Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 66 additions & 3 deletions tensorrt_llm/_torch/models/modeling_nemotron_h.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,13 +218,28 @@ def __init__(
# when loading NVFP4/W4A8_NVFP4_FP8 quantized expert weights.
# Look up the per-expert quant config from quant_config_dict and use it for create_moe.
moe_model_config = model_config
self._nvfp4_dequant_moe = False
if model_config.quant_config_dict is not None:
experts_prefix = f"model.layers.{layer_idx}.mixer.experts."
for key, cfg in model_config.quant_config_dict.items():
if key.startswith(experts_prefix):
moe_model_config = replace(model_config, quant_config=cfg)
# On GPUs without FP4 tensor cores (sm < 100), dequant NVFP4
# expert weights to BF16 at load time and run MoE unquantized.
if cfg.quant_mode.has_nvfp4() and get_sm_version() < 100:
self._nvfp4_dequant_moe = True
else:
moe_model_config = replace(model_config,
quant_config=cfg)
break

# When the global quant config (not per-layer dict) indicates NVFP4 on Hopper,
# also enable the dequant path.
if (not self._nvfp4_dequant_moe
and model_config.quant_config is not None
and model_config.quant_config.quant_mode.has_nvfp4()
and get_sm_version() < 100):
self._nvfp4_dequant_moe = True

# Setup MoE experts.
self.experts = create_moe(
routing_method=self.gate.routing_method,
Expand All @@ -241,6 +256,9 @@ def __init__(
activation_type=self.activation_type,
)

if self._nvfp4_dequant_moe:
self._wrap_moe_load_weights_for_dequant()

if reduce_output:
self.allreduce = AllReduce(
mapping=model_config.mapping,
Expand Down Expand Up @@ -295,6 +313,49 @@ def __init__(
for key in [EventType.Main, EventType.MoeShared]
}

def _wrap_moe_load_weights_for_dequant(self):
"""Wrap self.experts.load_weights to dequant NVFP4 expert weights to BF16."""
import functools
import types

import tensorrt_llm.quantization.utils.fp4_utils as fp4_utils

original_load_weights = self.experts.load_weights.__func__
target_dtype = torch.bfloat16

@functools.wraps(original_load_weights)
def _dequant_load_weights(self_moe, weights, **kwargs):
if not weights:
return original_load_weights(self_moe, weights, **kwargs)
dequanted = []
for w_dict in weights:
new_dict = dict(w_dict)
for key in list(new_dict.keys()):
if not key.endswith('.weight'):
continue
scale_key = key.replace('.weight', '.weight_scale')
scale2_key = key.replace('.weight', '.weight_scale_2')
tensor = new_dict[key]
if (tensor[...].dtype == torch.uint8
and scale_key in new_dict
and scale2_key in new_dict):
packed = tensor[...].cuda()
ws = new_dict[scale_key][...].cuda()
ws2 = new_dict[scale2_key][...].cuda()
N, K = packed.shape[0], packed.shape[1] * 2
new_dict[key] = fp4_utils.dequantize_nvfp4(
packed, ws, ws2, N, K, target_dtype)
del new_dict[scale_key]
del new_dict[scale2_key]
for suffix in ('.input_scale', '.input_quantizer'):
k = key.replace('.weight', suffix)
new_dict.pop(k, None)
dequanted.append(new_dict)
return original_load_weights(self_moe, dequanted, **kwargs)

self.experts.load_weights = types.MethodType(_dequant_load_weights,
self.experts)

def forward(
self,
hidden_states: torch.Tensor
Expand Down Expand Up @@ -403,10 +464,12 @@ def __init__(

quant_mode = (model_config.quant_config.quant_mode
if model_config.quant_config is not None else None)
self.is_nvfp4 = quant_mode is not None and quant_mode.has_nvfp4()
_has_fp4_hw = get_sm_version() >= 100
self.is_nvfp4 = (quant_mode is not None and quant_mode.has_nvfp4()
and _has_fp4_hw)
# For MIXED_PRECISION models, the global quant_mode is QuantMode(0). Check per-layer
# quant_config_dict to see if this specific layer is NVFP4-quantized.
if not self.is_nvfp4 and model_config.quant_config_dict is not None:
if not self.is_nvfp4 and _has_fp4_hw and model_config.quant_config_dict is not None:
layer_prefix = f"model.layers.{layer_idx}."
for key, cfg in model_config.quant_config_dict.items():
if key.startswith(layer_prefix) and cfg.quant_mode.has_nvfp4():
Expand Down
46 changes: 46 additions & 0 deletions tensorrt_llm/_torch/modules/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -1861,6 +1861,50 @@ def post_load_weights(self, module: Linear):
module.rebuild_tensor_metadata)


class NVFP4W4A16LinearMethod(NVFP4LinearMethod):
"""W4A16 fallback for NVFP4 weights on GPUs without FP4 tensor cores.

Inherits NVFP4LinearMethod's weight storage and loading (FP4 packed weights
+ per-block FP8 scales + global scale). Overrides apply() to dequantize
weights to BF16 and use standard matmul instead of nvfp4_gemm.
Activations remain in BF16 throughout -- no FP4 activation quantization.
"""

supports_nccl_symmetric_memory_window_output: ClassVar[bool] = False

def _dequantize_weight(self, module: Linear) -> torch.Tensor:
"""Dequantize FP4 packed weights to module.dtype (typically BF16)."""
weight = module.weight.data
K = weight.shape[1] * 2
N = module.out_features
sf_vec_size = module.scaling_vector_size

# Un-interleave per-block scales from CUTLASS swizzled layout to flat 2D
scale_rows = fp4_utils.pad_up(N, 128)
ws_2d = unswizzle_sf(module.weight_scale.data, scale_rows, K,
sf_vec_size)
ws_flat = ws_2d[:weight.shape[0], :K // sf_vec_size]

return fp4_utils.dequantize_nvfp4(weight, ws_flat,
module.weight_scale_2.data, N, K,
module.dtype, sf_vec_size)

def apply(self, module: Linear, input: torch.Tensor,
bias: Optional[torch.Tensor]):
w_bf16 = self._dequantize_weight(module)
output = F.linear(input, w_bf16, bias)
return output

def apply_linear_allreduce(self, module: Linear, input: torch.Tensor,
bias: Optional[torch.Tensor], tp_rank: int,
tp_group: List[int]):
output = self.apply(module, input, bias)
return output

def post_load_weights(self, module: Linear):
"""Skip FP4 GEMM alignment padding -- not needed for BF16 matmul."""


class W4A8NVFP4FP8LinearMethod(LinearMethodBase):

def create_weights(self, module: Linear, in_features: int,
Expand Down Expand Up @@ -2716,6 +2760,8 @@ def get_quant_method(quant_config: Optional[QuantConfig] = None):
if quant_config.layer_quant_mode.has_fp8_block_scales():
return FP8BlockScalesLinearMethod()
if quant_config.layer_quant_mode.has_nvfp4():
if get_sm_version() < 100:
return NVFP4W4A16LinearMethod()
if quant_config.quant_algo == QuantAlgo.NVFP4_ARC:
return NVFP4ARCLinearMethod()
else:
Expand Down
54 changes: 53 additions & 1 deletion tensorrt_llm/quantization/utils/fp4_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,13 @@
float4_sf_dtype = SF_DTYPE
fp4_buckets = FP4_BUCKETS

__all__ = ['float4_e2m1x2', 'float4_sf_dtype', 'pad_up', 'fp4_buckets']
E2M1_VALUES = torch.tensor(
[0, 0.5, 1, 1.5, 2, 3, 4, 6, 0, -0.5, -1, -1.5, -2, -3, -4, -6])

__all__ = [
'float4_e2m1x2', 'float4_sf_dtype', 'pad_up', 'fp4_buckets', 'E2M1_VALUES',
'dequantize_nvfp4'
]


def pad_up(x: int, y: int) -> int:
Expand Down Expand Up @@ -204,3 +210,49 @@ def shuffle_matrix_sf_a(

# 128x4
return torch.ops.trtllm.block_scale_interleave(w_shuffled)


def dequantize_nvfp4(
packed_weight: torch.Tensor,
weight_scale: torch.Tensor,
weight_scale_2: torch.Tensor,
out_features: int,
in_features: int,
target_dtype: torch.dtype = torch.bfloat16,
block_size: int = 16,
) -> torch.Tensor:
"""Dequantize NVFP4 packed weights to a floating-point dtype.

Args:
packed_weight: [N, K/2] uint8 with two E2M1 nibbles per byte.
weight_scale: Per-block scales (FP8 or FP32), flat or shaped.
weight_scale_2: Per-tensor global scale (FP32 scalar).
out_features: Unpadded number of output rows (N).
in_features: Unpadded number of input columns (K).
target_dtype: Output dtype (default bfloat16).
block_size: Number of elements per scale block (default 16).

Returns:
Dequantized weight tensor of shape [out_features, in_features].
"""
packed_uint8 = (packed_weight.view(torch.uint8)
if packed_weight.dtype != torch.uint8 else packed_weight)
N_stored = packed_uint8.shape[0]
K = packed_uint8.shape[1] * 2
device = packed_uint8.device

low = (packed_uint8 & 0x0F).long()
high = ((packed_uint8 >> 4) & 0x0F).long()
idx = torch.empty(N_stored, K, dtype=torch.long, device=device)
idx[:, 0::2] = low
idx[:, 1::2] = high
vals = E2M1_VALUES.to(device)[idx]

num_blocks = N_stored * (K // block_size)
ws = weight_scale.to(torch.float32).reshape(-1)[:num_blocks]
s2 = weight_scale_2.to(torch.float32)
block_scales = (ws * s2).view(N_stored, K // block_size, 1)
vals = vals.view(N_stored, K // block_size, block_size) * block_scales
vals = vals.view(N_stored, K)

return vals[:out_features, :in_features].to(target_dtype)
Loading
Loading