Skip to content

[Bug] TRTLLMGenFusedMoE + CUTLASS MoE both fail on SM120 (RTX PRO 6000) with FP4 Qwen3-Next MoE in v1.3.0rc4 #11932

@bhaktatejas922

Description

@bhaktatejas922

System Info

  • TensorRT-LLM version: 1.3.0rc4 (Docker: nvcr.io/nvidia/tensorrt-llm/release:1.3.0rc4)
  • GPU: NVIDIA RTX PRO 6000 Blackwell Server Edition (SM120, 96GB GDDR7) x4
  • Driver: 570.211.01
  • OS: Ubuntu 24.04 (GCP g4-standard-192)
  • Model: Qwen3-Next 80B MoE (~3B active), NVFp4 quantized

Problem

Neither the TRTLLM nor CUTLASS MoE backend works on SM120 GPUs with FP4 MoE models. This model runs fine on B200 (SM100).

Error 1: moe_config.backend: TRTLLM

NotImplementedError: TRTLLMGenFusedMoE does not support SM120 and above.
[TRT-LLM] [E] Failed to initialize executor on rank 0: TRTLLMGenFusedMoE does not support SM120 and above.

Error 2: moe_config.backend: CUTLASS

[TRT-LLM] [W] [Autotuner] Failed when profiling runner=MoERunner, tactic=6
Error: [TensorRT-LLM][ERROR] Assertion failed: Failed to initialize cutlass TMA WS grouped gemm.
Error: Error Internal (cutlass_kernel_file_gemm_grouped_sm120_M128_BS_group2.generated.cu:39)

Total failed profiling tactics: 16 for custom_op=trtllm::fused_moe::gemm2

The CUTLASS backend loads all 1412 weights (~46s), runs the autotuner, but then the executor worker crashes with RuntimeError: Executor worker returned error.

Serving Config

tensor_parallel_size: 1
moe_expert_parallel_size: 1
max_batch_size: 4
max_num_tokens: 32768
max_seq_len: 32768
trust_remote_code: true
moe_config:
  backend: CUTLASS  # also tried TRTLLM
kv_cache_config:
  free_gpu_memory_fraction: 0.85

Context

Questions

  1. Is SM120 FP4 MoE expected to work in 1.3.0rc4, or is this a known regression?
  2. If not yet supported, which upcoming release will include SM120 FP4 MoE support?
  3. Are there any workarounds (e.g., specific quantization format, env flags, or config changes)?

Steps to Reproduce

# On a g4-standard-192 (4x RTX PRO 6000 SM120)
docker run --gpus '"device=0"' --shm-size=16g \
  -v /path/to/nvfp4_checkpoint:/workspace/model:ro \
  nvcr.io/nvidia/tensorrt-llm/release:1.3.0rc4 \
  trtllm-serve /workspace/model --host 0.0.0.0 --port 8000 \
  --extra_llm_api_options serving.yaml

With the serving config above.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Customized kernels<NV>Specialized/modified CUDA kernels in TRTLLM for LLM ops, beyond standard TRT. Dev & perf.Model optimization<NV>Model-specific performance optimizations and tuning

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions