[megatron]feat: Add routing replay support for Megatron-Swift GRPO#8196
[megatron]feat: Add routing replay support for Megatron-Swift GRPO#8196XianlongLi wants to merge 10 commits intomodelscope:mainfrom
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the Megatron-Swift GRPO training framework by introducing router replay capabilities. This feature aims to improve the stability and efficiency of Mixture-of-Experts (MoE) model training by either replaying routing decisions from the recompute stage (R2) or directly from the inference rollout (R3). By providing consistent routing information, the system can better manage the dynamic nature of MoE models, leading to more robust and predictable training outcomes, especially in reinforcement learning from human feedback (RLHF) scenarios. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces routing replay support (R2 and R3 modes) for GRPO in Megatron-Swift. The implementation is comprehensive, involving changes to documentation, arguments, model configuration, and trainer logic. The core functionality is achieved by monkey-patching Megatron's TopKRouter to enable recording and replaying of routing decisions, with helper utilities for managing state and parallelism. The code is well-structured. I've identified a couple of minor issues, including a variable name typo and a code style issue, and have provided suggestions for fixes.
| if template.padding_free: | ||
| gloabl_routed_experts = torch.cat(routed_experts_list, dim=0).unsqueeze(0) | ||
| else: | ||
| gloabl_routed_experts = torch.stack(routed_experts_list) | ||
| encoded_batch['routed_experts'] = gloabl_routed_experts.to(device=self.device) |
There was a problem hiding this comment.
There's a typo in the variable name gloabl_routed_experts. It should be global_routed_experts.
| if template.padding_free: | |
| gloabl_routed_experts = torch.cat(routed_experts_list, dim=0).unsqueeze(0) | |
| else: | |
| gloabl_routed_experts = torch.stack(routed_experts_list) | |
| encoded_batch['routed_experts'] = gloabl_routed_experts.to(device=self.device) | |
| if template.padding_free: | |
| global_routed_experts = torch.cat(routed_experts_list, dim=0).unsqueeze(0) | |
| else: | |
| global_routed_experts = torch.stack(routed_experts_list) | |
| encoded_batch['routed_experts'] = global_routed_experts.to(device=self.device) |
| offset, _ = get_local_layer_range(tf_config, vp_rank) | ||
| for i, router in enumerate(router_instances_list): | ||
| router.set_target_indices(layers_topk_idx_reshape[i + offset].to(torch.int64)) | ||
|
|
swift/infer_engine/protocol.py
Outdated
| if value is None: | ||
| return None | ||
| if isinstance(value, np.ndarray): | ||
| return {'data': value.tolist(), 'shape': value.shape, 'dtype': str(value.dtype), '__ndarray__': True} |
There was a problem hiding this comment.
tolist() is inefficient
- 'data': value.tolist()
+ 'data': base64.b64encode(value.tobytes()).decode('ascii'),
swift/infer_engine/protocol.py
Outdated
| if value is None: | ||
| return None | ||
| if isinstance(value, dict) and value.get('__ndarray__'): | ||
| return np.array(value['data'], dtype=value['dtype']).reshape(value['shape']) |
There was a problem hiding this comment.
| return np.array(value['data'], dtype=value['dtype']).reshape(value['shape']) | |
| data = base64.b64decode(value['data']) | |
| return np.frombuffer(data, dtype=value['dtype']).reshape(value['shape']) |
swift/megatron/model/model_config.py
Outdated
| 'none'] = 'aux_loss' | ||
| use_shared_expert_gate: bool = False | ||
|
|
||
| enable_routing_replay: bool = False |
There was a problem hiding this comment.
I recommend using the parameter name moe_enable_routing_replay to maintain
consistency with Megatron-LM.
Reference: NVIDIA/Megatron-LM#2101
| load_format = vllm_engine_kwargs.pop('load_format', 'dummy') | ||
|
|
||
| if self.args.router_replay_mode == 'R3': | ||
| vllm_engine_kwargs['enable_return_routed_experts'] = True |
There was a problem hiding this comment.
Should we add a vLLM version check? The enable_return_routed_experts
parameter is only available in recent versions
|
Thanks for your contribution! Looks good overall, just a few comments. |
|
It seems the following Router Replay bugfixes are not included: |
|
Would you mind running some experiments to verify the correctness of this PR? |
Thanks for your comments, I will improve it. |
cd2ce89 to
75bddf2
Compare
|
@XianlongLi is this pr ready for review? |
@hjh0119 Not yet, the code has been added, and I'm waiting for some available GPUs to verify the new code and run experiments. I will remind you to review it again when I'm done.. |
|
LGTM! |
|
| return routing_probs, routing_map | ||
|
|
||
|
|
||
| def patched_routing(self, logits: torch.Tensor): |
There was a problem hiding this comment.
PR NVIDIA/Megatron-LM#3096 introduces the padding_mask parameter, which could break compatibility here.
I'd prefer requiring a newer Megatron version (since router_replay is now natively supported), making maintenance much cleaner.
| def get_router_replay_data(tf_config, vp_rank=None): | ||
| router_instances_list = RouterReplayHelper.get_micro_batch_router_list(tf_config, vp_rank) | ||
| layers_topk_idx = [] | ||
| for router in router_instances_list: | ||
| layers_topk_idx.append(router.recorded_topk_idx.to(torch.uint8)) | ||
| # layer_num, seq_len, topk -> 1, seq_len, layer_num, topk | ||
| layers_topk_idx = torch.stack(layers_topk_idx).transpose(0, 1).unsqueeze(0).to(device_name) | ||
| return layers_topk_idx | ||
|
|
||
|
|
||
| def set_router_replay_data(layers_topk_idx, tf_config, vp_rank=None): | ||
| # bs, seq_len, layer_num, topk -> layer_num, total_seq_len, topk | ||
| layers_topk_idx_reshape = layers_topk_idx.flatten(0, 1).transpose(0, 1).to(device_name) | ||
| router_instances_list = RouterReplayHelper.get_micro_batch_router_list(tf_config, vp_rank) | ||
| offset, _ = get_local_layer_range(tf_config, vp_rank) | ||
| for i, router in enumerate(router_instances_list): | ||
| router.set_target_indices(layers_topk_idx_reshape[i + offset].to(torch.int64)) |
There was a problem hiding this comment.
Does get_micro_batch_router_list already return offset indices?
Looks like the initial value of R3's RIS is approximately 3? Is this normal? Would you be able to run experiments for R2 as well? |
PR type
PR information
This PR adds routing replay support for GRPO in Megatron-Swift, including both Recompute Routing Replay(R2) and Rollout Routing Replay(R3) modes. The algorithm idea comes from research in MoE Reinforcement Learning 2510.11370.
The differences between R2 and R3 are as follows:
R2: Caches the routing distribution/mask generated in the Recompute stage of the RL training pipeline (a middle step between rollout and policy update, where the old policy is re-calculated on training data). Its routing data comes from the training engine (e.g., Megatron), not the inference engine.
R3: Caches the routing distribution/mask generated in the Rollout (inference) stage of the RL pipeline—this is the stage where the model generates responses for tasks via the inference engine (e.g., vLLM). Its routing data comes directly from the inference engine, the actual deployment environment of the model.
Usage Tutorial
To enable Router Replay functionality, add the router_replay_mode configuration to your training script:
For R3 in server mode, additional vLLM engine configuration is required. Please refer to the vllm-project/vllm#28284
The rollout script is as follows:
The code implementation of this PR references verl-project/verl#4101, thank them for excellent work.
Experiment results