Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions dev/profile.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,6 @@
" \"enforce_eager\": True,\n",
" \"gpu_memory_utilization\": 0.9,\n",
" },\n",
" \"peft_args\": {\n",
" # \"use_gradient_checkpointing\": False,\n",
" },\n",
" },\n",
")\n",
"state = ModelState(config)"
Expand Down
4 changes: 4 additions & 0 deletions dev/yes-no-maybe-megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ async def main() -> None:
max_tokens = int(os.environ.get("MAX_TOKENS", "100"))
timeout = float(os.environ.get("TIMEOUT", "100"))
learning_rate = float(os.environ.get("LEARNING_RATE", "1e-4"))
lora_rank = os.environ.get("LORA_RANK")
packed_sequence_length = int(
os.environ.get(
"PACKED_SEQUENCE_LENGTH",
Expand All @@ -202,6 +203,9 @@ async def main() -> None:
name=model_name,
project=project,
base_model=base_model,
lora_config=(
art.LoRAConfig(rank=int(lora_rank)) if lora_rank is not None else None
),
report_metrics=[],
_internal_config=build_internal_config(),
)
Expand Down
3 changes: 0 additions & 3 deletions dev/yes-no-maybe.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,6 @@ async def main():
# # engine_args=art.dev.EngineArgs(
# # max_lora_rank=1,
# # ),
# # peft_args=art.dev.PeftArgs(
# # r=1,
# # ),
# tinker_args=art.dev.TinkerArgs(
# renderer_name="qwen3_instruct",
# training_client_args=art.dev.TinkerTrainingClientArgs(
Expand Down
4 changes: 1 addition & 3 deletions examples/hn_title_generator/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,13 +239,11 @@ async def main():
name=MODEL_NAME,
project=PROJECT,
base_model=BASE_MODEL,
lora_config=art.LoRAConfig(alpha=8),
_internal_config=art.dev.InternalModelConfig(
init_args=art.dev.InitArgs(
gpu_memory_utilization=0.75,
),
peft_args=art.dev.PeftArgs(
lora_alpha=8,
),
trainer_args=art.dev.TrainerArgs(
max_grad_norm=0.1,
),
Expand Down
2 changes: 2 additions & 0 deletions src/art/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
from .auto_trajectory import auto_trajectory, capture_auto_trajectory
from .backend import Backend
from .batches import trajectory_group_batches
from .dev import LoRAConfig
from .gather import gather_trajectories, gather_trajectory_groups
from .model import Model, TrainableModel
from .serverless import ServerlessBackend
Expand Down Expand Up @@ -85,6 +86,7 @@
"Backend",
"LocalBackend",
"LocalTrainResult",
"LoRAConfig",
"ServerlessBackend",
"ServerlessTrainResult",
"Messages",
Expand Down
4 changes: 4 additions & 0 deletions src/art/dev/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from .engine import EngineArgs
from .model import (
BackendModelConfig,
InitArgs,
InternalModelConfig,
LoRAConfig,
PeftArgs,
TinkerArgs,
TinkerNativeArgs,
Expand All @@ -14,8 +16,10 @@

__all__ = [
"EngineArgs",
"BackendModelConfig",
"InternalModelConfig",
"InitArgs",
"LoRAConfig",
"PeftArgs",
"TinkerArgs",
"TinkerNativeArgs",
Expand Down
25 changes: 14 additions & 11 deletions src/art/dev/get_model_config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from ..megatron.model_support import default_target_modules_for_model
from .engine import EngineArgs
from .model import (
PEFT_ARGS_MIGRATION_MESSAGE,
BackendModelConfig,
InitArgs,
InternalModelConfig,
PeftArgs,
LoRAConfig,
TrainerArgs,
)
from .validate import is_dedicated_mode
Expand All @@ -17,11 +19,14 @@ def get_model_config(
base_model: str,
output_dir: str,
config: "InternalModelConfig | None",
) -> "InternalModelConfig":
lora_config: "LoRAConfig | None" = None,
) -> "BackendModelConfig":
from ..local.checkpoints import get_last_checkpoint_dir

if config is None:
config = InternalModelConfig()
if "peft_args" in config:
raise ValueError(PEFT_ARGS_MIGRATION_MESSAGE)

dedicated = is_dedicated_mode(config)
rollout_weights_mode = config.get("rollout_weights_mode", "lora")
Expand All @@ -36,7 +41,6 @@ def get_model_config(
max_seq_length=32768,
model_name=base_model,
)
target_modules = default_target_modules(base_model)
engine_args = EngineArgs(
allowed_local_media_path="/tmp",
enable_sleep_mode=enable_sleep_mode,
Expand All @@ -47,18 +51,17 @@ def get_model_config(
init_args.update(config.get("init_args", {}))
if last_checkpoint_dir := get_last_checkpoint_dir(output_dir):
init_args["model_name"] = last_checkpoint_dir
peft_args = PeftArgs(
lora_alpha=16,
r=8,
merged_lora_config = LoRAConfig(
random_state=3407,
target_modules=target_modules,
target_modules=default_target_modules(base_model),
use_gradient_checkpointing="unsloth",
)
peft_args.update(config.get("peft_args", {}))
if lora_config:
merged_lora_config.update(lora_config)
if rollout_weights_mode == "lora" and "lora_target_modules" not in config.get(
"engine_args", {}
):
engine_args["lora_target_modules"] = peft_args["target_modules"]
engine_args["lora_target_modules"] = merged_lora_config["target_modules"]
trainer_args = TrainerArgs(
adam_beta1=0.9,
adam_beta2=0.99,
Expand All @@ -78,10 +81,10 @@ def get_model_config(
weight_decay=0.1,
)
trainer_args.update(config.get("trainer_args", {}))
result = InternalModelConfig(
result = BackendModelConfig(
init_args=init_args,
engine_args=engine_args,
peft_args=peft_args,
lora_config=merged_lora_config,
rollout_weights_mode=rollout_weights_mode,
tinker_args=config.get("tinker_args"),
trainer_args=trainer_args,
Expand Down
25 changes: 17 additions & 8 deletions src/art/dev/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import Enum
from typing import Literal
from typing import Literal, NoReturn

from typing_extensions import Required, TypedDict

Expand Down Expand Up @@ -115,7 +115,6 @@ class InternalModelConfig(TypedDict, total=False):
Args:
init: Arguments for initializing an Unsloth FastLanguageModel.
engine: Arguments for the vLLM engine.
peft: Arguments for creating an Unsloth PEFT model wrapper.
tinker: Arguments for the Tinker training client.
trainer: Arguments for the GRPO trainer.
trainer_gpu_ids: GPU IDs for training (e.g., [0]). When set with
Expand All @@ -142,7 +141,6 @@ class InternalModelConfig(TypedDict, total=False):

init_args: "InitArgs"
engine_args: "EngineArgs"
peft_args: "PeftArgs"
tinker_args: "TinkerArgs | None"
tinker_native_args: "TinkerNativeArgs | None"
trainer_args: "TrainerArgs"
Expand All @@ -157,6 +155,10 @@ class InternalModelConfig(TypedDict, total=False):
allow_unvalidated_arch: bool


class BackendModelConfig(InternalModelConfig, total=False):
lora_config: "LoRAConfig"


class TinkerArgs(TypedDict, total=False):
renderer_name: Required[str]
training_client_args: "TinkerTrainingClientArgs"
Expand Down Expand Up @@ -203,11 +205,11 @@ class InitArgs(TypedDict, total=False):
use_async: bool


class PeftArgs(TypedDict, total=False):
r: int
class LoRAConfig(TypedDict, total=False):
rank: int
target_modules: list[str]
lora_alpha: int
lora_dropout: int
alpha: int
dropout: float
bias: str
layers_to_transform: list[int] | None
layers_pattern: str | None
Expand All @@ -216,11 +218,18 @@ class PeftArgs(TypedDict, total=False):
max_seq_length: int
use_rslora: bool
modules_to_save: list[str] | None
init_lora_weights: bool
init_weights: bool
loftq_config: dict
temporary_location: str


PEFT_ARGS_MIGRATION_MESSAGE = "`peft_args` has been replaced by top-level `TrainableModel(lora_config=...)`. Rename keys: r->rank, lora_alpha->alpha, lora_dropout->dropout, init_lora_weights->init_weights. Keep these keys under lora_config: target_modules, bias, layers_to_transform, layers_pattern, use_gradient_checkpointing, random_state, max_seq_length, use_rslora, modules_to_save, loftq_config, temporary_location."


def PeftArgs(*_: object, **__: object) -> NoReturn:
raise ValueError(PEFT_ARGS_MIGRATION_MESSAGE)


class TrainerArgs(TypedDict, total=False):
output_dir: str | None
overwrite_output_dir: bool
Expand Down
1 change: 1 addition & 0 deletions src/art/local/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,7 @@ async def _get_service(self, model: TrainableModel) -> ModelService:
base_model=model.base_model,
output_dir=get_model_dir(model=model, art_path=self._path),
config=model._internal_config,
lora_config=model.lora_config,
)
validate_dedicated_config(config)
dedicated = is_dedicated_mode(config)
Expand Down
13 changes: 8 additions & 5 deletions src/art/megatron/lora.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections.abc import Sequence
from collections.abc import Mapping, Sequence
import math
from typing import Any, Literal, cast

Expand Down Expand Up @@ -1376,17 +1376,20 @@ def wrap_shared_experts_mlp(
def apply_lora_adapters(
model: Sequence[torch.nn.Module],
provider: GPTModelProvider,
lora_config: Mapping[str, Any] | None = None,
) -> list[torch.nn.Module]:
lora_config = lora_config or {}
provider = cast(Any, provider)
handler = provider._art_model_support_handler
spec = provider._art_model_support_spec
target_modules = list(spec.default_target_modules)
rank = default_lora_rank_for_handler(handler)
target_modules = list(
lora_config.get("target_modules", spec.default_target_modules)
)
handler.apply_lora_adapters(
model,
provider,
target_modules=target_modules,
rank=rank,
alpha=LORA_ALPHA,
rank=lora_config.get("rank", default_lora_rank_for_handler(handler)),
alpha=lora_config.get("alpha", LORA_ALPHA),
)
return list(model)
1 change: 1 addition & 0 deletions src/art/megatron/runtime/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ async def _get_service(self, model: TrainableModel) -> ModelService:
base_model=model.base_model,
output_dir=get_model_dir(model=model, art_path=self._path),
config=model._internal_config,
lora_config=model.lora_config,
)
self._services[model.name] = MegatronService(
model_name=model.name,
Expand Down
Loading
Loading