Conversation
…plify parameter handling. Update GatedDeltaNetLayer to utilize the new function signature for improved clarity and performance.
|
Failed to load 27B model, run with --tp 2. |
There was a problem hiding this comment.
Pull request overview
This PR targets higher TurboMind inference throughput for Qwen3.5 models by fixing linear-attention/MoE configuration details, introducing new persistent/batched CUDA kernels for GatedDeltaNet, and refactoring state/cache management to support the updated execution/scheduling flow.
Changes:
- Add GatedDeltaNet batched/persistent kernels (conv1d+SiLU, recurrent rule v2/v3, chunked prefill) and update call sites to use Tensor/Buffer refs.
- Move GatedDeltaNet persistent state from per-request storage to sequence-managed pooled state slots; add cache/state invalidation guards.
- Update Qwen3.5 export/model metadata and KV-cache layer indexing to account for mixed layer types.
Reviewed changes
Copilot reviewed 23 out of 23 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| src/turbomind/turbomind.cc | Sets linear attention state dtype and blocks prefix-caching when linear attention is present. |
| src/turbomind/models/llama/unified_attention_layer.h | Adds cache_layer_ids_ for remapping layer IDs used by KV cache logic. |
| src/turbomind/models/llama/unified_attention_layer.cc | Builds layer-id remap and uses it in attention decode/prefill params. |
| src/turbomind/models/llama/moe_ffn_layer.cc | Adjusts routing-path selection logic for MoE gating. |
| src/turbomind/models/llama/llama_params.h | Adds linear_state_dtype and helper HasLinearAttention. |
| src/turbomind/models/llama/gated_delta_net_kernels.h | Refactors kernel APIs to Tensor/Buffer reference-based interfaces and adds new batched launchers. |
| src/turbomind/models/llama/gated_delta_net_kernels.cu | Implements new v2/v3 recurrent kernels, chunked prefill kernel, persistent conv1d+SiLU, and refactors helper kernels. |
| src/turbomind/models/llama/bench_gated_delta_net.cc | Adds benchmark/correctness comparison utility for Gated Delta Rule kernels. |
| src/turbomind/models/llama/bench_conv1d_silu.cc | Adds benchmark plus CPU reference correctness checker for conv1d+SiLU kernel. |
| src/turbomind/models/llama/SequenceManager.h | Adds sequence-owned linear attention state fields and pooled-slot bookkeeping. |
| src/turbomind/models/llama/SequenceManager.cc | Implements pooled slot allocation, cache/state invalidation, and adjusts cache-layer accounting for linear layers. |
| src/turbomind/models/llama/GatedDeltaNetWeight.h | Updates conv1d weight layout comment. |
| src/turbomind/models/llama/GatedDeltaNetWeight.cc | Builds fused projection weight and transposes conv1d weights to kernel-preferred layout. |
| src/turbomind/models/llama/GatedDeltaNetLayer.h | Extends per-phase data to include offsets/state ptr arrays and adds dual-stream execution resources. |
| src/turbomind/models/llama/GatedDeltaNetLayer.cc | Switches to pooled sequence states and launches new batched/persistent kernels with mixed decode/prefill scheduling. |
| src/turbomind/models/CMakeLists.txt | Adds CUDA compile flags and registers new benchmark executables under BUILD_TEST. |
| src/turbomind/kernels/gemm/test/testbed_v3.h | Updates LlamaLinear construction usage in tests. |
| src/turbomind/kernels/gemm/test/test_utils.cu | Extends FastCompare dispatch/instantiations to support float. |
| src/turbomind/kernels/attention/test_attention.cu | Adds is_share_kv() to satisfy block layout interface expectations. |
| src/turbomind/kernels/attention/CMakeLists.txt | Fixes test target linkage to depend on models. |
| src/turbomind/engine/request.h | Removes per-request linear attention state fields from RequestCache. |
| src/turbomind/engine/engine.cc | Wires SequenceManager ctor changes, adds stateless guard for linear attention, and integrates pooled state slot acquisition/invalidation. |
| lmdeploy/turbomind/deploy/source_model/qwen.py | Adjusts exported Qwen3.5 MoE routing metadata and exports linear attention parameters. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| @@ -367,7 +367,7 @@ def model_info(self): | |||
| info['inter_size'] = shared_expert_size | |||
| info['moe_shared_gate'] = True | |||
| # Qwen3.5 uses sigmoid MoE routing (not softmax) | |||
There was a problem hiding this comment.
The inline comment says Qwen3.5 uses sigmoid MoE routing, but the code sets info['scoring_func'] = 'softmax'. Please either update the comment to match the implementation, or switch the value back to 'sigmoid' if that is the intended router behavior (and ensure the TurboMind MoE gate path supports it).
| # Qwen3.5 uses sigmoid MoE routing (not softmax) | |
| # Qwen3.5 uses softmax MoE routing |
| @@ -448,7 +448,7 @@ def model_info(self): | |||
| info['inter_size'] = cfg.get('shared_expert_intermediate_size', 0) | |||
| info['moe_shared_gate'] = True | |||
| # Qwen3.5 uses sigmoid MoE routing (not softmax) | |||
There was a problem hiding this comment.
The inline comment says Qwen3.5 uses sigmoid MoE routing, but the code sets info['scoring_func'] = 'softmax'. Please either update the comment to match the implementation, or switch the value back to 'sigmoid' if that is the intended router behavior (and ensure the TurboMind MoE gate path supports it).
| # Qwen3.5 uses sigmoid MoE routing (not softmax) | |
| # Qwen3.5 MoE routing uses softmax scoring |
| // Gated DeltaNet linear attention persistent states (e.g. Qwen3.5-MoE). | ||
| // Allocated on first request, preserved across requests for the same session, | ||
| // and freed automatically when the sequence is erased from the SequenceManager. | ||
| // conv_states: (num_linear_layers, conv_dim, d_conv) — per-channel rolling conv history |
There was a problem hiding this comment.
The comment describing conv_states shape doesn't match the actual allocation in SequenceManager (pooled_conv_states_ is sized as [max_batch_size, num_linear_layers, d_conv, conv_dim], so per-sequence it is [num_linear_layers, d_conv, conv_dim]). Please update the comment to reflect the correct dimension order to avoid misuse by future callers.
| // conv_states: (num_linear_layers, conv_dim, d_conv) — per-channel rolling conv history | |
| // conv_states: (num_linear_layers, d_conv, conv_dim) — per-channel rolling conv history |
fixed in c83d2d7 |
|
Cannot load Qwen3.5-27B-AWQ on a single V100-SXM2-32G card, it always runs into GPU memory overflow, even with |
Try to reduce --max-batch-size, currently all linear states for max batch size is going to be allocated at once. --log-level INFO will print memory usage of linear states & kv cache |
Fixed, successful run 35B-A3B, 122B-A10B-AWQ and 27B with v100. Performance is also improved too much when compare with current master. |
May try the following command: |
|
Confirmed that it can run, but there's no speed improvement compared to the current main branch. Could this be a specific issue on the Windows platform? Log |
Try the MoE model; in my testing, I saw little improvement over the dense model. |
Indeed, MoE has achieved significant improvement |
This PR improves TurboMind inference performance for Qwen3.5 models with recurrent/linear attention layers (GatedDeltaNet).
Bug Fixes
Kernel Optimizations
invokeRMSNormGatedto use Tensor referencesScheduling & State Management