Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 88 additions & 5 deletions moss_tts_delay/finetuning/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,20 @@ def parse_args() -> argparse.Namespace:
parser.add_argument("--max-train-steps", type=int, default=None)
parser.add_argument("--max-grad-norm", type=float, default=1.0)
parser.add_argument("--logging-steps", type=int, default=1)
parser.add_argument(
"--wandb-project",
type=str,
default=None,
help="If set, log metrics to Weights & Biases (main process only). Requires: pip install wandb",
)
parser.add_argument("--wandb-run-name", type=str, default=None)
parser.add_argument("--wandb-entity", type=str, default=None)
parser.add_argument(
"--wandb-tags",
type=str,
default=None,
help="Comma-separated tags for the W&B run.",
)
parser.add_argument("--num-workers", type=int, default=0)
parser.add_argument("--mixed-precision", type=str, default="bf16", choices=["no", "fp16", "bf16"])
parser.add_argument("--attn-implementation", type=str, default="auto")
Expand Down Expand Up @@ -550,10 +564,6 @@ def main() -> None:
if accelerator.is_main_process:
output_root.mkdir(parents=True, exist_ok=True)

global_step = 0
completed_epochs = 0
last_log_time = time.perf_counter()
last_logged_step = 0
train_args_to_save = vars(args).copy()
train_args_to_save["global_batch_size"] = global_batch_size
train_args_to_save["global_batch_size_formula"] = global_batch_formula
Expand All @@ -562,6 +572,66 @@ def main() -> None:
train_args_to_save["resolved_warmup_steps"] = warmup_steps
train_args_to_save["resolved_channelwise_loss_weight"] = resolved_channelwise_loss_weight

wandb_module = None
if args.wandb_project and accelerator.is_main_process:
try:
import wandb
except ImportError as exc:
raise ImportError(
"wandb is not installed. Install with: pip install wandb"
) from exc
init_kwargs: Dict[str, Any] = {
"project": args.wandb_project,
"config": train_args_to_save,
}
if args.wandb_entity:
init_kwargs["entity"] = args.wandb_entity
if args.wandb_run_name:
init_kwargs["name"] = args.wandb_run_name
if args.wandb_tags:
init_kwargs["tags"] = [t.strip() for t in args.wandb_tags.split(",") if t.strip()]
wandb.init(**init_kwargs)
wandb_module = wandb

try:
_training_loop(
accelerator,
args,
model,
train_dataloader,
optimizer,
lr_scheduler,
resolved_channelwise_loss_weight,
max_train_steps,
global_batch_size,
output_root,
train_args_to_save,
wandb_module,
)
finally:
if wandb_module is not None:
wandb_module.finish()


def _training_loop(
accelerator: Accelerator,
args: argparse.Namespace,
model: MossTTSDelayModel,
train_dataloader: DataLoader,
optimizer: torch.optim.Optimizer,
lr_scheduler: Any,
resolved_channelwise_loss_weight: Optional[List[float]],
max_train_steps: int,
global_batch_size: int,
output_root: Path,
train_args_to_save: Dict[str, Any],
wandb_module: Optional[Any],
) -> None:
global_step = 0
completed_epochs = 0
last_log_time = time.perf_counter()
last_logged_step = 0

for epoch in range(args.num_epochs):
model.train()
for batch in train_dataloader:
Expand Down Expand Up @@ -595,16 +665,29 @@ def main() -> None:
samples_per_sec = (global_batch_size * steps_since_last_log) / elapsed
eta_seconds = max(max_train_steps - global_step, 0) / steps_per_sec
logged_loss = accelerator.gather(loss.detach().float().reshape(1)).mean().item()
lr_val = lr_scheduler.get_last_lr()[0]
accelerator.print(
f"[{format_timestamp()}] "
f"epoch={epoch} step={global_step}/{max_train_steps} "
f"loss={logged_loss:.4f} "
f"lr={lr_scheduler.get_last_lr()[0]:.2e} "
f"lr={lr_val:.2e} "
f"step_time={step_time:.2f}s "
f"steps_per_sec={steps_per_sec:.3f} "
f"samples_per_sec={samples_per_sec:.2f} "
f"eta={format_duration(eta_seconds)}"
)
if wandb_module is not None:
wandb_module.log(
{
"train/loss": logged_loss,
"train/lr": lr_val,
"train/step_time": step_time,
"train/steps_per_sec": steps_per_sec,
"train/samples_per_sec": samples_per_sec,
"train/epoch": epoch,
},
step=global_step,
)

