gspo: GSPO loss + DeepSpeed parity fixes (loss/grad divisors, SDP, fp32_lm_head, docs_per_step, temperature)#502
Conversation
Adds GRPO metrics parity with DeepSpeed: old_logprobs, ratio, ratio_sum, ratio_sq_sum, kl_new_old, clamp_frac, advantage, max/min_advantage, num_tokens, and optional per-token entropy. New files: - fast_llm/layers/language_model/loss/pg_metrics.py: reusable PolicyGradientMetrics dataclass + compute_policy_gradient_metrics() (callable by future PPO), with chunked vocab-parallel entropy support. - tests/layers/test_grpo_metrics.py: 8 unit tests covering single-seq, packed multi-seq, masked tokens, clamp fraction, entropy correctness, mock SDP correctness, mock vocab-parallel entropy, normalization parity. Config additions to LanguageModelGRPOLossConfig: - compute_extra_metrics (default False): log all non-entropy metrics - compute_entropy_metric (default False): additionally log per-token entropy - entropy_chunk_size (default 4096): batch chunk size for entropy pass Normalization matches existing new_logprobs_mean: sum(v*mask/label_counts) then divided by num_documents_in_batch. MAX/MIN use LossDef ReductionType and correct ReduceOp so they aggregate correctly across microbatches and SDP/sequence-parallel ranks.
Rename four metrics to match DeepSpeed's naming exactly so runs on both backends produce comparable WandB keys: ratio → ratio_new_old ratio_sum → ratio_new_old_sum ratio_sq_sum → ratio_new_old_squared_sum clamp_frac → clamp_log_ratio_new_old_indicator
Implements GSPO (geometric-mean sequence-level policy-gradient loss) as an alternative to the existing per-token GRPO clipping. Controlled via LanguageModelGRPOLossConfig.policy_loss = "gspo". Key changes: - data pipeline: expose per-token document_index when return_document_index=True - LanguageModelKwargs.document_index: new kwarg constant - LanguageModelLoss: store SDP dim for cross-rank segment aggregation - grpo.py: fused_gspo_loss_forward_backward with all_reduce(SUM) across SDP ranks before computing segment-level R_s and A_s; gradient derivation exploits tok_count cancellation so every token in a segment gets the same gradient factor R_s * clip_indicator_s - tests/layers/test_gspo_loss.py: 8 unit tests (single-segment, packed, ratio-1 equivalence, clipping, masking, SDP mock, gradient finite-diff, per-token metrics unchanged)
Adds ScheduleConfig.rollouts_per_step (default 0). When >0, TrainerConfig._from_dict
computes depth_first_micro_batches = rollouts_per_step // (batch_data_parallel ×
breadth_first_micro_batches) before sub-configs are created (and frozen).
Matches DeepSpeed gradient_accumulation_passes semantics for RL: with train_batch_size=1
each microbatch holds one rollout, so setting rollouts_per_step=1024 with data_parallel=8
gives depth_first_micro_batches=128 → exactly 1024 rollouts per optimizer step globally.
YAML usage:
schedule:
rollouts_per_step: 1024 # replaces manual depth_first_micro_batches
model:
distributed:
data_parallel: 8 # used for the division
- Rename rollouts_per_step → docs_per_step in ScheduleConfig; depth_first
is now determined at runtime rather than statically in _from_dict
- Add Schedule._depth_first_override and _eff_{depth_first,sequential,num_inputs}
properties so per-step schedules share the same config object as the runner
- Add Trainer._prefetch_to_doc_target: fetches microbatches one at a time,
all-reduces doc count per microbatch, stops when global total ≥ docs_per_step,
then resets num_documents_in_batch to the step total on all inputs
- Add Trainer._get_or_build_schedule: builds and caches per-N Schedule with
_depth_first_override=N//breadth_first_micro_batches
- Add normalize_by_documents flag to LanguageModelGRPOLossConfig; when True
both GRPO and GSPO paths divide by num_documents_in_batch instead of
num_labels_in_batch (matches DeepSpeed's per-rollout normalization)
- Add tests/layers/test_docs_per_step.py: 13 unit tests covering divisor
scaling, normalize_by_documents layer routing, Schedule._eff_* properties,
and _prefetch_to_doc_target accumulation logic
Add temperature field to LanguageModelGRPOLossConfig. When set to match the actor's sampling temperature (e.g. 0.7), new log-probs are computed at the same temperature as the stored old log-probs, so the IS ratio starts near 1.0 instead of ~1.08. Implementation: _effective_logits_scale = logits_scale_factor / temperature, substituted for logits_scale_factor at all three callsites in _forward_backward (GRPO path, GSPO path, _register_pg_metrics). Default temperature=1.0 preserves existing behaviour exactly.
Add fp32_lm_head to LanguageModelHeadConfig. When enabled, input hidden states and output_weights are cast to float32 before the lm_head linear, producing FP32 logits. This matches vLLM's bf16_last_layer_fp32 quantization (pipelinerl/vllm_quantization.py) and the DeepSpeed trainer's apply_fp32_lm_head() patch, so new_logprobs and old_logprobs are computed at the same numerical precision and the IS ratio starts near 1.0 at init. The gradient flowing back through the linear is cast to the original input dtype (bf16) before returning, keeping the transformer backward pass in its native dtype.
…accumulation Detaching the FP32 weight copy (requires_grad=False) prevents output_parallel_linear_backward from trying to write to a non-existent grad_buffer on the copy. Weight grad is then computed explicitly from the FP32 matmul and accumulated into the original BF16 param's grad_buffer via accumulate_gradient, restoring the correct FSDP gradient contract.
When normalize_by_documents=true, fast-LLM's reported grad_norm was ~1024×
larger than DeepSpeed's for the equivalent loss, causing the default
gradient_norm_clipping=0.3 to over-clip by ~500× and making training ~10
reward points slower than DS GSPO at the same step count. The lm_head_loss
metric was also off — 1024× smaller than DS's rl/loss in the previous
divisor=num_documents² formulation, then 2× too large from SDP doubling.
Root cause analysis
-------------------
DeepSpeed has TWO 1/batch_size factors with different sources:
1. Loss reported (rl/loss) uses /batch_size via tokens_weights = 1/batch_size
(pipelinerl/finetune/rl/__init__.py:246). The reported `rl/loss = -1.7`
value is the raw policy_loss_total, divided once by batch_size.
2. Gradient buffer uses an ADDITIONAL /(gas × world_size) factor that comes
from `scale_wrt_gas=True` in engine.backward()
(deepspeed/runtime/engine.py:1995-1996) and `tensor.div_(world_sz)` in
reduce_scatter_coalesced (deepspeed/runtime/comm/coalesced_collectives.py:124).
For DS with samples_per_microbatch=1 (PipelineRL standard), gas × world_size
= batch_size, so DS's effective gradient buffer factor is 1/batch_size² while
the loss metric factor is 1/batch_size. Loss and gradient have asymmetric
scaling.
Fast-LLM's existing implementation used a single `divisor` for both loss and
gradient. Worse, the data_parallel × grad_scale factor in grad_output
(runner.py:318) cancels with FSDP's RS-AVG /world_size, structurally removing
DS's /(gas × world_size) factor from the gradient. So fast-LLM's gradient
buffer ended up at 1/batch_size while DS's was at 1/batch_size² — a
~batch_size = 1024× mismatch.
Additionally, GSPO's SDP allreduce of lrn_sum/adv_sum/tok_sum makes both SDP
ranks compute IDENTICAL per-segment loss values. When LossDef.reduce sums
over the data_group (which includes SDP ranks), the loss metric is
double-counted by sdp_size. The gradient buffer is NOT double-counted —
each SDP rank contributes gradient from its own LOCAL tokens, with different
contributions for different tokens of the same segment.
Fixes
-----
1. Add a `grad_divisor` parameter to `fused_gspo_loss_forward_backward`,
`fused_grpo_loss_forward_backward`, and `triton_grpo_loss_forward_backward`,
defaulting to `divisor` (existing behavior). Allows the gradient to use a
different divisor than the loss.
2. In `LanguageModelGRPOLoss._forward_backward`, when normalize_by_documents
is True, set:
loss divisor = num_documents_in_batch (matches DS rl/loss)
gradient divisor = num_documents_in_batch² (matches DS grad_norm)
This is independent of TP/PP/SDP/DP parallelism and microbatching schedule
because batch_size is invariant under all of these.
3. In the GSPO path, divide the loss by sdp_size when sdp_group is active
(`fused_gspo_loss_forward_backward`). This pre-cancels the SDP doubling
that LossDef.reduce's SUM over data_group introduces. The gradient is
unaffected — different SDP ranks naturally contribute gradient from
different LOCAL token positions, no double-counting at any layer.
Verification
------------
Tested on 7B math run with 4 nodes, GSPO, gradient_norm_clipping=0.3:
Before fix | After fix | DS GSPO reference
------------------- | ------------------ | ------------------
step 1 grad_norm=141| step 1 grad_norm=0.135 | step 1 grad_norm=0.145
step 1 lm_head_loss | step 1 lm_head_loss | step 1 rl/loss
= -13.7 | ~ -1.7 (sign varies | = -1.7
| per data sample) |
clip_coeff=0.002 | clip_coeff=1.000 | no clipping at step 1
newlp at step 50 | newlp at step 50 | newlp at step 50
trapped at -0.17 | = -0.103 | = -0.105
newlp trajectory tracks DS step-by-step: step 1 within 3%, step 50 within 2%.
Both systems show grad_norm spikes at the same training phase (steps 14-20)
during warmup ramp-up — DS step 16 grad_norm=6.365 vs Fast-LLM 6.093.
Files changed
-------------
- fast_llm/layers/language_model/loss/grpo.py:
- LanguageModelGRPOLoss._forward_backward: split divisor and grad_divisor
based on normalize_by_documents flag, with detailed comments referencing
the corresponding lines in DeepSpeed and PipelineRL.
- fused_gspo_loss_forward_backward: add grad_divisor parameter; divide loss
by sdp_size when sdp_group is active.
- fused_grpo_loss_forward_backward: add grad_divisor parameter.
- fast_llm/functional/triton/grpo_loss.py:
- triton_grpo_loss_forward_backward: add grad_divisor parameter.
- Inline pg_metrics.py into grpo.py; rename to GRPOMetrics - Drop entropy_chunk_size; reuse fused_softmax_base outputs for entropy - Replace two bool flags with a single metrics: GRPOMetricsLevel enum - Rename clamp_log_ratio_new_old_indicator -> clipped_ratio_fraction - Raise on metrics enabled with pipeline_parallel > 1 (MAX/MIN reduce would be corrupted by the zero placeholder on empty pipeline ranks) - Migrate tests into tests/layers/test_lm_losses.py, reusing the existing helpers and parametrization (single + distributed runner) Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- Drop stale "second softmax pass" overhead note from `metrics` description (entropy now reuses the base softmax outputs) - De-mirror max/min in reference_grpo_metrics: use advantages[loss_mask].max()/.min() instead of the implementation's -inf/+inf sentinel pattern Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- Align (logits, target, advantages, old_log_probabilities, ...) order across compute_grpo_metrics, fused_grpo_loss_forward_backward, and reference_grpo_metrics - Replace **kwargs in LanguageModelGRPOLoss.__init__ with the explicit keyword-only signature mirroring LanguageModelLoss.__init__ - num_docs -> num_documents - Drop the comment that restated the k3 KL formula - Give compute_grpo_metrics the same defaults as the loss kernel - Trim the metrics field description to category-level wording - Always exercise varying label_counts in _test_grpo_metrics so per-token denominator broadcasting is covered - reference_grpo_metrics returns GRPOMetrics; comparison loop iterates dataclasses.fields - Drop name = self._name micro-rebinds; use self._name inline - defs = super()...; defs.append(...); defs.extend(...) consistently - Tighten _register_extra_metrics losses type to dict[str, list[Tensor]] - Split compiled tuple-returning core from outer GRPOMetrics wrapper to avoid @torch.compile graph-breaks on dataclass construction - One-line comment on the metrics gate explaining the softmax-skip Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
NamedTuple is a tuple subclass that dynamo handles natively, so the previous wrapper/inner split (added to dodge a dataclass graph-break) collapses into one @torch.compile function. Field order now lives exactly once — on the class. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- Entropy under vocab-parallel TP was wrong: the dot-product term (exp_logits * logits_norm).sum(-1) summed only the local vocab slice, so dividing by the global sum_exp_logits gave a per-rank fragment instead of the full E_p[logit_norm]. All-reduce the partial sum. - Replace the verbose pipeline-parallel guard with Assert.custom; the field description already explains the constraint. - Drop the cryptic `# k3` comment. - Match _register_extra_metrics losses annotation to the base class (dict | None). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
# Conflicts: # fast_llm/layers/language_model/loss/config.py # fast_llm/layers/language_model/loss/grpo.py
Coarse review — pass 1 of 2Reviewed 1. 2. The fp32_lm_head manual weight-grad path at 3. Same fp32_lm_head block re-gathers 4. 5. GSPO and GRPO call sites in 6. 7. 8. 9. The SDP-loss double-counting fix at 10. No test exercises 11. No test covers the fp32_lm_head gradient path at 12. 13. 14. Notes
|
# Conflicts: # fast_llm/layers/language_model/loss/config.py # fast_llm/layers/language_model/loss/grpo.py
- Drop unused self._preprocessing_config store in Trainer.setup. - Replace torch.ones + index_add_ with torch.bincount for tok_sum in fused_gspo_loss_forward_backward. - Drop load-bearing-sounding docs_per_step reference from the normalize_by_documents field description (no cross-config check exists to enforce it). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Splits the policy-gradient loss config and class hierarchy: - LanguageModelPolicyGradientLossConfig (abstract base): shared fields (epsilon_low/high, metrics, normalize_by_documents, temperature). - LanguageModelGRPOLossConfig: registers `type: grpo` (keeps GRPO-only use_triton). - LanguageModelGSPOLossConfig: registers `type: gspo`. - LanguageModelPolicyGradientLoss (abstract base): shared __init__/_forward_backward/_register_extra_metrics/get_loss_definitions/ get_preprocessing_config plumbing; abstract `_call_kernel`. - LanguageModelGRPOLoss / LanguageModelGSPOLoss: each implements `_call_kernel` against its kernel; GSPO overrides `get_preprocessing_config` to add `return_document_index`. Drops the stringly-typed `policy_loss: str` switch and the in-method if/else dispatch, addressing review items #1 and #5 plus Note 2. YAML migration: `type: grpo` + `policy_loss: gspo` → `type: gspo`. No checked-in YAML configs use the old form. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Replaces the abstract `_call_kernel` + per-algorithm subclass pattern with the assignment-at-init pattern used by `Normalization._forward`. - Single LanguageModelPolicyGradientLoss class hosts both kernel calls as `_call_grpo_kernel` and `_call_gspo_kernel`. - __init__ assigns `self._call_kernel` to the matching method based on isinstance(config, LanguageModelGSPOLossConfig). - get_preprocessing_config dispatches inline on the same isinstance. - Both LanguageModelGRPOLossConfig and LanguageModelGSPOLossConfig return the same loss class — the YAML-side type split (registered via @config_class(dynamic_type=...)) stays as in #1. Drops ~30 lines net from grpo.py: removes the abstract `_call_kernel` declaration and the two single-method subclasses. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Reverts the class merge from d2c051a in favor of the assignment-at-init pattern used by Normalization._forward. Drops the per-call _call_kernel wrapper that just shuffled args. - LanguageModelPolicyGradientLoss now hosts only shared scaffolding: _compute_divisors (token vs document), _shared_kernel_kwargs (the 9 kwargs both kernels accept), _finalize_loss (post-call register + extra metrics), and the per-token metrics machinery. - LanguageModelGRPOLoss and LanguageModelGSPOLoss are restored. Each __init__ assigns self._forward to the actual kernel function: GRPO: triton_grpo_loss_forward_backward or fused_grpo_loss_forward_backward GSPO: fused_gspo_loss_forward_backward - Each subclass's _forward_backward calls self._forward(...) directly with the kernel's real signature; no intermediate wrapper. - Configs map type:grpo → LanguageModelGRPOLoss, type:gspo → LanguageModelGSPOLoss again. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
About the features
Do we know what was actually responsible for the inconsistency with Deepspeed? In my opinion 6 is the most suspicious. 1/3 is equally problematic but looks more like a config / scoping issue (were we always using GSPO on the deepspeed side? Did we ever compare using GRPO?). 4 look real. but minor, 2 and 5 look negligible. |
Summary
This PR adds GSPO loss to fast-LLM along with a suite of supporting fixes that together achieve full metric and training-trajectory parity with DeepSpeed's GRPO/GSPO implementation. Targets the
grpo-metricsbranch. Six logical units:1. GSPO loss (sequence-level IS-ratio clipping)
Implements GSPO as an alternative policy-gradient loss alongside the existing per-token GRPO clipping. Controlled via
LanguageModelGRPOLossConfig.policy_loss = "gspo".fused_gspo_loss_forward_backwardkernel: computes per-segment geometric-mean log-ratioR_s, clips at[1−ε_low, 1+ε_high], and appliesR_s × A_sas a uniform per-token gradient within each segment. Anall_reduce(SUM)over sequence-data-parallel ranks aggregates(lrn_sum, adv_sum, tok_count)before clipping so the ratio is correct under sequence parallelism.document_indexdata field andLanguageModelKwargs.document_indexkwarg constant to route per-token segment membership through the data pipeline.tests/layers/test_gspo_loss.py(single-segment, packed sequences, ratio=1 equivalence, clipping, masking, SDP mock, gradient finite-diff, independence from per-token metrics).2. Dynamic
docs_per_stepaccumulationReplaces static
depth_first_micro_batcheswith a runtime document-count target — matching DeepSpeed'sgradient_accumulation_passessemantics for RL (where each microbatch holds one rollout).ScheduleConfig.docs_per_step: when >0,Trainer._prefetch_to_doc_targetfetches microbatches one at a time, all-reduces the per-microbatch document count, and stops once the global total ≥docs_per_step. The final step total is broadcast to all inputs so the normalisation denominator is consistent.Trainer._get_or_build_schedulebuilds and caches a per-NSchedulewith_depth_first_override = N // breadth_first_micro_batches, so the existing schedule machinery is reused without changes to the runner.Schedule._eff_{depth_first,sequential,num_inputs}properties expose the effective values for a given override.tests/layers/test_docs_per_step.py.3.
normalize_by_documentsAdds a
normalize_by_documentsflag toLanguageModelGRPOLossConfig. WhenTrue, both the GRPO and GSPO paths divide the loss bynum_documents_in_batch(the step-level rollout count) rather than the token count. Matches DeepSpeed's normalization wheretokens_weights = 1 / batch_size.4. Temperature scaling for IS ratio parity
Adds a
temperaturefield toLanguageModelGRPOLossConfig. When set to match the actor's sampling temperature (e.g. 0.7), new log-probabilities are computed at the same temperature as the stored old log-probabilities from vLLM, so the IS ratio starts near 1.0 at step 0 instead of ~1.08. Implementation:_effective_logits_scale = logits_scale_factor / temperature, substituted at all three call-sites in_forward_backward. Defaulttemperature=1.0preserves existing behaviour exactly.5.
fp32_lm_headprecision fix (matches vLLM'sbf16_last_layer_fp32)Adds a
fp32_lm_headflag (defaultFalse) onLanguageModelHeadConfig. WhenTrue, the LM head's logits computation upcasts both input and weight to FP32 before the linear projection, matching vLLM'sbf16_last_layer_fp32quantization. This ensures the trainer computes log-probabilities at the same numerical precision as the actor's sampling, sonew_logprobs ≈ old_logprobsat step 0 (IS ratio at training start ≈ 1.0, not artificially inflated by precision mismatch).d8cb9ef5: introduces the flag, upcasts input/weight, casts back to BF16 before downstream consumption.0f90f20b: fixes the gradient flow whenfp32_lm_head=True. The detached FP32 weight copy hasrequires_grad=False, which makesoutput_parallel_linear_backwardskip writing to the original weight'sgrad_buffer. We restore the FSDP gradient contract by computinggrad_weight = grad.t() @ saved_inputexplicitly and accumulating into the BF16 param'sgrad_bufferviaaccumulate_gradient.6. Decoupled loss/gradient divisors and SDP loss double-counting fix
Even with
normalize_by_documents=true, fast-LLM's reportedgrad_normwas ~1024× larger than DeepSpeed's, causing the defaultgradient_norm_clipping=0.3to over-clip by ~500× and making training ~10 reward points slower than DS GSPO at the same step count. Two issues, fixed in commit557a3c4c:Asymmetric loss/gradient scaling in DS:
/batch_sizeonce (viatokens_weights = 1/batch_size,pipelinerl/finetune/rl/__init__.py:246)./(gas × world_size)factor fromscale_wrt_gas=Trueinengine.backward()(deepspeed/runtime/engine.py:1995-1996) andtensor.div_(world_sz)inreduce_scatter_coalesced(deepspeed/runtime/comm/coalesced_collectives.py:124).samples_per_microbatch=1(PipelineRL standard),gas × world_size = batch_size, so the gradient buffer effectively has1/batch_size²while the loss metric has1/batch_size.Fast-LLM cancels DS's
/(gas × world_size)factor structurally viagrad_output = data_parallel × grad_scale(runner.py:318) interacting with FSDP's RS-AVG overdata_parallelranks (fsdp.py:396). So we need to apply the second1/batch_sizefactor explicitly only to the gradient — keeping the loss metric matched to DS.Fix: add a
grad_divisorparameter tofused_gspo_loss_forward_backward,fused_grpo_loss_forward_backward, andtriton_grpo_loss_forward_backward. Whennormalize_by_documents=true:num_documents_in_batch(matches DSrl/loss)num_documents_in_batch²(matches DSgrad_norm)Independent of TP/PP/SDP/DP parallelism and microbatching schedule, because
batch_sizeis invariant under all of them.SDP loss double-counting:
After the SDP allreduce of
lrn_sum/adv_sum/tok_suminfused_gspo_loss_forward_backward, both SDP ranks compute IDENTICAL per-segment loss values. WhenLossDef.reduceSUMs acrossdata_group(which includes SDP ranks), the loss metric is double-counted bysdp_size. The gradient is NOT double-counted — each SDP rank contributes gradient from its own LOCAL tokens, with different contributions for different tokens of the same segment.Fix: divide loss by
sdp_sizewhensdp_groupis active. Gradient unaffected.Verification
End-to-end 7B math run on 4 nodes, GSPO,
gradient_norm_clipping=0.3(default),normalize_by_documents=true,fp32_lm_head=true,temperature=0.7:grad_normlm_head_lossclip_coeffnewlpnewlp trajectory tracks DS step-by-step. Both systems show same gradient-spike pattern during warmup ramp-up at steps 14-20 (DS step 16 grad_norm=6.365, fast-LLM step 15=9.005). Match within data variance.
Test plan
pytest tests/layers/test_gspo_loss.py— GSPO unit tests passpytest tests/layers/test_docs_per_step.py— docs_per_step unit tests passpytest tests/layers/test_lm_losses.py— existing GRPO loss + per-token metrics tests unaffected (the metrics tests previously intest_grpo_metrics.pymoved into this file on the base branch)docs_per_step=1024,temperature=0.7,normalize_by_documents=true,fp32_lm_head=true, defaultgradient_norm_clipping=0.3) — grad_norm matches DS at step 1, training trajectory matches DS step-by-step through step 50+ (ongoing run validates through step ~410).