Skip to content

[Feature] Add guided decoding support for speculative decoding#4559

Open
windreamer wants to merge 9 commits intoInternLM:mainfrom
windreamer:mtp-guided
Open

[Feature] Add guided decoding support for speculative decoding#4559
windreamer wants to merge 9 commits intoInternLM:mainfrom
windreamer:mtp-guided

Conversation

@windreamer
Copy link
Copy Markdown
Collaborator

@windreamer windreamer commented Apr 28, 2026

Motivation

Fixes #4551

When speculative decoding and guided decoding (JSON schema / regex / grammar) are both enabled, guided constraints are silently ignored — the GuidedDecodingManager is never propagated into the speculative decoding path. This is a silent correctness issue: no error, no warning, just unconstrained output.

Modification

Core change: propagate & apply grammar mask in spec decode

  1. agent.py — After build_spec_agent(), propagate GuidedDecodingManager to both SpecModelAgent and its proposer.

  2. spec_agent.py — Main integration:

    • _async_model_forward: Fork GrammarMatchers for the draft model from the original guided processors; forked matchers are advanced in-place by get_outputs() at each draft step; originals remain untouched.
    • _rejection_sampling:
      • Decoding path: Apply position-serial grammar mask via _guided_spec_logits_process() — forked matchers provide per-position bitmasks for all num_spec_tokens + 1 target logits. After rejection sampling, accept the final output tokens on original matchers to advance their state correctly.
      • Prefill path: Pass guided_decoding_manager to FusedLogitsProcessor (standard path already handles it).
    • _guided_spec_logits_process(): New method that (1) runs non-grammar logits processing (temperature, penalties), (2) applies per-position grammar bitmasks using forked matchers, advancing each fork with argmax as a greedy approximation, (3) returns processed logits for rejection sampling. The actual sampled token may differ from argmax, but rejection sampling ensures only accepted tokens reach the original matchers.
  3. deepseek_mtp.py — Accept guided_processors in get_outputs(). Apply grammar bitmask to draft logits before argmax, then accept_token on each forked matcher to advance its state for the next draft position.

  4. base.py — Add guided_decoding_manager attribute to BaseSpecProposer (set by SpecModelAgent after construction). Add guided_processors parameter to get_outputs() signature.

  5. eagle3.py — Support guided decoding via draft-to-target bitmask translation. Since Eagle3's draft vocabulary differs from the target vocabulary, a target-vocab grammar mask cannot be applied directly. Instead, _translate_bitmask() converts the target-vocab bitmask into a draft-vocab bitmask using the draft_id_to_target_id mapping, then applies it to draft logits. After argmax + token mapping, accept_token advances each forked matcher's state.

  6. attention/__init__.py, configurations/utils.py, graph_runner.py, attention/fa3.py — Fix speculative decoding on non-SM90 CUDA GPUs: extend FA3 capability check from == 9 (SM90 only) to >= 8 (SM80+, Ampere and above) so that speculative decoding can use FA3's multi-token decode path. Add an early check in CUDAGraphRunner.__init__ that raises a clear RuntimeError when speculative decoding is requested but FA3 is unavailable, instead of crashing in the Triton paged attention kernel.

Helper changes

  • _expand_sampling_inputs / _slice_sampling_inputs: Handle additional SamplingInputs fields (response_formats, session_ctx, etc.) so that guided-decoding–related inputs survive the expand/slice round-trip during rejection sampling.

Tests

  • test_guided_spec_decode.py — Unit tests for _expand_sampling_inputs / _slice_sampling_inputs with guided fields, _guided_spec_logits_process bitmask application, and accept_token state advancement.
  • test_guided_spec_integration.py — Integration tests (require xgrammar + GPU).
  • test_mtp_guided_decoding.py — End-to-end pipeline tests (require xgrammar + GPU).

Docs

  • Updated spec_decoding.md (EN & ZH) with guided decoding usage notes.

BC-breaking (Optional)

None. The guided_processors parameter in get_outputs() defaults to None, so existing proposers that don't override it are unaffected.