if global_step >= max_train_steps:
break
Expand Down
93 changes: 88 additions & 5 deletions moss_tts_local/finetuning/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,20 @@ def parse_args() -> argparse.Namespace:
parser.add_argument("--max-train-steps", type=int, default=None)
parser.add_argument("--max-grad-norm", type=float, default=1.0)
parser.add_argument("--logging-steps", type=int, default=1)
parser.add_argument(
"--wandb-project",
type=str,
default=None,
help="If set, log metrics to Weights & Biases (main process only). Requires: pip install wandb",
)
parser.add_argument("--wandb-run-name", type=str, default=None)
parser.add_argument("--wandb-entity", type=str, default=None)
parser.add_argument(
"--wandb-tags",
type=str,
default=None,
help="Comma-separated tags for the W&B run.",
)
parser.add_argument("--num-workers", type=int, default=0)
parser.add_argument("--mixed-precision", type=str, default="bf16", choices=["no", "fp16", "bf16"])
parser.add_argument("--attn-implementation", type=str, default="auto")
Expand Down Expand Up @@ -554,10 +568,6 @@ def main() -> None:
if accelerator.is_main_process:
output_root.mkdir(parents=True, exist_ok=True)

global_step = 0
completed_epochs = 0
last_log_time = time.perf_counter()
last_logged_step = 0
train_args_to_save = vars(args).copy()
train_args_to_save["global_batch_size"] = global_batch_size
train_args_to_save["global_batch_size_formula"] = global_batch_formula
Expand All @@ -566,6 +576,66 @@ def main() -> None:
train_args_to_save["resolved_warmup_steps"] = warmup_steps
train_args_to_save["resolved_channelwise_loss_weight"] = resolved_channelwise_loss_weight

wandb_module = None
if args.wandb_project and accelerator.is_main_process:
try:
import wandb
except ImportError as exc:
raise ImportError(
"wandb is not installed. Install with: pip install wandb"
) from exc
init_kwargs: Dict[str, Any] = {
"project": args.wandb_project,
"config": train_args_to_save,
}
if args.wandb_entity:
init_kwargs["entity"] = args.wandb_entity
if args.wandb_run_name:
init_kwargs["name"] = args.wandb_run_name
if args.wandb_tags:
init_kwargs["tags"] = [t.strip() for t in args.wandb_tags.split(",") if t.strip()]
wandb.init(**init_kwargs)
wandb_module = wandb

try:
_training_loop(
accelerator,
args,
model,
train_dataloader,
optimizer,
lr_scheduler,
resolved_channelwise_loss_weight,
max_train_steps,
global_batch_size,
output_root,
train_args_to_save,
wandb_module,
)
finally:
if wandb_module is not None:
wandb_module.finish()


def _training_loop(
accelerator: Accelerator,
args: argparse.Namespace,
model: MossTTSDelayModel,
train_dataloader: DataLoader,
optimizer: torch.optim.Optimizer,
lr_scheduler: Any,
resolved_channelwise_loss_weight: Optional[List[float]],
max_train_steps: int,
global_batch_size: int,
output_root: Path,
train_args_to_save: Dict[str, Any],
wandb_module: Optional[Any],
) -> None:
global_step = 0
completed_epochs = 0
last_log_time = time.perf_counter()
last_logged_step = 0

for epoch in range(args.num_epochs):
model.train()
for batch in train_dataloader:
Expand Down Expand Up @@ -599,16 +669,29 @@ def main() -> None:
samples_per_sec = (global_batch_size * steps_since_last_log) / elapsed
eta_seconds = max(max_train_steps - global_step, 0) / steps_per_sec
logged_loss = accelerator.gather(loss.detach().float().reshape(1)).mean().item()
lr_val = lr_scheduler.get_last_lr()[0]
accelerator.print(
f"[{format_timestamp()}] "
f"epoch={epoch} step={global_step}/{max_train_steps} "
f"loss={logged_loss:.4f} "
f"lr={lr_scheduler.get_last_lr()[0]:.2e} "
f"lr={lr_val:.2e} "
f"step_time={step_time:.2f}s "
f"steps_per_sec={steps_per_sec:.3f} "
f"samples_per_sec={samples_per_sec:.2f} "
f"eta={format_duration(eta_seconds)}"
)
if wandb_module is not None:
wandb_module.log(
{
"train/loss": logged_loss,
"train/lr": lr_val,
"train/step_time": step_time,
"train/steps_per_sec": steps_per_sec,
"train/samples_per_sec": samples_per_sec,
"train/epoch": epoch,
},
step=global_step,
)

if global_step >= max_train_steps:
break
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,12 @@ dependencies = [
flash-attn = ["flash-attn"]
finetune = [
"accelerate>=1.10.1",
"wandb>=0.16.0",
]
finetune-deepspeed = [
"accelerate>=1.10.1",
"deepspeed>=0.16.0",
"wandb>=0.16.0",
]

# Default PyTorch runtime stack for the original pipeline.
Expand Down