-
Notifications
You must be signed in to change notification settings - Fork 224
Unify Megatron and FSDP training interfaces with forward_backward + optim_step #901
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Unify Megatron and FSDP training interfaces with forward_backward + optim_step #901
Conversation
…ptim_step - Add forward_backward() and optim_step() methods to MegatronPolicyWorkerBase - Update trainer to use unified interface for both strategies - Remove strategy branching in train_critic_and_policy() - Mark ppo_train() as deprecated (kept for backward compatibility) - Update test_megatron_worker.py to use new interface Co-Authored-By: Eric Tang <erictang000@gmail.com> Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request successfully unifies the training interfaces for Megatron and FSDP strategies by introducing forward_backward and optim_step methods in MegatronPolicyWorkerBase. The trainer.py file is updated to use this unified interface, removing strategy-specific branching, which significantly improves code maintainability and clarity. The ppo_train method is appropriately marked as deprecated, and the tests are updated to reflect these changes. Overall, this is a well-executed refactoring that aligns with the goal of creating a more consistent training pipeline.
| micro_batches=micro_buffer, | ||
| seq_len=seq_len, | ||
| micro_batch_size=micro_bsz, | ||
| temperature=self.cfg.generator.sampling_params.temperature, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The temperature parameter is typically associated with sampling in generation tasks. While it might be used internally by MegatronModelWrapper.forward_backward_mini_batch for entropy calculation or similar, its presence in a training method's signature can be confusing. Consider adding a comment to clarify its specific role in the training forward/backward pass, or if possible, refactor MegatronModelWrapper to only expose parameters relevant to training loss calculation in this context.
…nterface - Remove ppo_train from MegatronPolicyWorkerBase and WorkerDispatch - Update test_megatron_dp, test_megatron_offload to use forward_backward + optim_step - Update test_save_load_model.py and test_save_load_checkpoint.py for unified interface - Simplify _normalize_mini_batch_size (no longer needs policy_mini_batch_size_per_gpu) Both FSDP and Megatron now use the same forward_backward + optim_step interface. Co-Authored-By: Eric Tang <erictang000@gmail.com> Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
The method just set _micro_batches_accumulated = 0, which can be done directly in __init__. This removes unnecessary indirection and the vestigial mesh_rank guard that was no longer needed. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
| Returns: | ||
| The gradient norm (before scaling, after clipping), or None if unavailable. | ||
| """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need the gradient scaling for grad accumulation here i think? looked into it briefly and it seemed doable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tyler-griggs given offline discussion of moving to do loss sums, maybe we don't need to do any scaling here anymore actually? We will need to figure out how the megatron internal gradient accumulation works and whether we can enforce that it also only does sums.
Summary
forward_backward()andoptim_step()methods toMegatronPolicyWorkerBaseto match FSDP worker interfaceppo_train()as deprecated (kept for backward compatibility)test_megatron_worker.pyto use the new interfaceThis brings Megatron up to parity with FSDP following the refactoring in PR #859.
Test plan
test_megatron_worker.pyto verify forward_backward + optim_step works correctlyCo-Authored-By: Eric Tang erictang000@gmail.com