Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions examples/deepseek/int4_kernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import torch
import triton
import triton.language as tl


@triton.jit
def int4_dequant_kernel(
x_ptr, # pointer to int32 input [M, N]
s_ptr, # pointer to bf16 scale [M, 8*N//BLOCK_SIZE]
y_ptr, # pointer to bf16 output [M, 8N]
M: tl.constexpr,
N: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
m = tl.program_id(axis=0)
n_block = tl.program_id(axis=1)

# Load int32 values, unroll to int4.
NUM_INT32_PER_BLOCK = BLOCK_SIZE // 8
int32_vals = tl.load(
x_ptr + m * N + n_block * NUM_INT32_PER_BLOCK + tl.arange(0, BLOCK_SIZE) // 8
)

offset = (tl.arange(0, BLOCK_SIZE) % 8) * 4
vals = ((int32_vals >> offset) & 0xF) - 8

# # Compute scale per block
# # Each scale covers block_size contiguous y
scale = tl.load(s_ptr + m * 8 * N // BLOCK_SIZE + n_block)

vals = vals.to(tl.float32) * scale.to(tl.float32)
tl.store(y_ptr + m * N * 8 + n_block * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE), vals)


def int4_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 32) -> torch.Tensor:
"""
Dequantizes a packed int4 tensor to bf16.

Args:
x: int32 tensor of shape [M, N] with packed int4.
s: bf16 tensor of shape [M, 8 * N // BLOCK_SIZE].
block_size: number of output columns per block.

Returns: bf16 tensor of shape [M, 8 * N]
"""
m, n = x.shape
y = torch.empty((m, 8 * n), dtype=torch.get_default_dtype(), device=x.device)

grid = (m, 8 * n // block_size)
int4_dequant_kernel[grid](x, s, y, m, n, BLOCK_SIZE=block_size)
return y
17 changes: 17 additions & 0 deletions examples/deepseek/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import pdb

import safetensors
import torch
from int4_kernel import int4_dequant

tensors = safetensors.safe_open("model-00001-of-00527.safetensors", framework="pt", device="cuda")

bf16 = tensors.get_tensor("model.layers.1.mlp.experts.0.down_proj.weight")
int32 = tensors.get_tensor("model.layers.1.mlp.experts.0.down_proj.weight_packed")
ws = tensors.get_tensor("model.layers.1.mlp.experts.0.down_proj.weight_scale")
torch.set_default_dtype(torch.bfloat16)
bf16_2 = int4_dequant(int32, ws, block_size=32)


if not torch.allclose(bf16_2, bf16, rtol=1e-4, atol=1e-4):
pdb.set_trace()
10 changes: 8 additions & 2 deletions examples/llm_ptq/example_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,12 @@ def get_model(
device_map=device_map,
**model_kwargs,
)
elif hf_config.quantization_config.get("format", None) == "pack-quantized":
model = AutoModelForCausalLM.from_pretrained(
ckpt_path,
device_map="auto",
trust_remote_code=trust_remote_code,
)
else:
architecture = hf_config.architectures[0]

Expand All @@ -346,9 +352,9 @@ def get_model(
from_config = auto_model_module._from_config

with init_empty_weights():
# When computing the device_map, assuming half precision by default,
# When computing the device_map, assuming bfloat16 precision by default,
# unless specified by the hf_config.
torch_dtype = getattr(hf_config, "torch_dtype", torch.float16)
torch_dtype = getattr(hf_config, "torch_dtype", torch.bfloat16)
model_kwargs2 = model_kwargs.copy()
if auto_model_module != AutoModelForCausalLM:
model_kwargs2.pop("trust_remote_code", None)
Expand Down
12 changes: 10 additions & 2 deletions examples/llm_ptq/hf_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,12 @@ def main(args):
][0:1]

# Generate preview before quantization
if is_nemotron_vl_model and tokenizer is not None:
if model_type == "deepseek":
print(
"Deepseek model may hit OOM during preview generation. Skipping preview generation."
)
generated_ids_before_ptq = None
elif is_nemotron_vl_model and tokenizer is not None:
generated_ids_before_ptq = run_nemotron_vl_preview(
full_model,
tokenizer,
Expand All @@ -523,6 +528,7 @@ def main(args):
else:
# Standard generation for non-Nemotron VL models
generated_ids_before_ptq = full_model.generate(input_ids, max_new_tokens=100)

if model_type == "gptoss" and args.qformat == "nvfp4_mlp_only":
print("Applying nvfp4 quantization (MoE only) for gpt-oss")

Expand All @@ -542,7 +548,9 @@ def main(args):
# Run some samples
torch.cuda.empty_cache()
generated_ids_after_ptq = None
if model_type != "llama4" and not is_nemotron_vl_model:
if generated_ids_before_ptq is None:
pass
elif model_type != "llama4" and not is_nemotron_vl_model:
# Our fake quantizer may not be fully compatible with torch.compile.
generated_ids_after_ptq = full_model.generate(input_ids, max_new_tokens=100)
elif is_nemotron_vl_model and tokenizer is not None:
Expand Down
4 changes: 2 additions & 2 deletions examples/llm_ptq/scripts/huggingface_example.sh
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ esac
IFS=","
for qformat in $QFORMAT; do
case $qformat in
fp8 | fp8_pc_pt | fp8_pb_wo | int8_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8) ;;
fp8 | fp8_pc_pt | fp8_pb_wo | int8_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8 | nvfp4_mlp_only) ;;
*)
echo "Unknown quant argument: Expected one of: [fp8, fp8_pc_pt, fp8_pb_wo, int8_wo, int8_sq, int4_awq, w4a8_awq, fp16, bf16, nvfp4, nvfp4_awq, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8]" >&2
echo "Unknown quant argument: Expected one of: [fp8, fp8_pc_pt, fp8_pb_wo, int8_wo, int8_sq, int4_awq, w4a8_awq, fp16, bf16, nvfp4, nvfp4_awq, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8, nvfp4_mlp_only]" >&2
exit 1
;;
esac
Expand Down
28 changes: 28 additions & 0 deletions modelopt/torch/quantization/plugins/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from typing import TYPE_CHECKING

import torch
from torch import Tensor
from torch.nn.functional import linear

try:
from torch.distributed.tensor import Shard
Expand Down Expand Up @@ -501,6 +503,22 @@ def top_k(self, value):
self.router.moe_top_k = value


class _QuantCompressedLinear(QuantModule):
def _setup(self):
self.input_quantizer = TensorQuantizer()
self.weight_quantizer = TensorQuantizer()

def forward(self, input: Tensor) -> Tensor:
from compressed_tensors.quantization import QuantizationStatus

if self.quantization_status == QuantizationStatus.COMPRESSED:
weight_data = self.compressor.decompress_module(self)
else:
weight_data = self.weight

return linear(self.input_quantizer(input), self.weight_quantizer(weight_data), self.bias)


try:
from transformers.models.llama4.modeling_llama4 import Llama4TextExperts, Llama4TextMoe

Expand Down Expand Up @@ -576,6 +594,16 @@ def top_k(self, value):
except ImportError:
pass

try:
from compressed_tensors.linear.compressed_linear import CompressedLinear

if CompressedLinear not in QuantModuleRegistry:
QuantModuleRegistry.register({CompressedLinear: "hf.CompressedLinear"})(
_QuantCompressedLinear
)
except ImportError:
pass


class _QuantGptOssExperts(_QuantFunctionalMixin):
"""Quantized wrapper for `transformers.GptOssExperts`.
Expand Down