Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions scripts/run-qwen3.5-35B-A3B-sft.sh
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,21 @@ WANDB_ARGS=(
# --wandb-group qwen3.5-35B-sft
)

TB_ARGS=(
--use-tensorboard
--tb-project-name qwen3.5-35B-A3B-sft
)

ENTROPY_ARGS=(
--log-sft-entropy
)

VAL_ARGS=(
# Uncomment to enable val loss monitoring (val-batch-size defaults to 64, val-input-key defaults to "messages")
# --val-data ${BASE_FOLDER}/val_data.jsonl
# --val-interval 10
)

MISC_ARGS=(
# default dropout in megatron is 0.1
--attention-dropout 0.0
Expand Down Expand Up @@ -157,6 +172,9 @@ ray job submit --address="http://127.0.0.1:8265" \
${SFT_ARGS[@]} \
${OPTIMIZER_ARGS[@]} \
${WANDB_ARGS[@]} \
${TB_ARGS[@]} \
${ENTROPY_ARGS[@]} \
${VAL_ARGS[@]} \
${PERF_ARGS[@]} \
${EVAL_ARGS[@]} \
${MISC_ARGS[@]} \
Expand Down
32 changes: 32 additions & 0 deletions slime/backends/megatron_utils/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,38 @@ def train_actor(self, rollout_id: int, rollout_data: RolloutBatch, external_data

log_perf_data(rollout_id, self.args)

def compute_val_loss(self, rollout_id: int) -> None:
"""Compute validation loss with full DP coordination.

Called periodically by train_async.py (controlled by --val-interval).
Each DP rank independently tokenizes its shard of val data and runs
forward-only; results are gathered and logged on the source rank.
"""
if self.args.debug_rollout_only:
return

if not getattr(self.args, "val_data", None):
return

from .val_loss import ValDataLoader, compute_val_loss

# Lazy-initialize val data loader (each rank gets its own shard)
if not hasattr(self, "_val_data_loader"):
self._val_data_loader = ValDataLoader(
self.args,
dp_rank=mpu.get_data_parallel_rank(with_context_parallel=False),
dp_size=mpu.get_data_parallel_world_size(with_context_parallel=False),
)

if self.args.offload_train:
self.wake_up()

with torch.no_grad():
compute_val_loss(self.args, self.model, self._val_data_loader, rollout_id)

if self.args.offload_train:
self.sleep()

@timer
def save_model(self, rollout_id: int, force_sync: bool = False) -> None:
if self.args.debug_rollout_only:
Expand Down
45 changes: 36 additions & 9 deletions slime/backends/megatron_utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -1087,7 +1087,13 @@ def sft_loss_function(
"""Compute supervised fine-tuning loss over response tokens.

Computes log-probabilities of the ground-truth tokens in the response
segments and returns the negative log-likelihood as the loss.
segments and returns the negative log-likelihood as the loss. Optionally
computes and logs token-level entropy when ``args.log_sft_entropy`` is set.

Entropy is computed under ``torch.no_grad()`` since it is only used for
logging and does not participate in the loss. This avoids retaining the
extra ``[T, V]`` clone in the autograd graph which would cause OOM for
large-vocabulary models.

Args:
args: Configuration (passed through to helpers).
Expand All @@ -1097,12 +1103,15 @@ def sft_loss_function(
sum_of_sample_mean: Reduction function that averages per-sample values.

Returns:
Tuple of `(loss, metrics)` where `metrics` contains a single detached
scalar "loss".
Tuple of `(loss, metrics)` where `metrics` contains detached scalars
"loss" and optionally "entropy".
"""
response_lengths = batch["response_lengths"]
total_lengths = batch["total_lengths"]

log_entropy = getattr(args, "log_sft_entropy", False)

# Step 1: compute log_probs for loss (with gradient)
_, log_probs_and_entropy = get_log_probs_and_entropy(
logits,
args=args,
Expand All @@ -1121,12 +1130,30 @@ def sft_loss_function(
if log_probs.numel() == 0:
loss += 0 * logits.sum()

return (
loss,
{
"loss": loss.clone().detach(),
},
)
reported_loss = {
"loss": loss.clone().detach(),
}

# Step 2: compute entropy for logging only (no_grad to avoid OOM)
# The logits.clone() inside calculate_log_probs_and_entropy won't be
# retained in the autograd graph, so it's freed after computation.
if log_entropy:
with torch.no_grad():
_, entropy_result = get_log_probs_and_entropy(
logits,
args=args,
unconcat_tokens=batch["unconcat_tokens"],
total_lengths=total_lengths,
response_lengths=response_lengths,
with_entropy=True,
max_seq_lens=batch.get("max_seq_lens", None),
)
entropy = entropy_result["entropy"]
entropy = torch.cat(entropy, dim=0)
mean_entropy = sum_of_sample_mean(entropy)
reported_loss["entropy"] = mean_entropy.detach()

return loss, reported_loss


def loss_function(
Expand Down
Loading