Skip to content

[megatron]feat: Add routing replay support for Megatron-Swift GRPO#8196

Open
XianlongLi wants to merge 10 commits intomodelscope:mainfrom
XianlongLi:routing_replay
Open

[megatron]feat: Add routing replay support for Megatron-Swift GRPO#8196
XianlongLi wants to merge 10 commits intomodelscope:mainfrom
XianlongLi:routing_replay

Conversation

@XianlongLi
Copy link

@XianlongLi XianlongLi commented Mar 4, 2026

PR type

  • Bug Fix
  • New Feature
  • Document Updates
  • More Models or Datasets Support

PR information

This PR adds routing replay support for GRPO in Megatron-Swift, including both Recompute Routing Replay(R2) and Rollout Routing Replay(R3) modes. The algorithm idea comes from research in MoE Reinforcement Learning 2510.11370.

The differences between R2 and R3 are as follows:

R2: Caches the routing distribution/mask generated in the Recompute stage of the RL training pipeline (a middle step between rollout and policy update, where the old policy is re-calculated on training data). Its routing data comes from the training engine (e.g., Megatron), not the inference engine.

R3: Caches the routing distribution/mask generated in the Rollout (inference) stage of the RL pipeline—this is the stage where the model generates responses for tasks via the inference engine (e.g., vLLM). Its routing data comes directly from the inference engine, the actual deployment environment of the model.

Usage Tutorial

To enable Router Replay functionality, add the router_replay_mode configuration to your training script:

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
NPROC_PER_NODE=8 \
PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \
megatron rlhf \
    --rlhf_type grpo \
    --model Qwen/Qwen3-30B-A3B-Instruct-2507 \
    --save_safetensors true \
    --merge_lora false \
    --context_parallel_size 2 \
    --tensor_model_parallel_size 2 \
    --expert_model_parallel_size 4 \
    --pipeline_model_parallel_size 2 \
    --dataset open-r1/DAPO-Math-17k-Processed \
    --num_train_epochs 1 \
    --global_batch_size 64 \
    --micro_batch_size 2 \
    --steps_per_generation 2 \
    --num_generations 8 \
    --reward_funcs accuracy format \
    --use_vllm true \
    --vllm_mode colocate \
    --vllm_gpu_memory_utilization 0.3 \
    --vllm_tensor_parallel_size 4 \
    --vllm_max_model_len 16384 \
    --max_length 8192 \
    --max_completion_length 8192 \
    --tuner_type lora \
    --lr 5e-5 \
    --bf16 true \
    --beta 0.00 \
    --importance_sampling_level sequence \
    --epsilon 3e-4 \
    --epsilon_high 4e-4 \
    --dynamic_sample false \
    --overlong_filter true \
    --loss_type grpo \
    --sleep_level 2 \
    --offload_model true \
    --offload_bridge false \
    --offload_optimizer true \
    --logging_steps 1 \
    --recompute_granularity selective \
    --finetune \
    --dataloader_num_workers 8 \
    --dataset_num_proc 8 \
    --no_save_optim \
    --no_save_rng \
    --attention_backend flash \
    --temperature 1.0 \
    --padding_free true \
    --sequence_parallel true \
    --log_completions true \
    --report_to wandb \
    --router_replay_mode R3 # Options: "R2", "R3"

For R3 in server mode, additional vLLM engine configuration is required. Please refer to the vllm-project/vllm#28284
The rollout script is as follows:

swift rollout \
   --model Qwen/Qwen3-30B-A3B-Instruct-2507 \
   --vllm_tensor_parallel_size 2 \
   --vllm_max_model_len 10240 \
   --vllm_engine_kwargs '{"enable_return_routed_experts": true}'

The code implementation of this PR references verl-project/verl#4101, thank them for excellent work.

Experiment results

is_train_old_smooth is_train_vs_rollout_smooth k3_kl_smoothpng

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the Megatron-Swift GRPO training framework by introducing router replay capabilities. This feature aims to improve the stability and efficiency of Mixture-of-Experts (MoE) model training by either replaying routing decisions from the recompute stage (R2) or directly from the inference rollout (R3). By providing consistent routing information, the system can better manage the dynamic nature of MoE models, leading to more robust and predictable training outcomes, especially in reinforcement learning from human feedback (RLHF) scenarios.

Highlights

  • Router Replay Mode Introduction: Introduced a new command-line argument --router_replay_mode with options disabled, R2 (Recompute Routing Replay), and R3 (Rollout Routing Replay) to control how routing decisions are handled during GRPO training.
  • R2 (Recompute Routing Replay) Implementation: Implemented R2 mode, which records the routing distribution/mask during the recompute stage of the RL training pipeline and replays it for subsequent forward passes, ensuring consistent routing decisions.
  • R3 (Rollout Routing Replay) Implementation: Implemented R3 mode, which captures routing data directly from the inference engine (e.g., vLLM) during the rollout stage and uses it to replay routing decisions during training, aligning training with actual deployment behavior.
  • Megatron Core Patching for Router Replay: Applied a monkey patch to Megatron's TopKRouter to integrate the router replay mechanism, allowing dynamic recording and replaying of expert routing indices within the MoE layers.
  • vLLM Integration for R3: Modified the vLLM engine integration to enable returning routed_experts data, which is essential for R3 mode to capture routing decisions from the inference environment.
  • Numpy Array Serialization Support: Added custom serialization and deserialization logic for NumPy arrays within the inference protocol, facilitating the transfer of routed_experts data.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • docs/source/Megatron-SWIFT/Command-line-parameters.md
    • Added documentation for the new router_replay_mode parameter in Chinese.
  • docs/source_en/Megatron-SWIFT/Command-line-parameters.md
    • Added documentation for the new router_replay_mode parameter in English.
  • swift/infer_engine/grpo_vllm_engine.py
    • Included routed_experts in the chat completion response to support R3 mode.
  • swift/infer_engine/protocol.py
    • Added NumpyArray type with custom serializers and deserializers for NumPy arrays.
    • Extended ChatCompletionResponseChoice to include an optional routed_experts field.
  • swift/megatron/arguments/megatron_args.py
    • Introduced router_replay_mode as a new argument for Megatron RLHF configurations.
  • swift/megatron/model/model_config.py
    • Added enable_routing_replay attribute to MegatronModelConfig.
    • Configured enable_routing_replay to be true if router_replay_mode is not disabled.
  • swift/megatron/trainers/base.py
    • Imported router replay related classes and functions.
    • Applied the apply_router_replay_patch during trainer initialization.
    • Set and cleared global router replay actions (REPLAY_FORWARD, RECORD, REPLAY_BACKWARD) around the train_step and forward_backward_func calls.
  • swift/megatron/trainers/grpo_trainer.py
    • Imported router replay utilities.
    • Modified _get_encoded_batch to process and pad routed_experts data when router_replay_mode is R3.
    • Updated merge_output_input_data to store routed_experts from rollout outputs.
    • Integrated router replay actions (RECORD, REPLAY_FORWARD) into _maybe_compute_logps for old policy logps calculation.
    • Added routed_experts to the batch data after logps computation.
    • Managed router replay actions (REPLAY_FORWARD, REPLAY_BACKWARD) within the forward_step function.
    • Handled layers_topk_idx for router replay data in model_forward.
  • swift/megatron/trainers/rollout_mixin.py
    • Enabled enable_return_routed_experts in vLLM engine kwargs when router_replay_mode is R3.
  • swift/megatron/utils/init.py
    • Exported new modules router_replay_patch and router_replay_utils.
  • swift/megatron/utils/router_replay_patch.py
    • Added RouterReplayAction enum for replay states.
    • Implemented RouterReplay class to manage global router instances and replay actions.
    • Patched TopKRouter's __init__ and routing methods to incorporate router replay logic.
    • Dynamically added enable_routing_replay to TransformerConfig.
  • swift/megatron/utils/router_replay_utils.py
    • Added utility functions get_local_layer_range, get_local_topk_idx_for_current_rank, get_router_replay_data, and set_router_replay_data for managing router replay data across parallel ranks.
    • Introduced RouterReplayHelper class to query router replay states and locate local RouterReplay instances.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces routing replay support (R2 and R3 modes) for GRPO in Megatron-Swift. The implementation is comprehensive, involving changes to documentation, arguments, model configuration, and trainer logic. The core functionality is achieved by monkey-patching Megatron's TopKRouter to enable recording and replaying of routing decisions, with helper utilities for managing state and parallelism. The code is well-structured. I've identified a couple of minor issues, including a variable name typo and a code style issue, and have provided suggestions for fixes.

Comment on lines +395 to +399
if template.padding_free:
gloabl_routed_experts = torch.cat(routed_experts_list, dim=0).unsqueeze(0)
else:
gloabl_routed_experts = torch.stack(routed_experts_list)
encoded_batch['routed_experts'] = gloabl_routed_experts.to(device=self.device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There's a typo in the variable name gloabl_routed_experts. It should be global_routed_experts.

Suggested change
if template.padding_free:
gloabl_routed_experts = torch.cat(routed_experts_list, dim=0).unsqueeze(0)
else:
gloabl_routed_experts = torch.stack(routed_experts_list)
encoded_batch['routed_experts'] = gloabl_routed_experts.to(device=self.device)
if template.padding_free:
global_routed_experts = torch.cat(routed_experts_list, dim=0).unsqueeze(0)
else:
global_routed_experts = torch.stack(routed_experts_list)
encoded_batch['routed_experts'] = global_routed_experts.to(device=self.device)

offset, _ = get_local_layer_range(tf_config, vp_rank)
for i, router in enumerate(router_instances_list):
router.set_target_indices(layers_topk_idx_reshape[i + offset].to(torch.int64))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This blank line with trailing whitespace can be removed for cleaner code.

@hjh0119 hjh0119 self-assigned this Mar 4, 2026
if value is None:
return None
if isinstance(value, np.ndarray):
return {'data': value.tolist(), 'shape': value.shape, 'dtype': str(value.dtype), '__ndarray__': True}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tolist() is inefficient

- 'data': value.tolist()
+ 'data': base64.b64encode(value.tobytes()).decode('ascii'),

if value is None:
return None
if isinstance(value, dict) and value.get('__ndarray__'):
return np.array(value['data'], dtype=value['dtype']).reshape(value['shape'])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return np.array(value['data'], dtype=value['dtype']).reshape(value['shape'])
data = base64.b64decode(value['data'])
return np.frombuffer(data, dtype=value['dtype']).reshape(value['shape'])

'none'] = 'aux_loss'
use_shared_expert_gate: bool = False

enable_routing_replay: bool = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recommend using the parameter name moe_enable_routing_replay to maintain
consistency with Megatron-LM.

Reference: NVIDIA/Megatron-LM#2101

load_format = vllm_engine_kwargs.pop('load_format', 'dummy')

if self.args.router_replay_mode == 'R3':
vllm_engine_kwargs['enable_return_routed_experts'] = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add a vLLM version check? The enable_return_routed_experts
parameter is only available in recent versions

@hjh0119
Copy link
Collaborator

hjh0119 commented Mar 6, 2026

Thanks for your contribution!

Looks good overall, just a few comments.

@hjh0119
Copy link
Collaborator

hjh0119 commented Mar 6, 2026

It seems the following Router Replay bugfixes are not included:

verl-project/verl#4986
verl-project/verl#5452

@hjh0119
Copy link
Collaborator

hjh0119 commented Mar 6, 2026

Would you mind running some experiments to verify the correctness of this PR?
I think comparing the importance ratio of rollout logps/old logps vs training logps would help.

@XianlongLi
Copy link
Author

Thanks for your contribution!

Looks good overall, just a few comments.

Thanks for your comments, I will improve it.

@hjh0119
Copy link
Collaborator

hjh0119 commented Mar 11, 2026

@XianlongLi is this pr ready for review?

@XianlongLi
Copy link
Author

@XianlongLi is this pr ready for review?

@hjh0119 Not yet, the code has been added, and I'm waiting for some available GPUs to verify the new code and run experiments. I will remind you to review it again when I'm done..

@XianlongLi XianlongLi requested a review from hjh0119 March 18, 2026 09:31
@hjh0119
Copy link
Collaborator

hjh0119 commented Mar 20, 2026

LGTM!
Have you run any experiments to validate this feature?

@XianlongLi
Copy link
Author

LGTM! Have you run any experiments to validate this feature?
@hjh0119
I've added it to the PR description.

return routing_probs, routing_map


def patched_routing(self, logits: torch.Tensor):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR NVIDIA/Megatron-LM#3096 introduces the padding_mask parameter, which could break compatibility here.

I'd prefer requiring a newer Megatron version (since router_replay is now natively supported), making maintenance much cleaner.

Comment on lines +95 to +111
def get_router_replay_data(tf_config, vp_rank=None):
router_instances_list = RouterReplayHelper.get_micro_batch_router_list(tf_config, vp_rank)
layers_topk_idx = []
for router in router_instances_list:
layers_topk_idx.append(router.recorded_topk_idx.to(torch.uint8))
# layer_num, seq_len, topk -> 1, seq_len, layer_num, topk
layers_topk_idx = torch.stack(layers_topk_idx).transpose(0, 1).unsqueeze(0).to(device_name)
return layers_topk_idx


def set_router_replay_data(layers_topk_idx, tf_config, vp_rank=None):
# bs, seq_len, layer_num, topk -> layer_num, total_seq_len, topk
layers_topk_idx_reshape = layers_topk_idx.flatten(0, 1).transpose(0, 1).to(device_name)
router_instances_list = RouterReplayHelper.get_micro_batch_router_list(tf_config, vp_rank)
offset, _ = get_local_layer_range(tf_config, vp_rank)
for i, router in enumerate(router_instances_list):
router.set_target_indices(layers_topk_idx_reshape[i + offset].to(torch.int64))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does get_micro_batch_router_list already return offset indices?

@hjh0119
Copy link
Collaborator

hjh0119 commented Mar 20, 2026

I've added it to the PR description.

Looks like the initial value of R3's RIS is approximately 3? Is this normal?

Would you be able to run experiments for R2 as well?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants