[Feature] Add Multi-Token Prediction (MTP) module implementation#1572
[Feature] Add Multi-Token Prediction (MTP) module implementation#1572HAOCHENYE wants to merge 3 commits intoInternLM:mainfrom
Conversation
|
@claude review |
xtuner/v1/model/moe/qwen3vl_text.py
Outdated
| def _forward( # type: ignore[override] | ||
| self, | ||
| seq_ctx: SequenceContext, # todo(@yehaochen): support intra layer micro-batch | ||
| loss_ctx: CELossContext | None, | ||
| return_router_logits: bool = False, |
There was a problem hiding this comment.
Claude: Critical: When deepstack_visual_embeds is not None, this code falls through to the original _forward body which still references self.balancing_loss and self.z_loss. These attributes were removed from MoE.__init__ in commit 1 of this PR, so this will raise AttributeError at runtime for any VL model with deepstack visual embeds.
Additionally, loss_ctx here is typed as CELossContext | None, but after commit 1 the parent's _forward now expects MoELossContextDict | None (a dict). When deepstack_visual_embeds is not None, line 186 passes loss_ctx (a dict) directly to self.lm_head(hidden_states, loss_ctx) which expects LMHeadLossContext, not a dict. This will also fail at runtime.
This entire _forward override needs to be updated to work with the new dict-based loss context API and the loss-context-based auxiliary losses (instead of self.balancing_loss / self.z_loss).
xtuner/v1/model/moe/moe.py
Outdated
| embed_tokens_fn=self.embed_tokens, | ||
| position_embeddings=position_embeddings_list[micro_idx], | ||
| seq_ctx=seq_ctx, | ||
| ) | ||
|
|
||
| # Compute MTP losses for each depth | ||
| mtp_losses = [] | ||
| for mtp_hidden, mtp_ctx in zip(mtp_outputs, mtp_ctx_list): |
There was a problem hiding this comment.
Claude: Critical: In _forward, MTP labels are properly shifted for each depth via roll_packed_tensor(shifted_tensor, seq_ctx.cu_seq_lens_k, -idx - 1, dim=-1). But here in _micro_batch_forward, no such shifting is applied — the MTP loss contexts are used as-is. This means all MTP depths would compute loss against the same (unshifted) labels, which is incorrect.
The same roll_packed_tensor logic needs to be applied here for each MTP depth before computing the loss.
| Example: | ||
| For packed sequences [1,2,3] and [4,5,6] with shifts=-1: | ||
| Original input_ids: [1, 2, 3, 4, 5, 6] | ||
| Rolled input_ids: [2, 3, 0, 5, 6, 0] | ||
| Original position_ids: [0, 1, 2, 0, 1, 2] | ||
| Rolled position_ids: [1, 2, 0, 1, 2, 0] | ||
| """ |
There was a problem hiding this comment.
Claude: Critical: roll_sequence_context mutates the input seq_ctx in-place by reassigning seq_ctx.input_ids and seq_ctx.inputs_embeds. In _forward, this is safe because mtp_seq_ctx is a clone. But in _micro_batch_forward, the seq_ctx from the micro-batch loop is passed directly into MTPBlock.forward → roll_sequence_context without cloning, which corrupts the original training data.
Either:
- This function should create and return a new
SequenceContextinstead of mutating the input, or _micro_batch_forwardneeds to cloneseq_ctxbefore passing it to MTP, similar to what_forwarddoes.
Also, the docstring says "Both input_ids and position_ids are rolled" but position_ids is never actually rolled in the implementation.
| seq_ctx = data["seq_ctx"].to(DEVICE) | ||
| if self.sp_mesh.size() > 1: | ||
| seq_ctx = seq_ctx.split(sequence_parallel_mesh=self.sp_mesh) | ||
| seq_ctx_list.append(seq_ctx) | ||
| loss_ctx = loss_cfg.build(shifted_labels=data["shifted_labels"], sp_mesh=self.sp_mesh) | ||
| loss_ctx_list.append(loss_ctx) | ||
|
|
||
| # 2. Compute cu_seq_lens_list (for calibration) | ||
| # 3. Call model's interface to build and calibrate all loss_ctx (done in one shot) | ||
| loss_ctx_dict_list = self._engine.model.build_loss_ctx_batch(data_batch, sp_mesh=self.sp_mesh) |
There was a problem hiding this comment.
Claude: Warning: loss_cfg is guaranteed to be non-None at this point (it's set to CELossConfig() three lines above if it was None). So the if loss_cfg is not None check always passes and the deprecation warning always fires, even if the user correctly sets lm_loss_cfg directly in the model config.
Consider only warning when loss_cfg was explicitly provided by the user (i.e., when the original loss_cfg argument was not None before the default assignment).
| shape [batch, seq_len, hidden_size]. | ||
| position_embeddings (tuple[torch.Tensor, torch.Tensor]): Rotary position | ||
| embeddings (cos, sin). |
There was a problem hiding this comment.
Claude: Warning: The return type annotation says tuple[torch.Tensor, torch.Tensor, torch.Tensor] but the docstring (lines 96-97) says:
Returns:
torch.Tensor: Output hidden states
The docstring should document the full 3-tuple return (hidden_states, router_weights, router_results).
| # MTP needs to shift labesl multiple times. Since rebuild the `shifted_labels` in data_batch | ||
| mtp_loss_ctx_list = self._build_loss_ctx(self.config.lm_loss_cfg, _data_batch, sp_mesh) | ||
| if mtp_loss_ctx_list is not None: | ||
| loss_ctx_cls = mtp_loss_ctx_list[0].__class__ |
There was a problem hiding this comment.
Claude: Nit: Typo: "labesl" → "labels"
| loss_ctx_cls = mtp_loss_ctx_list[0].__class__ | |
| # MTP needs to shift labels multiple times. Since rebuild the `shifted_labels` in data_batch |
xtuner/v1/model/moe/moe.py
Outdated
| # Check if any micro-batch has MTP loss contexts | ||
| mtp_loss_ctx_lists = [loss_ctx_dict.get("mtp") for loss_ctx_dict in loss_ctx_list] | ||
| if any(ctx is not None for ctx in mtp_loss_ctx_lists): | ||
| # Verify all micro-batches have input_ids |
There was a problem hiding this comment.
Claude: Warning: MTPBlock.forward returns list[torch.Tensor] per its type annotation, but MTPLayer.forward actually returns tuple[torch.Tensor, torch.Tensor, torch.Tensor] (hidden_states, router_weights, router_results). So mtp_outputs is actually a list[tuple[...]], not list[torch.Tensor].
This is correctly handled downstream in _forward (line 741: hidden_states, router_results, router_weights = mtp_hidden), but the type annotations in MTPBlock.forward are misleading and should be updated to match.
|
Claude: ## Summary This PR adds Multi-Token Prediction (MTP) support and refactors the loss context API from a single-object to a dict-based design, enabling multiple loss types per model (LM, balancing, z-loss, MTP). It also renames IssuesCritical
Warning
Nit
VerdictREQUEST_CHANGES |
|
@claude Agree with your comment, please update the fork branch as you say. Please make sure the modified code could pass the lint |
|
@claude Almost Agree with your comment, except for the please update the fork branch as you say. Please make sure the modified code could pass the lint Note, THIS IS A STACK PR, please keep the linear commit histroy (3 commit) |
e6be73b to
271c089
Compare
Change loss context from single object to dict-based API: - Update loss_cfg.build() to accept data parameter as dict - Change ModelItem.loss_ctx to dict with loss type keys (e.g. 'lm') - Update model forward pass to accept loss_ctx_dict parameter - Update all tests to use new dict-based loss context API ghstack-source-id: e27a1aa Pull-Request: InternLM#1569
ghstack-source-id: 2d84ad6 Pull-Request: InternLM#1570
…s context base class - Rename CELossContext to LMHeadLossContext for better semantic clarity - Refactor BaseLossContext to be more abstract by removing LM-specific logic - Move eager_mode and chunk_mode implementations from base class to LMHeadLossContext - Make loss_ctx_cls and _loss_kwargs_cls abstract properties in BaseLossConfig - Remove sp_split() and to() implementations from BaseLossKwargs base class - Move sp_split() and to() to CELossKwargs subclass - Update BaseRLLossKwargs to properly inherit and extend sp_split() and to() methods - Add deprecation alias: CELossContext = LMHeadLossContext for backward compatibility - Export LMHeadLossContext in __init__.py ghstack-source-id: 67744a7 Pull-Request: InternLM#1571
No description provided.