Checklist

  • Pre-commit / linting tools pass.
  • Unit tests added for core logic (_guided_spec_logits_process, expand/slice with guided fields).
  • No dependency on downstream version changes.
  • Documentation updated.

Copilot AI review requested due to automatic review settings April 28, 2026 06:58
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds guided decoding (JSON schema / regex / grammar via xgrammar) support to the PyTorch speculative decoding (MTP) path by propagating GuidedDecodingManager into spec decoding and applying grammar bitmasks during both draft proposal and target verification/rejection sampling.

Changes:

  • Propagate GuidedDecodingManager into SpecModelAgent and spec proposers, and apply position-serial grammar masking in spec decode verification.
  • Add draft-side grammar masking support for proposers that share the target vocab (e.g., DeepseekMTP), and a supports_grammar_mask capability flag.
  • Add unit/integration/E2E tests and update EN/ZH docs for guided decoding with speculative decoding.

Reviewed changes

Copilot reviewed 10 out of 10 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
lmdeploy/pytorch/engine/model_agent/agent.py Propagates guided_decoding_manager into the speculative decoding agent and proposer.
lmdeploy/pytorch/spec_decode/spec_agent.py Implements guided masking in spec decode verification, and expands/slices additional SamplingInputs fields.
lmdeploy/pytorch/spec_decode/proposers/base.py Adds supports_grammar_mask and guided_decoding_manager plumb-through; extends get_outputs signature.
lmdeploy/pytorch/spec_decode/proposers/deepseek_mtp.py Applies grammar bitmask to draft logits (when provided) and advances forked matchers.
lmdeploy/pytorch/spec_decode/proposers/eagle3.py Disables draft-side grammar masking via supports_grammar_mask = False.
tests/pytorch/spec_decode/test_guided_spec_decode.py Unit tests for expand/slice behavior and guided-spec decode grammar mechanics.
tests/pytorch/spec_decode/test_guided_spec_integration.py Higher-level integration tests for guided masking + rejection sampling state consistency.
tests/test_lmdeploy/test_mtp_guided_decoding.py GPU integration tests for pipeline + MTP + guided decoding (schema/regex/json_object + streaming).
docs/en/advance/spec_decoding.md Documents guided decoding usage with speculative decoding (EN).
docs/zh_cn/advance/spec_decoding.md Documents guided decoding usage with speculative decoding (ZH).

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread lmdeploy/pytorch/spec_decode/proposers/eagle3.py
Comment thread lmdeploy/pytorch/spec_decode/spec_agent.py Outdated
@windreamer windreamer marked this pull request as draft April 28, 2026 07:28
@windreamer windreamer marked this pull request as ready for review April 28, 2026 10:40
@windreamer windreamer requested a review from Copilot April 28, 2026 10:40
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 15 out of 15 changed files in this pull request and generated 5 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread tests/pytorch/spec_decode/test_guided_spec_integration.py Outdated
Comment thread lmdeploy/pytorch/spec_decode/spec_agent.py
Comment thread lmdeploy/pytorch/spec_decode/proposers/eagle3.py
Comment thread docs/en/advance/spec_decoding.md Outdated
Comment thread docs/zh_cn/advance/spec_decoding.md Outdated
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 15 out of 15 changed files in this pull request and generated 1 comment.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread lmdeploy/pytorch/spec_decode/spec_agent.py Outdated
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 15 out of 15 changed files in this pull request and generated 1 comment.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread lmdeploy/pytorch/spec_decode/proposers/eagle3.py
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 23 out of 23 changed files in this pull request and generated no new comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@lvhan028 lvhan028 requested a review from RunningLeon May 6, 2026 02:43
@lvhan028 lvhan028 added the enhancement New feature or request label May 6, 2026
Comment thread lmdeploy/pytorch/spec_decode/spec_agent.py Outdated
@RunningLeon RunningLeon requested a review from grimoire May 6, 2026 10:50
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 23 out of 23 changed files in this pull request and generated 4 comments.

