Skip to content

Conversation

@grimoire
Copy link
Collaborator

fill kv add scale fmt

@grimoire grimoire changed the title Fix dsv32 Moe Reduce kernel Dec 23, 2025
@lvhan028
Copy link
Collaborator

May resolve the conflict

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces a new MOE reduce kernel for better performance and adds support for scale format configuration in FP8 KV cache operations. The main changes involve refactoring weight multiplication in MOE operations from inline kernel operations to a dedicated reduction kernel, and adding scale rounding support for FP8 quantization.

  • Implements a new moe_reduce kernel to handle weighted reduction of expert outputs separately from the main MOE computation
  • Adds scale_fmt parameter to FP8 quantization functions to support different scale formats (specifically 'ue8m0' for rounded scales)
  • Removes inline weight multiplication from MOE kernels, delegating this to the new reduction kernel

Reviewed changes

Copilot reviewed 10 out of 10 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
lmdeploy/pytorch/kernels/cuda/fused_moe.py Adds new moe_reduce kernel and helper function; removes weights/enable_weights parameters from fused_moe_kernel and fused_moe_kernel_launcher; replaces .sum() with moe_reduce() call
lmdeploy/pytorch/kernels/cuda/w8a8_fused_moe.py Removes weights/enable_weights parameters from kernel and launcher; imports and uses moe_reduce instead of .sum()
lmdeploy/pytorch/kernels/cuda/blocked_fp8_fused_moe.py Similar changes as w8a8_fused_moe.py - removes weights parameters and uses moe_reduce
lmdeploy/pytorch/kernels/cuda/fill_kv_cache.py Adds scale rounding helper functions (fast_log2_ceil, fast_pow2, fast_round_scale); adds ROUND_SCALE parameter to quantization; adds scale_fmt parameter support
lmdeploy/pytorch/backends/cuda/nsa.py Adds scale_fmt configuration and passes it to quant_fp8 and fill_kv_cache_blocked_fp8 calls
lmdeploy/pytorch/backends/cuda/blockedf8_modules.py Removes incorrect scale_fmt parameter from blocked_gemm_fp8 call (function doesn't accept this parameter)
lmdeploy/pytorch/backends/cuda/attention.py Adds scale_fmt='ue8m0' parameter to fill_kv_cache_blocked_fp8 call
tests/pytorch/kernel/test_fused_moe.py Removes weights and enable_weights fixtures; updates test ground truth to not apply weights (as kernel no longer does this)
tests/pytorch/kernel/test_fuse_moe_blocked_fp8.py Similar test updates as test_fused_moe.py
tests/pytorch/kernel/test_fill_kv_cache.py Adds scale_fmt fixture and parametrizes test with [None, 'ue8m0']; threads scale_fmt through gt fixture and test calls

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@lvhan028 lvhan028 requested a review from CUHKSZzxy December 25, 2025 12:23


def moe_reduce(hidden_states: torch.Tensor, topk_weights: torch.Tensor, fp32_acc: bool = False) -> torch.Tensor:
"""Moe reduce."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we set fp32_acc=True to align with before?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no such "before". Different moe kernels use different acc

Copy link
Collaborator

@RunningLeon RunningLeon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@lvhan028 lvhan028 merged commit b0befc3 into InternLM:main Dec 31, 2025
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants