[Feature] Add guided decoding support for speculative decoding#4559
[Feature] Add guided decoding support for speculative decoding#4559windreamer wants to merge 9 commits intoInternLM:mainfrom
Conversation
There was a problem hiding this comment.
Pull request overview
Adds guided decoding (JSON schema / regex / grammar via xgrammar) support to the PyTorch speculative decoding (MTP) path by propagating GuidedDecodingManager into spec decoding and applying grammar bitmasks during both draft proposal and target verification/rejection sampling.
Changes:
- Propagate
GuidedDecodingManagerintoSpecModelAgentand spec proposers, and apply position-serial grammar masking in spec decode verification. - Add draft-side grammar masking support for proposers that share the target vocab (e.g.,
DeepseekMTP), and asupports_grammar_maskcapability flag. - Add unit/integration/E2E tests and update EN/ZH docs for guided decoding with speculative decoding.
Reviewed changes
Copilot reviewed 10 out of 10 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
lmdeploy/pytorch/engine/model_agent/agent.py |
Propagates guided_decoding_manager into the speculative decoding agent and proposer. |
lmdeploy/pytorch/spec_decode/spec_agent.py |
Implements guided masking in spec decode verification, and expands/slices additional SamplingInputs fields. |
lmdeploy/pytorch/spec_decode/proposers/base.py |
Adds supports_grammar_mask and guided_decoding_manager plumb-through; extends get_outputs signature. |
lmdeploy/pytorch/spec_decode/proposers/deepseek_mtp.py |
Applies grammar bitmask to draft logits (when provided) and advances forked matchers. |
lmdeploy/pytorch/spec_decode/proposers/eagle3.py |
Disables draft-side grammar masking via supports_grammar_mask = False. |
tests/pytorch/spec_decode/test_guided_spec_decode.py |
Unit tests for expand/slice behavior and guided-spec decode grammar mechanics. |
tests/pytorch/spec_decode/test_guided_spec_integration.py |
Higher-level integration tests for guided masking + rejection sampling state consistency. |
tests/test_lmdeploy/test_mtp_guided_decoding.py |
GPU integration tests for pipeline + MTP + guided decoding (schema/regex/json_object + streaming). |
docs/en/advance/spec_decoding.md |
Documents guided decoding usage with speculative decoding (EN). |
docs/zh_cn/advance/spec_decoding.md |
Documents guided decoding usage with speculative decoding (ZH). |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 15 out of 15 changed files in this pull request and generated 5 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 15 out of 15 changed files in this pull request and generated 1 comment.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 15 out of 15 changed files in this pull request and generated 1 comment.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
84eac20 to
bb48caf
Compare
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 23 out of 23 changed files in this pull request and generated no new comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
- Implement position-serial grammar mask via forked GrammarMatchers - Propagate guided_decoding_manager from ModelAgent to SpecModelAgent - Apply grammar mask in DeepseekMTP proposer before token selection - Advance forked matcher state in Eagle3 proposer (no mask due to vocab mismatch) - Handle grammar state management after rejection sampling - Expand/slice sampling_inputs for non-tensor fields (response_formats, session_ctx, etc.) - Consolidate tests: 7 unit + 2 integration tests, 6 GPU e2e tests - Add MTP + guided decoding usage docs (en/zh_cn)
Eagle3's draft vocabulary differs from the target vocabulary, so a target-vocab grammar mask is inapplicable to draft logits. Add supports_grammar_mask class attribute to BaseSpecProposer (default True); Eagle3 overrides to False. spec_agent now gates the fork on this flag, and Eagle3.get_outputs() no longer accepts or processes guided_processors. Co-authored-by: openhands <openhands@all-hands.dev>
…ation - Eagle3.get_outputs() now applies grammar mask before argmax and accept_token after d2t mapping, matching DeepseekMTP pattern - Add _translate_bitmask() to convert target-vocab bitmask to draft-vocab bitmask via scatter_add_ (vectorized, no loops) - Remove supports_grammar_mask flag; all proposers now support it - Fork guided processors unconditionally in spec_agent._async_model_forward - Move session_to_cleanup handling before get_processors in forward_decode - Bump xgrammar>=0.1.33 (fork() requirement) in all 5 runtime requirements - Add comprehensive tests: bitmask translation, Eagle3 get_outputs, fork independence, multi-step draft loop
…ng on CUDA - Update use_fa3 capability check from == 9 (SM90 only) to >= 8 (SM80+) in attention/__init__.py and configurations/utils.py - Add FA3 requirement check in graph_runner.py: speculative decoding on CUDA now raises a clear error if FA3 is unavailable, instead of crashing deep in the Triton paged attention kernel - Update docstrings/error messages to reflect SM80+ (Ampere) support
…tion FA3 mha_fwd derives seqlen_k from page_table.shape[1] * page_size for paged KV without cu_seqlens_k. get_scheduler_metadata must receive the same value to produce a consistent scheduler layout. Previously max_seqlen_k was incorrectly set to step_context. max_kv_seqlen (runtime KV length) in op_backend.py, and decode_query_len or attn_metadata.max_kv_seqlen in cudagraph.py. These values differ from what FA3 computes internally, causing scheduler_metadata to be misaligned with the actual kernel behavior. - op_backend.py: use block_offsets.size(1) * block_size - cudagraph.py: use graph_meta.num_blocks * graph_meta.block_size Both now match FA3 internal: page_table.size(1) * page_size. Co-authored-by: openhands <openhands@all-hands.dev>
… device-to-host sync
In _guided_spec_logits_process, forked matchers were advanced using argmax of the masked target logits. In the non-greedy rejection sampling path, the actually accepted token can differ from argmax, causing subsequent grammar masks (especially the bonus-position mask) to be computed from an incorrect grammar state. Fix: advance forks using the known draft tokens for positions 0..num_spec_tokens-1. Target logits are conditioned on draft tokens, and rejection sampling discards positions after the first rejection, so the draft-token path is the only reachable one. The bonus position needs no advancement — the fork is discarded after the loop.
Motivation
Fixes #4551
When speculative decoding and guided decoding (JSON schema / regex / grammar) are both enabled, guided constraints are silently ignored — the
GuidedDecodingManageris never propagated into the speculative decoding path. This is a silent correctness issue: no error, no warning, just unconstrained output.Modification
Core change: propagate & apply grammar mask in spec decode
agent.py— Afterbuild_spec_agent(), propagateGuidedDecodingManagerto bothSpecModelAgentand itsproposer.spec_agent.py— Main integration:_async_model_forward: ForkGrammarMatchers for the draft model from the original guided processors; forked matchers are advanced in-place byget_outputs()at each draft step; originals remain untouched._rejection_sampling:_guided_spec_logits_process()— forked matchers provide per-position bitmasks for allnum_spec_tokens + 1target logits. After rejection sampling, accept the final output tokens on original matchers to advance their state correctly.guided_decoding_managertoFusedLogitsProcessor(standard path already handles it)._guided_spec_logits_process(): New method that (1) runs non-grammar logits processing (temperature, penalties), (2) applies per-position grammar bitmasks using forked matchers, advancing each fork with argmax as a greedy approximation, (3) returns processed logits for rejection sampling. The actual sampled token may differ from argmax, but rejection sampling ensures only accepted tokens reach the original matchers.deepseek_mtp.py— Acceptguided_processorsinget_outputs(). Apply grammar bitmask to draft logits beforeargmax, thenaccept_tokenon each forked matcher to advance its state for the next draft position.base.py— Addguided_decoding_managerattribute toBaseSpecProposer(set bySpecModelAgentafter construction). Addguided_processorsparameter toget_outputs()signature.eagle3.py— Support guided decoding via draft-to-target bitmask translation. Since Eagle3's draft vocabulary differs from the target vocabulary, a target-vocab grammar mask cannot be applied directly. Instead,_translate_bitmask()converts the target-vocab bitmask into a draft-vocab bitmask using thedraft_id_to_target_idmapping, then applies it to draft logits. Afterargmax+ token mapping,accept_tokenadvances each forked matcher's state.attention/__init__.py,configurations/utils.py,graph_runner.py,attention/fa3.py— Fix speculative decoding on non-SM90 CUDA GPUs: extend FA3 capability check from== 9(SM90 only) to>= 8(SM80+, Ampere and above) so that speculative decoding can use FA3's multi-token decode path. Add an early check inCUDAGraphRunner.__init__that raises a clearRuntimeErrorwhen speculative decoding is requested but FA3 is unavailable, instead of crashing in the Triton paged attention kernel.Helper changes
_expand_sampling_inputs/_slice_sampling_inputs: Handle additionalSamplingInputsfields (response_formats,session_ctx, etc.) so that guided-decoding–related inputs survive the expand/slice round-trip during rejection sampling.Tests
test_guided_spec_decode.py— Unit tests for_expand_sampling_inputs/_slice_sampling_inputswith guided fields,_guided_spec_logits_processbitmask application, andaccept_tokenstate advancement.test_guided_spec_integration.py— Integration tests (require xgrammar + GPU).test_mtp_guided_decoding.py— End-to-end pipeline tests (require xgrammar + GPU).Docs
spec_decoding.md(EN & ZH) with guided decoding usage notes.BC-breaking (Optional)
None. The
guided_processorsparameter inget_outputs()defaults toNone, so existing proposers that don't override it are unaffected.Checklist
_guided_spec_logits_process, expand/slice with guided fields).