-
Notifications
You must be signed in to change notification settings - Fork 222
[Tinker] Refactor trainer and worker (to move algo to trainer and infra to worker) #859
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Tinker] Refactor trainer and worker (to move algo to trainer and infra to worker) #859
Conversation
| lcm_dp_size = math.lcm(lcm_dp_size, self.critic_model.actor_infos[0].rank.dp_size) | ||
| if self.ref_model is not None: | ||
| lcm_dp_size = math.lcm(lcm_dp_size, self.ref_model.actor_infos[0].rank.dp_size) | ||
| lcm_dp_size = self.dispatch.get_lcm_dp_size() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This also needs to be moved out of trainer. The trainer shouldn't know/care about dp size.
erictang000
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks pretty good to me! will work on the megatron worker update off of this branch
Also calling out some work that's actively ongoing for adding dynamic batch sizing that's touching relevant codepaths:
#817
#847
maybe the best thing to do would be to rebase #847 on top of the tinkerify branch? (and move down the micro batching logic updates from there into into forward_backward)
| scale = 1.0 / self._micro_batches_accumulated | ||
| for param in self.model.parameters(): | ||
| if param.grad is not None: | ||
| param.grad.mul_(scale) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I get that this is cleaner to not worry about loss scaling during forward (which requires pre-knowing the details of how we break down the microbatch or how often we call optim_step).
But are you sure that this is exactly numerically equivalent + compatible with autocasting/mixed precision? Scaling the loss prior to backward seems cleaner in that sense, can we verify that this is numerically identical (and that there aren't overflow edge cases here) to scaling the loss in pytorch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
coming back to this now that i'm implementing this for megatron - I think it overall seems cleaner and more flexible to handle scaling loss for gradient accumulation in forward_backward, where we could just set the default to weight each micro batch evenly and scale the loss for each microbatch by (1/num_microbatches). We could even theoretically provide an optional parameter for forward_backward that specifies a loss scaling term (default 1), so that users could flexibly scale the loss for different mini-batches differently (i.e. if you have 2 mini batches of different sizes, you might want to make sure that the 2nd mini batch loss is scaled to size).
def _forward_backward_micro(experience, microbatch_weight):
...
loss = loss * microbatch_weight
loss.backward()
...
def forward_backward(data, loss_weight=1.0):
micro_batch_iterator = BatchIterator(data, micro_batch_size, drop_last=False)
for micro_batch in micro_batch_iterator:
metrics = self._forward_backward_micro(micro_batch, microbatch_weight=loss_weight / len(micro_batch_iterator))
...This would still be compatible with the upstream tinker API which doesn't let you specify a loss weight.
wdyt? @tyler-griggs
| else: | ||
| base_log_probs = None | ||
| # Critic forward (dispatch handles offload/backload automatically) | ||
| if self.has_critic: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one reason that this logic was so messy before was that you could technically overlap the forward pass for critic/ref/policy, which is something you lose here. This probably wasn't obvious or used very much in most of our configs since everything was colocated, but could matter more at scale.
Once we make the api async (and handle some requests queue on the server side) we could add that functionality back - do you think that's worth adding a TODO?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I was thinking the same thing. I think it's fine to lose overlapping for a bit, especially given that critic usage is presumably quite low. But it will be possible to add it back once we update the async request handling.
Resolved conflicts by accepting origin/main changes: - trainer.py: kept async/await and offload functionality - worker.py: kept ppo_train method with loss scaling - test files: kept origin/main versions Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a significant architectural refactoring aimed at improving the separation of concerns within the training stack. Key changes include moving algorithmic logic to the trainer and infrastructure logic to workers, and introducing a WorkerDispatch class to coordinate between them. This refactoring also removes redundant training methods from workers and shifts micro-batching logic there.
Additionally, this PR adds support for expert parallelism in skyrl-tx, including a ragged_dot wrapper that handles sharded experts and updates to LoRA layers to support this.
The changes are well-structured and accompanied by new tests for the WorkerDispatch functionality and expert parallelism layers. The dependency updates for vllm, torch, and other packages are consistent with the code modifications. The addition of project-summary.md provides an excellent overview of the refactoring.
Overall, this is a high-quality refactoring that improves the architecture and adds new capabilities. I have not found any issues that require changes.
- FSDP workers now use forward_backward + optim_step pattern instead of ppo_train. Megatron workers keep ppo_train. - Gradient scaling: scale gradients by 1/micro_batches_accumulated in optim_step instead of scaling loss during backward - BatchIterator now yields Experience directly (not TrainingInputBatch) - Trainer branches on strategy: Megatron uses ppo_train via dispatch, FSDP uses forward_backward + optim_step - Use dispatch.get_lcm_dp_size() for cleaner dp size calculation - Remove test_ppo_train.py (was testing removed FSDP ppo_train) - Update tests to use TrainingInputBatch and new worker API Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Use make_dummy_training_batch instead of make_dummy_experience - Use mesh dispatch for forward_backward with data= kwarg - Rename experience parameter to data Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
The previous implementation only looped over epochs and passed the full training batch to forward_backward, taking one optimizer step per epoch. This restores the reference behavior where we: 1. Loop over epochs 2. Within each epoch, loop over mini-batches (train_batch_size // mini_batch_size) 3. Take an optimizer step after each mini-batch This gives the correct number of optimizer steps: update_epochs_per_batch * (train_batch_size // mini_batch_size) Also adds TODO to rename _normalize_mini_batch_size once Megatron no longer requires mini-batch normalization. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1. Fix optimizer GPU leak during weight sync:
- trainer.py: Use backload_optimizer=False when loading model before
checkpoint or after ref sync, so prepare_for_weight_sync correctly
tracks GPU state
2. Update tests for new forward_backward API:
- test_worker_offload.py: Use make_dummy_training_batch, mesh dispatch
- test_save_load_model.py: Use make_dummy_training_batch, mesh dispatch
API change: forward_backward now takes TrainingInputBatch via "mesh"
dispatch (instead of Experience + microbatch_weight via "pass_through")
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Update fwd_logprobs_values_reward to use dispatch.forward() instead of manual backload/offload calls (~90 lines -> ~30 lines) - Update save_checkpoints to use dispatch.save_checkpoint() - Update load_checkpoints to use dispatch.load_checkpoint() - Update save_models to use dispatch.save_hf_model() - Update update_ref_with_policy to use dispatch.save_hf_model() and dispatch.init_model() - Update _cleanup_old_checkpoints to use dispatch.get_node_ids() - Remove redundant direct backload calls at start/end of train loop - Remove unused imports (concatenate_outputs_after_mesh_dispatch, TrainingOutputBatch, get_node_ids) This centralizes GPU state management in WorkerDispatch, eliminating the risk of state tracking inconsistency and simplifying the trainer. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
PR Overview
This PR has a pretty significant architectural refactor. The primary changes are:
Move algorithm logic into the trainer and infra logic into the workers.
ppo_train(from FSDP workers for now, @erictang000 to remove from Megatron)Introduce WorkerDispatch: an intermediate layer between the training loop and workers. Later, this is the entity that will sit beneath the Tinker API server. It is the entity that handles a "pool" of workers. Currently it dispatches requests to the pool of workers (by mesh or by pass_through) and handles onloading / offloading workers.
i. I'm very open to renaming suggestions.
Architecture
Here's a diagram of what the breakdown of responsibilities now looks like. Note that some of this will still need to change. E.g., the trainer will not call
dispatchin the future, it will call into the Tinker API server.Tests
Deleted Tests
test_ppo_train.py- tested removedppo_trainmethodUpdated Tests
test_training_step.py- usesWorkerDispatchfor policy teststest_worker_offload.py- updated to work with new interfacestest_save_load_checkpoint.py- updated importstest_trainer.py- rewrotetest_normalize_mini_batch_sizeGradient scaling
Since the worker no longer knows mini batches, we scale gradients instead of loss:
Old (scale loss during backward)
New (scale gradients at optim_step)
Both produce:
grad = (1/N) * Σ ∂loss_i/∂paramWhat's next for training refactor?