Comment thread lmdeploy/pytorch/spec_decode/spec_agent.py Outdated
Comment thread lmdeploy/pytorch/configurations/utils.py
Comment thread lmdeploy/pytorch/backends/cuda/attention/__init__.py
Comment thread docs/en/advance/spec_decoding.md Outdated
windreamer and others added 9 commits May 6, 2026 19:40
- Implement position-serial grammar mask via forked GrammarMatchers
- Propagate guided_decoding_manager from ModelAgent to SpecModelAgent
- Apply grammar mask in DeepseekMTP proposer before token selection
- Advance forked matcher state in Eagle3 proposer (no mask due to vocab mismatch)
- Handle grammar state management after rejection sampling
- Expand/slice sampling_inputs for non-tensor fields (response_formats, session_ctx, etc.)
- Consolidate tests: 7 unit + 2 integration tests, 6 GPU e2e tests
- Add MTP + guided decoding usage docs (en/zh_cn)
Eagle3's draft vocabulary differs from the target vocabulary, so a
target-vocab grammar mask is inapplicable to draft logits.  Add
supports_grammar_mask class attribute to BaseSpecProposer (default
True); Eagle3 overrides to False.  spec_agent now gates the fork on
this flag, and Eagle3.get_outputs() no longer accepts or processes
guided_processors.

Co-authored-by: openhands <openhands@all-hands.dev>
…ation

- Eagle3.get_outputs() now applies grammar mask before argmax and
  accept_token after d2t mapping, matching DeepseekMTP pattern
- Add _translate_bitmask() to convert target-vocab bitmask to
  draft-vocab bitmask via scatter_add_ (vectorized, no loops)
- Remove supports_grammar_mask flag; all proposers now support it
- Fork guided processors unconditionally in spec_agent._async_model_forward
- Move session_to_cleanup handling before get_processors in forward_decode
- Bump xgrammar>=0.1.33 (fork() requirement) in all 5 runtime requirements
- Add comprehensive tests: bitmask translation, Eagle3 get_outputs,
  fork independence, multi-step draft loop
…ng on CUDA

- Update use_fa3 capability check from == 9 (SM90 only) to >= 8 (SM80+)
  in attention/__init__.py and configurations/utils.py
- Add FA3 requirement check in graph_runner.py: speculative decoding
  on CUDA now raises a clear error if FA3 is unavailable, instead of
  crashing deep in the Triton paged attention kernel
- Update docstrings/error messages to reflect SM80+ (Ampere) support
…tion

FA3 mha_fwd derives seqlen_k from page_table.shape[1] * page_size
for paged KV without cu_seqlens_k. get_scheduler_metadata must
receive the same value to produce a consistent scheduler layout.

Previously max_seqlen_k was incorrectly set to step_context.
max_kv_seqlen (runtime KV length) in op_backend.py, and
decode_query_len or attn_metadata.max_kv_seqlen in cudagraph.py.
These values differ from what FA3 computes internally, causing
scheduler_metadata to be misaligned with the actual kernel behavior.

- op_backend.py: use block_offsets.size(1) * block_size
- cudagraph.py: use graph_meta.num_blocks * graph_meta.block_size

Both now match FA3 internal: page_table.size(1) * page_size.

Co-authored-by: openhands <openhands@all-hands.dev>
In _guided_spec_logits_process, forked matchers were advanced using
argmax of the masked target logits.  In the non-greedy rejection
sampling path, the actually accepted token can differ from argmax,
causing subsequent grammar masks (especially the bonus-position mask)
to be computed from an incorrect grammar state.

Fix: advance forks using the known draft tokens for positions
0..num_spec_tokens-1.  Target logits are conditioned on draft tokens,
and rejection sampling discards positions after the first rejection,
so the draft-token path is the only reachable one.  The bonus position
needs no advancement — the fork is discarded after the loop.
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 23 out of 23 changed files in this pull request and generated 3 comments.

Comment thread lmdeploy/pytorch/spec_decode/spec_agent.py
Comment thread lmdeploy/pytorch/spec_decode/spec_agent.py
Comment thread lmdeploy/pytorch/spec_decode/proposers/eagle3.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] Speculative decoding does not support guided decoding

4 participants