Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/Megatron-SWIFT/Command-line-parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ Megatron训练参数继承自Megatron参数和基本参数(**与ms-swift共用
- rollout_importance_sampling_threshold: 重要性采样权重的阈值,用于截断或屏蔽极端权重。默认为2.0。
- log_rollout_offpolicy_metrics: 当 `rollout_importance_sampling_mode` 未设置时,是否记录训推不一致诊断指标(KL、PPL、χ²等)。当设置了 `rollout_importance_sampling_mode` 时,指标会自动记录。默认为False。
- off_policy_sequence_mask_delta: Off-Policy Sequence Masking 阈值,来自 DeepSeek-V3.2 论文。当设置此值时,会计算每个序列的 `mean(old_policy_logps - policy_logps)`,若该值大于阈值且该序列的优势为负,则 mask 掉该序列不参与损失计算。默认为None,不启用。具体参考[文档](../Instruction/GRPO/AdvancedResearch/training_inference_mismatch.md#off-policy-sequence-masking)。
- router_replay_mode: 路由重放模式,可选项为`disabled`、`R2`、`R3`。默认为disabled,不启用路由重放。

内置奖励函数参数参考[文档](../Instruction/Command-line-parameters.md#奖励函数参数)

Expand Down
1 change: 1 addition & 0 deletions docs/source_en/Megatron-SWIFT/Command-line-parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,7 @@ In addition to inheriting the training parameters, the following parameters are
- rollout_importance_sampling_threshold: Threshold for importance sampling weights, used for truncating or masking extreme weights. Default is 2.0.
- log_rollout_offpolicy_metrics: Whether to log training-inference mismatch diagnostic metrics (KL, PPL, χ², etc.) when `rollout_importance_sampling_mode` is not set. When `rollout_importance_sampling_mode` is set, metrics are always logged. Default is False.
- off_policy_sequence_mask_delta: Off-Policy Sequence Masking threshold from [DeepSeek-V3.2 paper](https://arxiv.org/abs/2512.02556). When set, computes `mean(old_policy_logps - policy_logps)` for each sequence. If this value exceeds the threshold AND the sequence has negative advantage, the sequence is masked out from loss computation. For details, refer to the [documentation](../Instruction/GRPO/AdvancedResearch/training_inference_mismatch.md#off-policy-sequence-masking).
- router_replay_mode: Router replay mode. Options are `disabled`,`R2`,`R3`. Default is disabled.

Built-in reward function parameters refer to the [documentation](../Instruction/Command-line-parameters.md#reward-function-parameters).

Expand Down
2 changes: 1 addition & 1 deletion swift/infer_engine/grpo_vllm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def _create_chat_completion_response(self, result, inputs, request_config, reque
finish_reason=output.finish_reason,
logprobs=logprobs,
token_ids=token_ids,
)
routed_experts=getattr(output, 'routed_experts', None))
choices.append(choice)
prompt_token_ids = None
images_size = None
Expand Down
31 changes: 29 additions & 2 deletions swift/infer_engine/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,45 @@
import base64
import io
import json
import numpy as np
import os
import time
import uuid
from copy import deepcopy
from dataclasses import asdict, dataclass, field, fields
from PIL import Image
from pydantic import BaseModel, Field, field_validator
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from pydantic import AfterValidator, BaseModel, Field, PlainSerializer, field_validator
from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Union

from swift.template import Messages, Tool
from swift.utils import remove_response


def serialize_ndarray(value):
if value is None:
return None
if isinstance(value, np.ndarray):
return {
'data': base64.b64encode(value.tobytes()).decode('ascii'),
'shape': value.shape,
'dtype': str(value.dtype),
'__ndarray__': True
}
return value


def deserialize_ndarray(value):
if value is None:
return None
if isinstance(value, dict) and value.get('__ndarray__'):
data = base64.b64decode(value['data'])
return np.frombuffer(data, dtype=value['dtype']).reshape(value['shape'])
return value


NumpyArray = Annotated[Any, PlainSerializer(serialize_ndarray, return_type=Dict), AfterValidator(deserialize_ndarray)]


@dataclass
class InferRequest:
"""
Expand Down Expand Up @@ -392,6 +418,7 @@ class ChatCompletionResponseChoice:
finish_reason: Literal['stop', 'length', None]
logprobs: Optional[Dict[str, List[Dict[str, Any]]]] = None
token_ids: Optional[List[int]] = None
routed_experts: Optional[NumpyArray] = None

def to_cmpl_choice(self) -> 'CompletionResponseChoice':
self = deepcopy(self)
Expand Down
2 changes: 2 additions & 0 deletions swift/megatron/arguments/megatron_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ class RLHFMegatronArgumentsMixin:
# Beyond the 80/20 Rule, https://arxiv.org/abs/2506.01939
top_entropy_quantile: float = 1.0

router_replay_mode: Literal['disabled', 'R2', 'R3'] = 'disabled'

# ─────────────────────────── Not Supported Yet ───────────────────────────

# reward model
Expand Down
4 changes: 4 additions & 0 deletions swift/megatron/model/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,10 @@ def get_mcore_model_config(args, hf_config):
if num_moe_experts is None:
kwargs['expert_model_parallel_size'] = 1
kwargs['expert_tensor_parallel_size'] = 1

if args.router_replay_mode != 'disabled':
kwargs['moe_enable_routing_replay'] = True

config = MegatronModelConfig(**kwargs)
config.hf_config = hf_config
config.args = args
Expand Down
22 changes: 19 additions & 3 deletions swift/megatron/trainers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@
from swift.megatron.callbacks import megatron_callbacks_map
from swift.megatron.model import get_mcore_model
from swift.megatron.tuners import LoraParallelLinear
from swift.megatron.utils import (copy_original_module_weight, disable_forward_pre_hook, enable_forward_pre_hook,
get_optimizer_param_scheduler, get_padding_to, init_persistent_async_worker,
initialize_tp_communicators, load_mcore_checkpoint,
from swift.megatron.utils import (apply_router_replay_patch, copy_original_module_weight, disable_forward_pre_hook,
enable_forward_pre_hook, get_optimizer_param_scheduler, get_padding_to,
init_persistent_async_worker, initialize_tp_communicators, load_mcore_checkpoint,
logical_and_across_model_parallel_group, maybe_finalize_async_save,
prepare_mcore_model, reduce_max_stat_across_model_parallel_group,
save_mcore_checkpoint, should_disable_forward_pre_hook, warmup_jit_function,
Expand All @@ -44,8 +44,11 @@

try:
from megatron.core.optimizer import param_group_identifier_keys
from megatron.core.transformer.moe.router_replay import RouterReplay, RouterReplayAction
except ImportError:
param_group_identifier_keys = None
RouterReplay = None
RouterReplayAction = None

mcore_016 = version.parse(megatron.core.__version__) >= version.parse('0.16.0rc0')

Expand All @@ -55,6 +58,11 @@
class BaseMegatronTrainer(ABC):

def __init__(self, args, template: Template):
# validate mcore version and patch routing_replay
self.enable_routing_replay = args.router_replay_mode != 'disabled'
if self.enable_routing_replay:
apply_router_replay_patch()

self.args = args
self.template = template
self.prepare_model()
Expand Down Expand Up @@ -839,6 +847,10 @@ def train_step(self, train_data_iterator):
self.optimizer.zero_grad()
# TODO: refactor _replace_data_iterator
data_iterator = self._replace_data_iterator(train_data_iterator)

if self.enable_routing_replay:
RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_FORWARD)

metrics = forward_backward_func(
forward_step_func=self.forward_step,
data_iterator=data_iterator,
Expand All @@ -855,6 +867,10 @@ def train_step(self, train_data_iterator):
if update_successful:
self.opt_param_scheduler.step(increment=args.global_batch_size)

if self.enable_routing_replay:
RouterReplay.clear_global_router_replay_action()
RouterReplay.clear_global_indices()

return metrics, grad_norm, update_successful

def _aggregated_metrics(self, metrics, total_metrics):
Expand Down
79 changes: 76 additions & 3 deletions swift/megatron/trainers/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from swift.dataset import RowPreprocessor
from swift.infer_engine.protocol import RequestConfig, RolloutInferRequest, RolloutOutput
from swift.megatron.arguments import MegatronArguments, MegatronRLHFArguments
from swift.megatron.utils import forward_step_helper, get_padding_to, set_random_seed
from swift.megatron.utils import RouterReplayHelper, get_padding_to, set_random_seed, set_router_replay_data
from swift.rewards import orms
from swift.rlhf_trainers.grpo_trainer import DataType
from swift.rlhf_trainers.utils import (aggressive_empty_cache, nanstd, pad_logps_back_to_batch, profiling_context,
Expand All @@ -38,6 +38,12 @@
from .utils import gather, gather_object
from .vocab_parallel_utils import compute_logps_and_entropy_from_logits

try:
from megatron.core.transformer.moe.router_replay import RouterReplay, RouterReplayAction
except ImportError:
RouterReplay = None
RouterReplayAction = None

logger = get_logger()


Expand Down Expand Up @@ -271,6 +277,8 @@ def _batch_encode(self, infer_requests: List[Dict], template: Template, strict:
return batched_inputs, error_list

def _get_encoded_batch(self, encoded_list, rollout_batch, template):
original_seq_lengths = [item['length'] for item in encoded_list]

args = self.args
encoded_batch = to_device(template.data_collator(encoded_list, padding_to=get_padding_to(args)), self.device)

Expand Down Expand Up @@ -361,6 +369,43 @@ def _get_encoded_batch(self, encoded_list, rollout_batch, template):
flat_lps, dtype=torch.float32, device=self.device)

encoded_batch['rollout_per_token_logps'] = rollout_per_token_logps

# Validating and processing routed_experts data in R3 mode
if self.args.router_replay_mode == 'R3':
routed_experts_list = []
cur_seq_lengths = seq_lengths
if (seq_lengths.size(0) > batch_size):
cur_seq_lengths = seq_lengths[:batch_size].clone()
cur_seq_lengths[batch_size - 1] = seq_lengths[batch_size - 1:].sum()
for data, original_seq_len, cur_seq_len in zip(rollout_batch, original_seq_lengths, cur_seq_lengths):
routed_experts = data.get('routed_experts')
assert routed_experts is not None, (
'When router_replay_mode = R3, routed_experts must be in rollout data')
routed_experts = torch.tensor(routed_experts)
# The number of experts in the output can be 1 less than (prompt_length + response_token_count)
# This gap of 1 is expected
# For more details, please refer PR https://github.com/vllm-project/vllm/pull/28284
experts_seq_len = routed_experts.shape[0]
assert (experts_seq_len == original_seq_len or experts_seq_len + 1
== original_seq_len), (f'The seq_len of routed_experts({experts_seq_len}) in output ',
f'does not match the seq_len of data({original_seq_len}), '
'should be equal to or 1 less than the seq_len of data')
# Padding routed_experts(seq_len, layer_num, topk) seq_len to match the seq_len of the input_ids
padding_routed_experts = routed_experts
padding_to = cur_seq_len if template.padding_free else max_seq_len
padding_len = padding_to - experts_seq_len
if padding_len > 0:
padding_right = template.padding_side == 'right'
padding_routed_experts = nn.functional.pad(routed_experts,
(0, 0, 0, 0, 0, padding_len) if padding_right else
(0, 0, 0, 0, padding_len, 0), 'constant', 0)
routed_experts_list.append(padding_routed_experts)
if template.padding_free:
global_routed_experts = torch.cat(routed_experts_list, dim=0).unsqueeze(0)
else:
global_routed_experts = torch.stack(routed_experts_list)
encoded_batch['routed_experts'] = global_routed_experts.to(device=self.device)

return encoded_batch

def _generate_and_score_completions(self, batch):
Expand Down Expand Up @@ -558,6 +603,9 @@ def merge_output_input_data(input_data: Dict[str, Union[torch.Tensor, Any]], out
if 'content' in choice.logprobs:
rollout_logprobs = [item['logprob'] for item in choice.logprobs['content']]
input_data['rollout_logprobs'] = [rollout_logprobs]

# Step 6: Store rollout routed_experts for routing replay
input_data['routed_experts'] = choice.routed_experts
return input_data

assert len(batch) == len(outputs)
Expand Down Expand Up @@ -938,7 +986,7 @@ def _maybe_compute_logps(self, batch: Dict[str, Any]) -> Dict[str, Any]:
with self.null_ref_context() as ref_models:
assert len(ref_models) == 1, 'GRPO currently does not support VPP.'
ref_model = ref_models[0]
ref_per_token_logps_packed = self.compute_per_token_logps(
ref_per_token_logps_packed, _ = self.compute_per_token_logps(
ref_model, iter([deepcopy(inputs)]), temperature=self.temperature)
if self.template.padding_free:
ref_per_token_logps, _ = pad_logps_back_to_batch(
Expand All @@ -950,7 +998,13 @@ def _maybe_compute_logps(self, batch: Dict[str, Any]) -> Dict[str, Any]:
ref_per_token_logps = ref_per_token_logps_packed
batch['ref_per_token_logps'] = ref_per_token_logps

old_per_token_logps_packed = self.compute_per_token_logps(
if self.enable_routing_replay:
if self.args.router_replay_mode == 'R2':
RouterReplay.set_global_router_replay_action(RouterReplayAction.RECORD)
if self.args.router_replay_mode == 'R3':
RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_FORWARD)

old_per_token_logps_packed, routing_topk_idx = self.compute_per_token_logps(
self.unwrapped_models[0], iter([deepcopy(inputs)]), temperature=self.temperature)
if self.template.padding_free:
old_per_token_logps, _ = pad_logps_back_to_batch(
Expand All @@ -962,6 +1016,11 @@ def _maybe_compute_logps(self, batch: Dict[str, Any]) -> Dict[str, Any]:
old_per_token_logps = old_per_token_logps_packed
batch['old_per_token_logps'] = old_per_token_logps

if self.enable_routing_replay:
batch['routed_experts'] = routing_topk_idx
RouterReplay.clear_global_indices()
RouterReplay.clear_global_router_replay_action()

return batch

def _compute_kl_from_batches(self, mini_batch_data: List[Dict[str, Any]]) -> torch.Tensor:
Expand Down Expand Up @@ -1041,6 +1100,15 @@ def forward_step(self, data_iterator, model):
'seq_lengths': seq_lengths,
})
data.pop('loss_scale', None)

if self.enable_routing_replay and RouterReplayHelper.is_replay_backward_action(model.config):
router_instance_list = RouterReplayHelper.get_micro_batch_router_list(model.config)
for router in router_instance_list:
router.set_router_replay_action(RouterReplayAction.REPLAY_FORWARD)
if self.enable_routing_replay and RouterReplayHelper.is_replay_forward_action(model.config):
layers_topk_idx = data.pop('routed_experts', None)
set_router_replay_data(layers_topk_idx, model.config)

inputs = self._prepare_model_inputs(data)

labels = data['labels']
Expand Down Expand Up @@ -1091,6 +1159,11 @@ def forward_step(self, data_iterator, model):
output_tensor = per_token_logps
data['per_token_entropy'] = per_token_entropy

if self.enable_routing_replay and RouterReplayHelper.is_replay_forward_action(model.config):
router_instance_list = RouterReplayHelper.get_micro_batch_router_list(model.config)
for router in router_instance_list:
router.set_router_replay_action(RouterReplayAction.REPLAY_BACKWARD)

return output_tensor, partial(self.loss_func, data=data)

@profiling_decorator
Expand Down
19 changes: 16 additions & 3 deletions swift/megatron/trainers/rlhf_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from transformers.utils import ContextManagers

from swift.megatron.model import get_mcore_model
from swift.megatron.utils import forward_step_helper, load_mcore_checkpoint
from swift.megatron.utils import (RouterReplayHelper, forward_step_helper, get_local_topk_idx_for_current_rank,
get_router_replay_data, load_mcore_checkpoint, set_router_replay_data)
from swift.rlhf_trainers.utils import identity_data_collator
from swift.utils import get_logger
from .base import BaseMegatronTrainer
Expand Down Expand Up @@ -107,18 +108,30 @@ def compute_per_token_logps(self, model, data_iterator, no_grad=True, temperatur

Returns:
per_token_logps tensor, or None if on a non-last PP stage
routing_topk_idx tensor, or None if disbale router replay
"""
data = self.get_batch(data_iterator)
data.pop('loss_scale', None)
labels = data.get('labels')

routing_topk_idx = None
global_topk_idx = data.pop('routed_experts', None)
if self.enable_routing_replay and RouterReplayHelper.is_replay_forward_action(model.config):
assert global_topk_idx is not None, 'When router_replay_mode = R3, routed_experts must be in data'
routing_topk_idx = get_local_topk_idx_for_current_rank(global_topk_idx, model.config,
data.get('packed_seq_params'))
set_router_replay_data(routing_topk_idx, model.config)

data_for_forward = {k: v for k, v in data.items() if k != 'labels'}
context = torch.no_grad() if no_grad else nullcontext()
with context:
output_tensor = forward_step_helper(self.args, model, data_for_forward)

if self.enable_routing_replay and RouterReplayHelper.is_r2_record_action(model.config):
routing_topk_idx = get_router_replay_data(model.config)

if labels is None or output_tensor is None:
return None
return None, routing_topk_idx

if temperature != 1.0:
output_tensor.div_(temperature)
Expand All @@ -133,7 +146,7 @@ def compute_per_token_logps(self, model, data_iterator, no_grad=True, temperatur

if self.args.context_parallel_size > 1:
per_token_logps = self._postprocess_packed_tensor_cp(per_token_logps, packed_seq_params, num_samples)
return per_token_logps
return per_token_logps, routing_topk_idx

def _postprocess_packed_tensor_cp(self, tensor, packed_seq_params, num_samples):
"""
Expand Down
5 changes: 5 additions & 0 deletions swift/megatron/trainers/rollout_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,11 @@ def _prepare_vllm_engine(self):
vllm_engine_kwargs = args.vllm_engine_kwargs or {}
load_format = vllm_engine_kwargs.pop('load_format', 'dummy')

if self.args.router_replay_mode == 'R3':
assert check_vllm_version_ge('0.14.0'), \
'The enable_return_routed_experts attribute is not supported. Please upgrade vllm to 0.14.0 or higher'
vllm_engine_kwargs['enable_return_routed_experts'] = True
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add a vLLM version check? The enable_return_routed_experts
parameter is only available in recent versions


engine = GRPOVllmEngine(
args.model_info.model_dir,
torch_dtype=args.torch_dtype,
Expand Down
1 change: 1 addition & 0 deletions swift/megatron/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@
from .parallel_utils import (logical_and_across_model_parallel_group, reduce_max_stat_across_model_parallel_group,
split_cp_inputs)
from .patcher import patch_merge_fn, patch_torch_dist_shard
from .router_replay_utils import *
from .utils import (copy_original_module_weight, forward_step_helper, get_local_layer_specs, get_padding_to,
prepare_mcore_model, tuners_sharded_state_dict)
Loading
Loading