Skip to content

Conversation

@tyler-griggs
Copy link
Member

PR Overview

This PR has a pretty significant architectural refactor. The primary changes are:

  1. Move algorithm logic into the trainer and infra logic into the workers.

    • Remove ppo_train (from FSDP workers for now, @erictang000 to remove from Megatron)
    • Move micro-batching logic from trainer to worker (worker doesn't know mini batches, trainer doesn't know micro batches
    • Worker onloading/offloading moved out of trainer
  2. Introduce WorkerDispatch: an intermediate layer between the training loop and workers. Later, this is the entity that will sit beneath the Tinker API server. It is the entity that handles a "pool" of workers. Currently it dispatches requests to the pool of workers (by mesh or by pass_through) and handles onloading / offloading workers.
    i. I'm very open to renaming suggestions.


Architecture

Here's a diagram of what the breakdown of responsibilities now looks like. Note that some of this will still need to change. E.g., the trainer will not call dispatch in the future, it will call into the Tinker API server.

┌─────────────────────────────────────────────────────────────┐
│                    TRAINER (Algorithm)                       │
│  - PPO algorithm implementation                              │
│  - Knows only mini batches                                   │
│  - Calls dispatch.forward_backward() + dispatch.optim_step() │
└─────────────────────────────────────────────────────────────┘
                              │
                              ▼
┌─────────────────────────────────────────────────────────────┐
│                 WORKER DISPATCH (Coordination)               │
│  - Manages all actor groups (policy, critic, ref)           │
│  - Handles GPU state (offload/backload) automatically       │
│  - Routes calls to appropriate workers                      │
│  - Handles DP sharding via MeshDispatch                     │
└─────────────────────────────────────────────────────────────┘
                              │
                              ▼
┌─────────────────────────────────────────────────────────────┐
│                    WORKERS (Execution)                       │
│  - Execute forward/backward passes                          │
│  - Handle micro-batching internally                         │
│  - Scale gradients at optim_step                            │
│  - Model-specific implementations (FSDP, Megatron)          │
└─────────────────────────────────────────────────────────────┘

Tests

Deleted Tests

  • test_ppo_train.py - tested removed ppo_train method

Updated Tests

  • test_training_step.py - uses WorkerDispatch for policy tests
  • test_worker_offload.py - updated to work with new interfaces
  • test_save_load_checkpoint.py - updated imports
  • test_trainer.py - rewrote test_normalize_mini_batch_size

Gradient scaling

Since the worker no longer knows mini batches, we scale gradients instead of loss:

Old (scale loss during backward)

for i in 1..N:
    grad += (1/N) * ∂loss_i/∂param
optimizer.step(grad)

New (scale gradients at optim_step)

for i in 1..N:
    grad += ∂loss_i/∂param
grad *= 1/N
optimizer.step(grad)

Both produce: grad = (1/N) * Σ ∂loss_i/∂param


What's next for training refactor?

  • Remove weight sync logic from trainer. This should not be explicitly triggered by the trainer
  • Remove other “infra” calls from trainer, such as empty_cache
  • Create separate entry points for a) launching workers and b) launching training. Currently trainer.py sets up the workers (build_models) and launches training
  • Update Megatron (@erictang000) to bring to same state instead of branching on which backend is used

lcm_dp_size = math.lcm(lcm_dp_size, self.critic_model.actor_infos[0].rank.dp_size)
if self.ref_model is not None:
lcm_dp_size = math.lcm(lcm_dp_size, self.ref_model.actor_infos[0].rank.dp_size)
lcm_dp_size = self.dispatch.get_lcm_dp_size()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also needs to be moved out of trainer. The trainer shouldn't know/care about dp size.

Copy link
Collaborator

@erictang000 erictang000 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks pretty good to me! will work on the megatron worker update off of this branch

Also calling out some work that's actively ongoing for adding dynamic batch sizing that's touching relevant codepaths:
#817
#847

maybe the best thing to do would be to rebase #847 on top of the tinkerify branch? (and move down the micro batching logic updates from there into into forward_backward)

scale = 1.0 / self._micro_batches_accumulated
for param in self.model.parameters():
if param.grad is not None:
param.grad.mul_(scale)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get that this is cleaner to not worry about loss scaling during forward (which requires pre-knowing the details of how we break down the microbatch or how often we call optim_step).

But are you sure that this is exactly numerically equivalent + compatible with autocasting/mixed precision? Scaling the loss prior to backward seems cleaner in that sense, can we verify that this is numerically identical (and that there aren't overflow edge cases here) to scaling the loss in pytorch?

Copy link
Collaborator

@erictang000 erictang000 Jan 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

coming back to this now that i'm implementing this for megatron - I think it overall seems cleaner and more flexible to handle scaling loss for gradient accumulation in forward_backward, where we could just set the default to weight each micro batch evenly and scale the loss for each microbatch by (1/num_microbatches). We could even theoretically provide an optional parameter for forward_backward that specifies a loss scaling term (default 1), so that users could flexibly scale the loss for different mini-batches differently (i.e. if you have 2 mini batches of different sizes, you might want to make sure that the 2nd mini batch loss is scaled to size).

def _forward_backward_micro(experience, microbatch_weight):
     ...
     loss = loss * microbatch_weight
     loss.backward()
     ...
def forward_backward(data, loss_weight=1.0):
     micro_batch_iterator = BatchIterator(data, micro_batch_size, drop_last=False)
     for micro_batch in micro_batch_iterator:
            metrics = self._forward_backward_micro(micro_batch, microbatch_weight=loss_weight / len(micro_batch_iterator))
      ...

This would still be compatible with the upstream tinker API which doesn't let you specify a loss weight.

wdyt? @tyler-griggs

else:
base_log_probs = None
# Critic forward (dispatch handles offload/backload automatically)
if self.has_critic:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one reason that this logic was so messy before was that you could technically overlap the forward pass for critic/ref/policy, which is something you lose here. This probably wasn't obvious or used very much in most of our configs since everything was colocated, but could matter more at scale.

Once we make the api async (and handle some requests queue on the server side) we could add that functionality back - do you think that's worth adding a TODO?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I was thinking the same thing. I think it's fine to lose overlapping for a bit, especially given that critic usage is presumably quite low. But it will be possible to add it back once we update the async request handling.

tyler-griggs and others added 2 commits January 19, 2026 02:25
Resolved conflicts by accepting origin/main changes:
- trainer.py: kept async/await and offload functionality
- worker.py: kept ppo_train method with loss scaling
- test files: kept origin/main versions

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
@tyler-griggs tyler-griggs marked this pull request as ready for review January 19, 2026 19:50
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant architectural refactoring aimed at improving the separation of concerns within the training stack. Key changes include moving algorithmic logic to the trainer and infrastructure logic to workers, and introducing a WorkerDispatch class to coordinate between them. This refactoring also removes redundant training methods from workers and shifts micro-batching logic there.

Additionally, this PR adds support for expert parallelism in skyrl-tx, including a ragged_dot wrapper that handles sharded experts and updates to LoRA layers to support this.

The changes are well-structured and accompanied by new tests for the WorkerDispatch functionality and expert parallelism layers. The dependency updates for vllm, torch, and other packages are consistent with the code modifications. The addition of project-summary.md provides an excellent overview of the refactoring.

Overall, this is a high-quality refactoring that improves the architecture and adds new capabilities. I have not found any issues that require changes.

tyler-griggs and others added 2 commits January 19, 2026 19:59
- FSDP workers now use forward_backward + optim_step pattern instead of
  ppo_train. Megatron workers keep ppo_train.
- Gradient scaling: scale gradients by 1/micro_batches_accumulated in
  optim_step instead of scaling loss during backward
- BatchIterator now yields Experience directly (not TrainingInputBatch)
- Trainer branches on strategy: Megatron uses ppo_train via dispatch,
  FSDP uses forward_backward + optim_step
- Use dispatch.get_lcm_dp_size() for cleaner dp size calculation
- Remove test_ppo_train.py (was testing removed FSDP ppo_train)
- Update tests to use TrainingInputBatch and new worker API

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Use make_dummy_training_batch instead of make_dummy_experience
- Use mesh dispatch for forward_backward with data= kwarg
- Rename experience parameter to data

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
@tyler-griggs tyler-griggs changed the base branch from tinkerify to main January 19, 2026 20:41
tyler-griggs and others added 5 commits January 19, 2026 20:49
The previous implementation only looped over epochs and passed the full
training batch to forward_backward, taking one optimizer step per epoch.

This restores the reference behavior where we:
1. Loop over epochs
2. Within each epoch, loop over mini-batches (train_batch_size // mini_batch_size)
3. Take an optimizer step after each mini-batch

This gives the correct number of optimizer steps:
  update_epochs_per_batch * (train_batch_size // mini_batch_size)

Also adds TODO to rename _normalize_mini_batch_size once Megatron no
longer requires mini-batch normalization.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1. Fix optimizer GPU leak during weight sync:
   - trainer.py: Use backload_optimizer=False when loading model before
     checkpoint or after ref sync, so prepare_for_weight_sync correctly
     tracks GPU state

2. Update tests for new forward_backward API:
   - test_worker_offload.py: Use make_dummy_training_batch, mesh dispatch
   - test_save_load_model.py: Use make_dummy_training_batch, mesh dispatch

API change: forward_backward now takes TrainingInputBatch via "mesh"
dispatch (instead of Experience + microbatch_weight via "pass_through")

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Update fwd_logprobs_values_reward to use dispatch.forward() instead of
  manual backload/offload calls (~90 lines -> ~30 lines)
- Update save_checkpoints to use dispatch.save_checkpoint()
- Update load_checkpoints to use dispatch.load_checkpoint()
- Update save_models to use dispatch.save_hf_model()
- Update update_ref_with_policy to use dispatch.save_hf_model() and
  dispatch.init_model()
- Update _cleanup_old_checkpoints to use dispatch.get_node_ids()
- Remove redundant direct backload calls at start/end of train loop
- Remove unused imports (concatenate_outputs_after_mesh_dispatch,
  TrainingOutputBatch, get_node_ids)

This centralizes GPU state management in WorkerDispatch, eliminating
the risk of state tracking inconsistency and simplifying the trainer.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
@tyler-griggs tyler-griggs merged commit 6a9ab82 into NovaSky-AI:main Jan 19, 2026
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants