Skip to content

[Feature] Add Multi-Token Prediction (MTP) module implementation#1572

Open
HAOCHENYE wants to merge 3 commits intoInternLM:mainfrom
HAOCHENYE:mtp
Open

[Feature] Add Multi-Token Prediction (MTP) module implementation#1572
HAOCHENYE wants to merge 3 commits intoInternLM:mainfrom
HAOCHENYE:mtp

Conversation

@HAOCHENYE
Copy link
Collaborator

No description provided.

@HAOCHENYE
Copy link
Collaborator Author

@claude review

Comment on lines 112 to 116
def _forward( # type: ignore[override]
self,
seq_ctx: SequenceContext, # todo(@yehaochen): support intra layer micro-batch
loss_ctx: CELossContext | None,
return_router_logits: bool = False,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Critical: When deepstack_visual_embeds is not None, this code falls through to the original _forward body which still references self.balancing_loss and self.z_loss. These attributes were removed from MoE.__init__ in commit 1 of this PR, so this will raise AttributeError at runtime for any VL model with deepstack visual embeds.

Additionally, loss_ctx here is typed as CELossContext | None, but after commit 1 the parent's _forward now expects MoELossContextDict | None (a dict). When deepstack_visual_embeds is not None, line 186 passes loss_ctx (a dict) directly to self.lm_head(hidden_states, loss_ctx) which expects LMHeadLossContext, not a dict. This will also fail at runtime.

This entire _forward override needs to be updated to work with the new dict-based loss context API and the loss-context-based auxiliary losses (instead of self.balancing_loss / self.z_loss).

Comment on lines +536 to +543
embed_tokens_fn=self.embed_tokens,
position_embeddings=position_embeddings_list[micro_idx],
seq_ctx=seq_ctx,
)

# Compute MTP losses for each depth
mtp_losses = []
for mtp_hidden, mtp_ctx in zip(mtp_outputs, mtp_ctx_list):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Critical: In _forward, MTP labels are properly shifted for each depth via roll_packed_tensor(shifted_tensor, seq_ctx.cu_seq_lens_k, -idx - 1, dim=-1). But here in _micro_batch_forward, no such shifting is applied — the MTP loss contexts are used as-is. This means all MTP depths would compute loss against the same (unshifted) labels, which is incorrect.

The same roll_packed_tensor logic needs to be applied here for each MTP depth before computing the loss.

Comment on lines +93 to +99
Example:
For packed sequences [1,2,3] and [4,5,6] with shifts=-1:
Original input_ids: [1, 2, 3, 4, 5, 6]
Rolled input_ids: [2, 3, 0, 5, 6, 0]
Original position_ids: [0, 1, 2, 0, 1, 2]
Rolled position_ids: [1, 2, 0, 1, 2, 0]
"""
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Critical: roll_sequence_context mutates the input seq_ctx in-place by reassigning seq_ctx.input_ids and seq_ctx.inputs_embeds. In _forward, this is safe because mtp_seq_ctx is a clone. But in _micro_batch_forward, the seq_ctx from the micro-batch loop is passed directly into MTPBlock.forwardroll_sequence_context without cloning, which corrupts the original training data.

Either:

  1. This function should create and return a new SequenceContext instead of mutating the input, or
  2. _micro_batch_forward needs to clone seq_ctx before passing it to MTP, similar to what _forward does.

Also, the docstring says "Both input_ids and position_ids are rolled" but position_ids is never actually rolled in the implementation.

Comment on lines +793 to +800
seq_ctx = data["seq_ctx"].to(DEVICE)
if self.sp_mesh.size() > 1:
seq_ctx = seq_ctx.split(sequence_parallel_mesh=self.sp_mesh)
seq_ctx_list.append(seq_ctx)
loss_ctx = loss_cfg.build(shifted_labels=data["shifted_labels"], sp_mesh=self.sp_mesh)
loss_ctx_list.append(loss_ctx)

# 2. Compute cu_seq_lens_list (for calibration)
# 3. Call model's interface to build and calibrate all loss_ctx (done in one shot)
loss_ctx_dict_list = self._engine.model.build_loss_ctx_batch(data_batch, sp_mesh=self.sp_mesh)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Warning: loss_cfg is guaranteed to be non-None at this point (it's set to CELossConfig() three lines above if it was None). So the if loss_cfg is not None check always passes and the deprecation warning always fires, even if the user correctly sets lm_loss_cfg directly in the model config.

Consider only warning when loss_cfg was explicitly provided by the user (i.e., when the original loss_cfg argument was not None before the default assignment).

Comment on lines +95 to +97
shape [batch, seq_len, hidden_size].
position_embeddings (tuple[torch.Tensor, torch.Tensor]): Rotary position
embeddings (cos, sin).
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Warning: The return type annotation says tuple[torch.Tensor, torch.Tensor, torch.Tensor] but the docstring (lines 96-97) says:

Returns:
torch.Tensor: Output hidden states

The docstring should document the full 3-tuple return (hidden_states, router_weights, router_results).

# MTP needs to shift labesl multiple times. Since rebuild the `shifted_labels` in data_batch
mtp_loss_ctx_list = self._build_loss_ctx(self.config.lm_loss_cfg, _data_batch, sp_mesh)
if mtp_loss_ctx_list is not None:
loss_ctx_cls = mtp_loss_ctx_list[0].__class__
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Nit: Typo: "labesl" → "labels"

Suggested change
loss_ctx_cls = mtp_loss_ctx_list[0].__class__
# MTP needs to shift labels multiple times. Since rebuild the `shifted_labels` in data_batch

# Check if any micro-batch has MTP loss contexts
mtp_loss_ctx_lists = [loss_ctx_dict.get("mtp") for loss_ctx_dict in loss_ctx_list]
if any(ctx is not None for ctx in mtp_loss_ctx_lists):
# Verify all micro-batches have input_ids
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Warning: MTPBlock.forward returns list[torch.Tensor] per its type annotation, but MTPLayer.forward actually returns tuple[torch.Tensor, torch.Tensor, torch.Tensor] (hidden_states, router_weights, router_results). So mtp_outputs is actually a list[tuple[...]], not list[torch.Tensor].

This is correctly handled downstream in _forward (line 741: hidden_states, router_results, router_weights = mtp_hidden), but the type annotations in MTPBlock.forward are misleading and should be updated to match.

@claude
Copy link

claude bot commented Mar 12, 2026

Claude: ## Summary

This PR adds Multi-Token Prediction (MTP) support and refactors the loss context API from a single-object to a dict-based design, enabling multiple loss types per model (LM, balancing, z-loss, MTP). It also renames CELossContextLMHeadLossContext for semantic clarity and makes BaseLossContext more abstract.

Issues

Critical

  • [xtuner/v1/model/moe/qwen3vl_text.py:112-221] When deepstack_visual_embeds is not None, the _forward override still references removed self.balancing_loss/self.z_loss attributes and passes a dict-typed loss_ctx directly to self.lm_head (which expects LMHeadLossContext). This will crash at runtime for VL models with deepstack.
  • [xtuner/v1/model/moe/moe.py:536-543] In _micro_batch_forward, MTP labels are not shifted per depth (unlike _forward which uses roll_packed_tensor). All MTP depths compute loss against the same unshifted labels.
  • [xtuner/v1/module/mtp/utils.py:93-99] roll_sequence_context mutates the input seq_ctx in-place. Safe in _forward (which clones), but corrupts training data in _micro_batch_forward (which doesn't clone).

Warning

  • [xtuner/v1/train/trainer.py:793-800] _resolve_model_loss_cfg always fires the deprecation warning since loss_cfg is guaranteed non-None at the call site.
  • [xtuner/v1/module/mtp/mtp_layer.py:95-97] Docstring says return type is torch.Tensor but actual return is tuple[torch.Tensor, torch.Tensor, torch.Tensor].
  • [xtuner/v1/model/moe/moe.py:527] MTPBlock.forward return type annotation says list[torch.Tensor] but actual elements are 3-tuples.

Nit

  • [xtuner/v1/model/moe/moe.py:332] Typo: "labesl" → "labels"

Verdict

REQUEST_CHANGES

@HAOCHENYE
Copy link
Collaborator Author

@claude Agree with your comment, please update the fork branch as you say. Please make sure the modified code could pass the lint

@HAOCHENYE
Copy link
Collaborator Author

@claude Almost Agree with your comment, except for the _micro_batch_forward one. I do not plan to support mtp in _micro_batch_forawrd, please raise an NotImplementedErrror instead.

please update the fork branch as you say. Please make sure the modified code could pass the lint

Note, THIS IS A STACK PR, please keep the linear commit histroy (3 commit)

@HAOCHENYE HAOCHENYE force-pushed the mtp branch 3 times, most recently from e6be73b to 271c089 Compare March 16, 2026 15:57
Change loss context from single object to dict-based API:
- Update loss_cfg.build() to accept data parameter as dict
- Change ModelItem.loss_ctx to dict with loss type keys (e.g. 'lm')
- Update model forward pass to accept loss_ctx_dict parameter
- Update all tests to use new dict-based loss context API


ghstack-source-id: e27a1aa
Pull-Request: InternLM#1569
…s context base class

- Rename CELossContext to LMHeadLossContext for better semantic clarity
- Refactor BaseLossContext to be more abstract by removing LM-specific logic
- Move eager_mode and chunk_mode implementations from base class to LMHeadLossContext
- Make loss_ctx_cls and _loss_kwargs_cls abstract properties in BaseLossConfig
- Remove sp_split() and to() implementations from BaseLossKwargs base class
- Move sp_split() and to() to CELossKwargs subclass
- Update BaseRLLossKwargs to properly inherit and extend sp_split() and to() methods
- Add deprecation alias: CELossContext = LMHeadLossContext for backward compatibility
- Export LMHeadLossContext in __init__.py


ghstack-source-id: 67744a7
Pull-Request: InternLM#1571
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant