Skip to content

Conversation

@tyler-griggs
Copy link
Member

@tyler-griggs tyler-griggs commented Jan 20, 2026

Summary

  • Add forward_backward() and optim_step() methods to MegatronPolicyWorkerBase to match FSDP worker interface
  • Update trainer to use unified interface for both Megatron and FSDP strategies (removes strategy branching)
  • Mark ppo_train() as deprecated (kept for backward compatibility)
  • Update test_megatron_worker.py to use the new interface

This brings Megatron up to parity with FSDP following the refactoring in PR #859.

Test plan

  • Run test_megatron_worker.py to verify forward_backward + optim_step works correctly
  • Verify metrics match between Megatron and FSDP implementations

Co-Authored-By: Eric Tang erictang000@gmail.com

…ptim_step

- Add forward_backward() and optim_step() methods to MegatronPolicyWorkerBase
- Update trainer to use unified interface for both strategies
- Remove strategy branching in train_critic_and_policy()
- Mark ppo_train() as deprecated (kept for backward compatibility)
- Update test_megatron_worker.py to use new interface

Co-Authored-By: Eric Tang <erictang000@gmail.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request successfully unifies the training interfaces for Megatron and FSDP strategies by introducing forward_backward and optim_step methods in MegatronPolicyWorkerBase. The trainer.py file is updated to use this unified interface, removing strategy-specific branching, which significantly improves code maintainability and clarity. The ppo_train method is appropriately marked as deprecated, and the tests are updated to reflect these changes. Overall, this is a well-executed refactoring that aligns with the goal of creating a more consistent training pipeline.

micro_batches=micro_buffer,
seq_len=seq_len,
micro_batch_size=micro_bsz,
temperature=self.cfg.generator.sampling_params.temperature,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The temperature parameter is typically associated with sampling in generation tasks. While it might be used internally by MegatronModelWrapper.forward_backward_mini_batch for entropy calculation or similar, its presence in a training method's signature can be confusing. Consider adding a comment to clarify its specific role in the training forward/backward pass, or if possible, refactor MegatronModelWrapper to only expose parameters relevant to training loss calculation in this context.

tyler-griggs and others added 3 commits January 20, 2026 17:44
…nterface

- Remove ppo_train from MegatronPolicyWorkerBase and WorkerDispatch
- Update test_megatron_dp, test_megatron_offload to use forward_backward + optim_step
- Update test_save_load_model.py and test_save_load_checkpoint.py for unified interface
- Simplify _normalize_mini_batch_size (no longer needs policy_mini_batch_size_per_gpu)

Both FSDP and Megatron now use the same forward_backward + optim_step interface.

Co-Authored-By: Eric Tang <erictang000@gmail.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
The method just set _micro_batches_accumulated = 0, which can be done
directly in __init__. This removes unnecessary indirection and the
vestigial mesh_rank guard that was no longer needed.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Returns:
The gradient norm (before scaling, after clipping), or None if unavailable.
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need the gradient scaling for grad accumulation here i think? looked into it briefly and it seemed doable

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tyler-griggs given offline discussion of moving to do loss sums, maybe we don't need to do any scaling here anymore actually? We will need to figure out how the megatron internal gradient accumulation works and whether we can enforce that it also only does sums.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants