Skip to content
Open
57 changes: 57 additions & 0 deletions benchmarks/python/benchmark_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,13 @@
from thunder.dynamo.compiler import thunderfx
from layers_for_inference_benchmark import (
GroupedSwiGLU,
NVFP4InferenceSwiGLU,
NVFP4InferenceLinear,
SwiGLU,
Llama4MoE,
NVFP4InferenceGroupedSwiGLU,
nvfuser_f16a_nvfp4weight_scaled_grouped_mm,
nvfuser_f16a_nvfp4weight_scaled_mm,
)
from thunder.tests.distributed.test_moe import GroupedLinearColwiseParallel, GroupedLinearRowwiseParallel
from thunder.transforms.cudagraph import CUDAGraphTransform
Expand Down Expand Up @@ -81,6 +85,7 @@ def _register_nvfp4_ops():
"""Register nvfp4 custom operations with Thunder."""
# Register f16a_nvfp4weight_scaled_grouped_mm with nvfuser translator
_nvfp4_grouped_mm_symbol = _register_custom_op(nvfuser_f16a_nvfp4weight_scaled_grouped_mm)
_nvfp4_scaled_mm_symbol = _register_custom_op(nvfuser_f16a_nvfp4weight_scaled_mm)

def nvfp4_grouped_mm_translator(
activation,
Expand Down Expand Up @@ -120,7 +125,39 @@ def nvfp4_grouped_mm_translator(
)
return out

def nvfp4_scaled_mm_translator(
activation,
fp4_weight,
weight_scaling_factor,
weight_global_scale,
*,
fd,
lc_to_nv_map,
):
from nvfuser_direct import DataType
from thunder.executors.nvfuserex_impl import getnv

nv_act = getnv(activation, fd, lc_to_nv_map)
nv_fp4_w = getnv(fp4_weight, fd, lc_to_nv_map)
nv_sf_w = getnv(weight_scaling_factor, fd, lc_to_nv_map)
nv_alpha = getnv(weight_global_scale, fd, lc_to_nv_map)

quantized_activation, activation_scale = fd.ops.nv_block_quantize(nv_act, nv_alpha, True, 16)
out, _, _ = fd.ops.scaled_mm(
quantized_activation,
nv_fp4_w,
activation_scale,
nv_sf_w,
nv_alpha,
bias=None,
beta=None,
dtype=DataType.BFloat16,
)
return out


_register_nvfuser_translator(_nvfp4_grouped_mm_symbol, nvfp4_grouped_mm_translator)
_register_nvfuser_translator(_nvfp4_scaled_mm_symbol, nvfp4_scaled_mm_translator)


# The logic is based on https://github.com/pytorch/ao/blob/b34c1037/torchao/quantization/quant_api.py#L230
Expand Down Expand Up @@ -181,6 +218,26 @@ def _quantize_llama4(model: nn.Module) -> None:
lambda model, cur_fqn: isinstance(model, GroupedSwiGLU),
)

_replace_with_custom_fn_if_matches_filter_with_name(
model,
NVFP4InferenceSwiGLU.from_swiglu,
lambda model, cur_fqn: isinstance(model, SwiGLU),
)

# Find and return all submodules of the model that are instances of Llama4MoE
def _find_llama4moe_recursive(module):
found = []
for child in module.children():
if isinstance(child, Llama4MoE):
found.append(child)
found.extend(_find_llama4moe_recursive(child))
return found

llama4moe_module = _find_llama4moe_recursive(model)
assert len(llama4moe_module) == 1, f"Expected exactly one Llama4MoE module, found {len(llama4moe_module)}"

# Quantize the gate projection layer
llama4moe_module[0].gate = NVFP4InferenceLinear.from_linear(llama4moe_module[0].gate)

@contextmanager
def timer():
Expand Down
154 changes: 154 additions & 0 deletions benchmarks/python/layers_for_inference_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@
"NVFP4InferenceGroupedLinear",
"NVFP4InferenceGroupedSwiGLU",
"nvfuser_f16a_nvfp4weight_scaled_grouped_mm",
"nvfuser_f16a_nvfp4weight_scaled_mm",
"NVFP4InferenceLinear",
"NVFP4InferenceSwiGLU",
]


Expand Down Expand Up @@ -236,6 +239,60 @@ def nvfuser_f16a_nvfp4weight_scaled_grouped_mm(
)
return grouped_mm(activation, hp_weight.transpose(2, 1), offsets)

@torch.library.custom_op("nvf_cutlass::f16a_nvfp4weight_scaled_mm", mutates_args=())
def nvfuser_f16a_nvfp4weight_scaled_mm(
activation: torch.Tensor,
fp4_weight: torch.Tensor,
weight_scaling_factor: torch.Tensor,
weight_global_scale: torch.Tensor,
) -> torch.Tensor:
# fp4_weight shape: (in_features // 2, out_features)
# Dequantize and transpose to get (out_features, in_features)
hp_weight = dequantize_to_dtype(
fp4_weight.t(),
weight_scaling_factor,
weight_global_scale,
activation.dtype,
fp4_weight.device,
16
)
# hp_weight is now (out_features, in_features) - ready for F.linear
return torch.nn.functional.linear(activation, hp_weight.to(torch.bfloat16))


@torch.library.register_fake("nvf_cutlass::f16a_nvfp4weight_scaled_mm")
def _(
activation: torch.Tensor,
fp4_weight: torch.Tensor,
weight_scaling_factor: torch.Tensor,
weight_global_scale: torch.Tensor,
) -> torch.Tensor:
# fp4_weight shape: (in_features // 2, out_features)
# Validate that activation has at least 1 dimension
if activation.ndim == 0:
raise ValueError(f"Expected activation to have at least 1 dimension, got {activation.ndim}")


if (
len(
{
t.device
for t in [
activation,
fp4_weight,
weight_scaling_factor,
weight_global_scale,
]
}
)
!= 1
):
raise ValueError("Expected all inputs to be on the same device.")


a = torch.empty((activation.shape[0], fp4_weight.t().shape[0]), device=activation.device, dtype=torch.bfloat16)
return a


@torch.library.register_fake("nvf_cutlass::f16a_nvfp4weight_scaled_grouped_mm")
def _(
Expand Down Expand Up @@ -524,6 +581,103 @@ def from_grouped_swiglu(grouped_swiglu: GroupedSwiGLU, fqn: str | None = None) -
return NVFP4InferenceGroupedSwiGLU(gate_proj, up_proj, down_proj)


class NVFP4InferenceLinear(nn.Module):
"""NVFP4 Linear layer for inference using nvf_cutlass.nvfp4_scaled_mm."""

def __init__(
self,
fp4_weight: torch.Tensor,
weight_scaling_factor: torch.Tensor,
weight_global_scale: torch.Tensor,
) -> None:
super().__init__()
self.register_buffer("fp4_weight", fp4_weight)
self.register_buffer("weight_scaling_factor", weight_scaling_factor)
self.register_buffer("weight_global_scale", weight_global_scale)


@property
def out_features(self) -> int:
return self.fp4_weight.size(1)

@property
def in_features(self) -> int:
return self.fp4_weight.size(0) * 2

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""Forward pass using nvfp4_scaled_mm.

Args:
hidden_states: Input tensor of shape [batch, seq_len, in_features]
"""
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])

# Use nvfp4_scaled_mm which handles the full computation
output = torch.ops.nvf_cutlass.f16a_nvfp4weight_scaled_mm(
hidden_states,
self.fp4_weight,
self.weight_scaling_factor,
self.weight_global_scale,
)

return output

@staticmethod
def from_linear(linear: nn.Linear, fqn: str | None = None) -> NVFP4InferenceLinear:
"""Create an NVFP4InferenceLinear from a standard Linear layer.

Args:
linear (nn.Linear): The source Linear layer.
fqn (str or None): Fully qualified name. Currently unused; reserved for future use or compatibility.
"""
weight_fp4, weight_scale, global_scale = quantize_linear_weight_to_nvfp4(linear.weight)
return NVFP4InferenceLinear(weight_fp4.t(), weight_scale, global_scale)


class NVFP4InferenceSwiGLU(nn.Module):
"""NVFP4 SwiGLU for inference using NVFP4InferenceLinear."""

def __init__(
self,
gate_proj: NVFP4InferenceLinear,
up_proj: NVFP4InferenceLinear,
down_proj: NVFP4InferenceLinear,
):
super().__init__()
self.gate_proj = gate_proj
self.up_proj = up_proj
self.down_proj = down_proj

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""Forward pass through SwiGLU.

Args:
hidden_states: Input tensor

Returns:
Output tensor after SwiGLU transformation
"""
gate_out = self.gate_proj(hidden_states)
up_out = self.up_proj(hidden_states)

intermediate = torch.nn.functional.silu(gate_out) * up_out

return self.down_proj(intermediate)

@staticmethod
def from_swiglu(swiglu, fqn: str | None = None) -> NVFP4InferenceSwiGLU:
"""Create an NVFP4InferenceSwiGLU from a SwiGLU module.

Args:
swiglu: The source SwiGLU module (should have gate_proj, up_proj, down_proj).
fqn (str or None): Fully qualified name. Currently unused; reserved for future use or compatibility.
"""
gate_proj = NVFP4InferenceLinear.from_linear(swiglu.gate_proj)
up_proj = NVFP4InferenceLinear.from_linear(swiglu.up_proj)
down_proj = NVFP4InferenceLinear.from_linear(swiglu.down_proj)
return NVFP4InferenceSwiGLU(gate_proj, up_proj, down_proj)


# Slightly modified version of `thunder.tests.test_networks.Llama4MoE`
# to have the same singature as transformers' Llama4TextMoe -- in this file
# return values include `router_logits`.
Expand Down
6 changes: 5 additions & 1 deletion csrc/runtime/allocations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -662,7 +662,11 @@ class BackwardTraverseFromAllocToLogical {
}
}

if (areDimsToBeMergedContiguous(tensor_, new_shape)) {
bool is_divisible = split->in()->extent()->evaluate().as<int64_t>() %
split->factor()->evaluate().as<int64_t>() ==
0;

if (is_divisible && areDimsToBeMergedContiguous(tensor_, new_shape)) {
tensor_ = tensor_.view(new_shape);
} else {
auto [tensor_new_shape, tensor_new_strides] =
Expand Down