Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/art/local/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1365,6 +1365,22 @@ async def _experimental_fork_checkpoint(

shutil.copytree(source_checkpoint_dir, dest_checkpoint_dir)

# Ensure the trainer picks up the forked LoRA weights.
# 1. Invalidate the _state cache so create_unsloth_train_context re-initializes
# with the forked checkpoint path.
# 2. Store the forked checkpoint path so the first training call can explicitly
# load the adapter weights via load_lora_adapter. This is necessary because
# from_pretrained may set up the LoRA architecture without loading the actual
# trained weights.
service = await self._get_service(cast(TrainableModel, model))
if hasattr(service, "_state") and "_state" in service.__dict__:
del service.__dict__["_state"]
if verbose:
print(
"Invalidated UnslothService _state cache to pick up forked checkpoint"
)
service._forked_checkpoint_dir = dest_checkpoint_dir # type: ignore[union-attr]

if verbose:
print(
f"Successfully forked checkpoint from {from_model} (step {selected_step}) to {model.name}"
Expand Down
84 changes: 84 additions & 0 deletions src/art/unsloth/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,77 @@ def reload_to_gpu(self, device: str = "cuda:0") -> None:

self._is_offloaded = False

async def load_lora_adapter(self, lora_path: str) -> None:
"""Load LoRA adapter weights from a checkpoint directory into the peft model.

Used by fork_checkpoint to explicitly replace the adapter weights after
from_pretrained may have initialized fresh LoRA layers instead of loading
the forked weights (e.g. across precision mismatches).
"""
try:
await self.results_queue.join()
except Exception:
pass
try:
torch.cuda.synchronize()
except Exception:
pass

import importlib

try:
load_safetensors = importlib.import_module("safetensors.torch").load_file
except Exception:
load_safetensors = None # type: ignore[assignment]

state_dict = None
st_path = os.path.join(lora_path, "adapter_model.safetensors")
bin_path = os.path.join(lora_path, "adapter_model.bin")
try:
if os.path.exists(st_path) and load_safetensors is not None:
state_dict = load_safetensors(st_path, device="cpu")
elif os.path.exists(bin_path):
state_dict = torch.load(bin_path, map_location="cpu") # type: ignore[call-arg]
else:
raise FileNotFoundError(f"No adapter weights found in {lora_path}")
except Exception as exc:
raise RuntimeError(f"Failed to load LoRA adapter weights: {exc}") from exc

with torch.no_grad():
self.peft_model.zero_grad(set_to_none=True)
optimizer = getattr(self.trainer, "optimizer", None)
if optimizer is not None:
optimizer = getattr(optimizer, "optimizer", optimizer)
if hasattr(optimizer, "zero_grad"):
optimizer.zero_grad(set_to_none=True) # type: ignore[arg-type]
if hasattr(optimizer, "state") and isinstance(optimizer.state, dict):
optimizer.state.clear()

try:
try:
from peft.utils.save_and_load import (
set_peft_model_state_dict as _set_peft_model_state_dict,
)
except Exception:
from peft import (
set_peft_model_state_dict as _set_peft_model_state_dict, # type: ignore
)

active_adapter = getattr(self.peft_model, "active_adapter", "default")
_set_peft_model_state_dict(
self.peft_model,
state_dict,
adapter_name=active_adapter,
)
self.peft_model.set_adapter(active_adapter)
except Exception as exc:
raise RuntimeError(f"Failed to set LoRA weights in-place: {exc}") from exc

try:
torch.cuda.synchronize()
except Exception:
pass


# ============================================================================
# Service
Expand All @@ -319,6 +390,7 @@ class UnslothService:
_is_sleeping: bool = False
_last_training_mode: Literal["sft", "rl"] | None = None
_latest_step: int = 0
_forked_checkpoint_dir: str | None = None
_lora_id_counter: int = 1 # Start from 1 since 0 is reserved
# Dedicated mode subprocess state
_vllm_process: subprocess.Popen | None = field(default=None, repr=False) # type: ignore[type-arg]
Expand Down Expand Up @@ -612,6 +684,12 @@ async def _train_dedicated(
verbose: bool = False,
) -> AsyncIterator[dict[str, float]]:
"""Train in dedicated mode — no sleep/wake, vLLM keeps running on separate GPU."""
# Load forked adapter weights on first training call if needed.
forked_dir = getattr(self, "_forked_checkpoint_dir", None)
if forked_dir is not None:
self._forked_checkpoint_dir = None
await self._state.load_lora_adapter(forked_dir)

self._reset_optimizer_if_mode_changed("rl")
optimizer = _get_trainer_optimizer(self._state.trainer)

Expand Down Expand Up @@ -673,6 +751,12 @@ async def _train_shared(
verbose: bool = False,
) -> AsyncIterator[dict[str, float]]:
"""Train in shared mode — sleep/wake cycle with in-process vLLM."""
# Load forked adapter weights on first training call if needed.
forked_dir = getattr(self, "_forked_checkpoint_dir", None)
if forked_dir is not None:
self._forked_checkpoint_dir = None
await self._state.load_lora_adapter(forked_dir)

llm = await self.llm

# Pause generation to prevent new requests during training
Expand Down