Skip to content

[Pipeline RL] Add support for PipelineRL#428

Merged
jlamypoirier merged 134 commits intomainfrom
jlp_pipeline_rl
Mar 20, 2026
Merged

[Pipeline RL] Add support for PipelineRL#428
jlamypoirier merged 134 commits intomainfrom
jlp_pipeline_rl

Conversation

@jlamypoirier
Copy link
Collaborator

@jlamypoirier jlamypoirier commented Dec 17, 2025

This PR provides the initial integration with PipelineRL with GRPO loss.

It introduces:

  • Streaming Redis Dataset — capable of consuming documents such as rollouts from an external Redis stream.
  • Trainer Callback System — supports callbacks for events like training_started, step_finished, and training_finished.
  • Redis-based Callback Implementation with Weights Broadcast Mechanism — uses a separate external NCCL rendezvous point to broadcast updated model weights in real time to inference servers, and a Redis stream to broadcast training events.

This enables seamless coordination between Fast-LLM training and PipelineRL-based inference or orchestration components.

Base automatically changed from jlp_entropy_loss_tweaks to main March 17, 2026 23:42
jlamypoirier and others added 15 commits March 17, 2026 20:23
… forward

Move num_labels_in_seq computation from _compute_num_labels_in_seq (called
inside forward_backward on the already-packed sequence) to _get_model_input,
where document boundaries are available via cropped_lengths. Per-document
response token counts are trivially computed and broadcast to token positions,
eliminating the need for span-finding on the packed sequence.

Also fixes new_logprobs metric scaling with cross_entropy_splits > 1, and
updates test_lm_head to properly handle list-indexed advantages/old_log_probs
and verify the new_logprobs extra metric.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Use local_world_size=1 (not world_size) since each process sees exactly
  one GPU via CUDA_VISIBLE_DEVICES in PipelineRL's setup
- Switch from torch.distributed.broadcast_object_list/broadcast to
  fast_llm.core.distributed.broadcast_object/broadcast, which work
  directly on ProcessGroupNCCL backend objects (ProcessGroupPool returns
  unregistered backends that torch.distributed ops cannot accept)
- Use process_group.shutdown() instead of torch.distributed.destroy_process_group
  for the same reason

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Two bugs when micro-batches are padded to pad_to_size in
LanguageModelBatch.from_documents:

1. advantages and old_log_probabilities (TokenDataBatch) were not padded
   to match the token batch size. get_cropped_data(label_begin, label_end)
   then returned fewer elements than logits, causing a shape mismatch in
   fused_grpo_loss_forward_backward.

2. num_labels_in_seq used cropped_lengths from (begin, label_end) which
   spans end-begin+prediction_distance tokens, one more than the model
   input length. Now uses (label_begin, label_end) so segment lengths
   sum to end-begin, matching new_log_probs shape.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
rafapi
rafapi previously requested changes Mar 20, 2026
bigximik and others added 3 commits March 20, 2026 08:46
start_time was set once at the start of iterate() and never reset,
causing TimeoutError after 600s of total training time regardless of
whether documents were actively flowing. Reset start_time on each
successful XREADGROUP response so the timeout only fires when no
new documents have arrived for the configured duration.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
searchsorted requires a sorted haystack, but labels_per_document is an
unsorted array of per-document label counts. Using it directly caused
incorrect doc-index lookups, resulting in wrong (often zero) label counts
and nan in the grpo_new_logprobs metric.

Fix: use length_cumsum[1:] (sorted) to map each token to its document
index, then index labels_per_document with that result.
Padded tokens and fully masked documents have num_labels_in_seq=0 and
loss_mask=0. Without clamping, 0/0=nan poisons the sum even though those
positions contribute nothing to the loss. Clamp to min=1 so masked
positions produce 0/1=0 instead.
Copy link
Collaborator

@bigximik bigximik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’ve made several additional changes and addressed @rafapi feedback. @jlamypoirier , could you review and confirm? Otherwise, we can merge.

@jlamypoirier jlamypoirier marked this pull request as ready for review March 20, 2026 21:17
@jlamypoirier jlamypoirier merged commit 24b0f0c into main Mar 20, 2026
1 of 2 checks passed
@jlamypoirier jlamypoirier deleted the jlp_pipeline_rl branch March 20, 2026 21:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants