Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions examples/llm_ptq/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,23 @@ scripts/huggingface_example.sh --model $HF_PATH --quant [fp8|nvfp4|int8_sq|int4_

[PTQ for DeepSeek](../deepseek/README.md) shows how to quantize the DeepSeek model with FP4 and export to TensorRT-LLM.

#### VLM calibration with image-text pairs (e.g., Nemotron VL)

For vision-language models, calibration quality can likely improve by using image-text pairs instead of text-only data, especially on visual understanding tasks:

```bash
python hf_ptq.py \
--pyt_ckpt_path <huggingface_model_card> \
--qformat nvfp4 \
--export_path <quantized_ckpt_path> \
--trust_remote_code \
--calib_with_images \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qq: Can user choose which vlm dataset to use or we just provide one option

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When --calib_with_images is used, the calibration dataset is hardcoded to nemotron_vlm_dataset_v2, it's a very large dataset and we can choose a few subsets from it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you document the dataset name in the above description?

--calib_size 512
```

> Note: when `--calib_with_images` is set, `--calib_size` must be a single value.
This functionality is currently in beta and has been tested on `nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-BF16`.

### NeMo Example Script

NeMo 2.0 framework PTQ and TensorRT-LLM deployment examples are maintained in the NeMo GitHub repo. Please refer to the [NeMo PTQ documentation](https://docs.nvidia.com/nemo-framework/user-guide/latest/model-optimization/quantization/quantization.html) for more details.
Expand Down
48 changes: 48 additions & 0 deletions examples/llm_ptq/example_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import copy
import glob
import inspect
import os
import shutil
import sys
Expand Down Expand Up @@ -131,6 +132,53 @@ def is_nemotron_vl(model_or_config):
return any("nemotron" in arch.lower() for arch in architectures)


def create_vlm_calibration_loop(full_model, calib_dataloader):
"""Create a calibration loop for VLM models that handles multimodal inputs.

This function inspects the model's forward signature and filters batch kwargs
to only include supported parameters, then calls the appropriate forward method.

Args:
full_model: The full VLM model
calib_dataloader: DataLoader yielding multimodal batches

Returns:
A calibration function that can be passed to mtq.quantize()
"""
# Import here to avoid circular dependency
from nemotron_vl_calib import safe_nemotron_vl_forward

def calibrate_loop(_model):
# Inspect model's forward signature to determine what parameters it accepts
forward_params = inspect.signature(full_model.forward).parameters
accepts_kwargs = any(
p.kind == inspect.Parameter.VAR_KEYWORD for p in forward_params.values()
)
allowed_keys = set(forward_params.keys())

full_model.eval()
with torch.no_grad():
for batch in calib_dataloader:
# Filter batch to only include parameters the model accepts
if accepts_kwargs:
call_kwargs = batch
else:
call_kwargs = {k: v for k, v in batch.items() if k in allowed_keys}
# Remove None values
call_kwargs = {k: v for k, v in call_kwargs.items() if v is not None}

# Use safe_nemotron_vl_forward for Nemotron Nano VL (embedding-injection style)
# For other VLMs (like Nemotron-Parse), use standard forward
if hasattr(full_model, "img_context_token_id"):
# Nemotron Nano VL style
safe_nemotron_vl_forward(full_model, call_kwargs)
else:
# Standard encoder-decoder or other VLM architectures
full_model(**call_kwargs)

return calibrate_loop


def build_quant_cfg(
qformat,
kv_cache_qformat,
Expand Down
155 changes: 130 additions & 25 deletions examples/llm_ptq/hf_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from example_utils import (
build_quant_cfg,
copy_custom_model_files,
create_vlm_calibration_loop,
get_model,
get_processor,
get_tokenizer,
Expand Down Expand Up @@ -97,6 +98,39 @@
mto.enable_huggingface_checkpointing()


def extract_and_prepare_language_model_from_vl(full_model):
"""Extract language model from VL model and disable quantization for non-language components.

Args:
full_model: The full VLM model

Returns:
tuple: (language_model, model_type) or (None, None) if not a VLM
"""
language_model_lineage = get_language_model_from_vl(full_model)
if language_model_lineage is not None:
language_model = language_model_lineage.pop(-1)
ancestors = language_model_lineage
# Apply disabled quant to all modules that are not part of language_model
# This excludes them during HF export
disabled_quant_cfg = {
"quant_cfg": {"default": {"enable": False}},
"algorithm": "max",
}

memo = set(ancestors) | {language_model}
for ancestor in ancestors:
for _, module in ancestor.named_children():
if module not in memo:
mtq.quantize(module, disabled_quant_cfg, forward_loop=None)
memo.add(module)

model_type = get_model_type(language_model)
return language_model, model_type

return None, None


def make_calib_dataloader(
args: argparse.Namespace,
language_model: torch.nn.Module,
Expand All @@ -107,7 +141,30 @@ def make_calib_dataloader(
) -> tuple[DataLoader, str | None]:
calib_dataloader = None
first_text_speech_dataset = None
if model_type == "mllama":
if args.calib_with_images:
# VLM image-text calibration path: assume Nemotron VLM dataset by default.
assert processor is not None, (
"Please provide a processor (e.g., AutoProcessor) for image calibration."
)
assert len(args.calib_size) == 1, (
"Image calibration currently supports a single dataset. "
"Please pass --calib_size with one value (e.g., --calib_size 256)."
)
calib_dataloader = get_vlm_dataset_dataloader(
dataset_name="nemotron_vlm_dataset_v2",
processor=processor,
batch_size=args.batch_size,
num_samples=args.calib_size[0],
device=device,
max_length=args.calib_seq,
require_image=True,
subsets=["sparsetables", "plotqa_cot", "wiki_en"],
shuffle_buffer_size=10_000,
seed=42,
use_media_shards=True,
max_shards=1,
)
elif model_type == "mllama":
Copy link
Collaborator

@cjluo-nv cjluo-nv Jan 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this new dataset be used for mllama too? If yes, maybe we can remove this branch

assert processor is not None and isinstance(processor, MllamaImageProcessor), (
"The MllamaImageProcessor must be set."
)
Expand Down Expand Up @@ -164,6 +221,12 @@ def auto_quantize(
):
"""Auto search quantization of multiple formats."""

if args.calib_with_images:
raise NotImplementedError(
"AutoQuantize with image-text calibration is not supported yet. "
"Please run plain PTQ (e.g., --qformat nvfp4) with --calib_with_images."
)

assert not (args.auto_quantize_bits and args.inference_pipeline_parallel > 1), (
"Auto Quantization is not supported for pipeline parallel size > 1"
)
Expand Down Expand Up @@ -291,7 +354,9 @@ def load_model(args: argparse.Namespace):
tokenizer = None
language_model = full_model
default_padding_side = None
default_pad_token = None

is_nemotron_vl_model = is_nemotron_vl(full_model)
if model_type == "mllama":
processor = get_processor(
args.pyt_ckpt_path,
Expand All @@ -307,6 +372,31 @@ def load_model(args: argparse.Namespace):
device,
trust_remote_code=args.trust_remote_code,
)
elif is_nemotron_vl_model and args.calib_with_images:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is calib_with_images only working with is_nemotron_vl_model? And cannot be used for other VLMs?

# For Nemotron VL image calibration, we need an AutoProcessor to build multimodal inputs.
processor = AutoProcessor.from_pretrained(
args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code, padding_side="left"
)

if hasattr(processor, "tokenizer") and processor.tokenizer is not None:
tokenizer = processor.tokenizer
else:
tokenizer = get_tokenizer(args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code)

default_pad_token = tokenizer.pad_token
# Some Nemotron tokenizers may not define pad_token by default; but we use padding=True during calibration.
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
assert tokenizer.pad_token is not None, f"Pad token for {args.pyt_ckpt_path} cannot be set!"

default_padding_side = tokenizer.padding_side
tokenizer.padding_side = "left"

# Quantize only the language model, but keep the full_model for calibration forward.
extracted_lm, extracted_model_type = extract_and_prepare_language_model_from_vl(full_model)
if extracted_lm is not None:
language_model = extracted_lm
model_type = extracted_model_type
else:
if args.dataset is None:
args.dataset = ["cnn_dailymail", "nemotron-post-training-dataset-v2"]
Expand All @@ -320,29 +410,15 @@ def load_model(args: argparse.Namespace):
tokenizer = get_tokenizer(args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code)

default_padding_side = tokenizer.padding_side
default_pad_token = tokenizer.pad_token
# Left padding usually provides better calibration result.
tokenizer.padding_side = "left"

# We only quantize the language model for VLMs other than the type supported above.
language_model_lineage = get_language_model_from_vl(full_model)
if language_model_lineage is not None:
language_model = language_model_lineage.pop(-1)
ancestors = language_model_lineage
# Apply disabled quant to all modules that are not part of language_model so we can exclude them during
# HF export.
disabled_quant_cfg = {
"quant_cfg": {"default": {"enable": False}},
"algorithm": "max",
}

memo = set(ancestors) | {language_model}
for ancestor in ancestors:
for _, module in ancestor.named_children():
if module not in memo:
mtq.quantize(module, disabled_quant_cfg, forward_loop=None)
memo.add(module)

model_type = get_model_type(language_model)
extracted_lm, extracted_model_type = extract_and_prepare_language_model_from_vl(full_model)
if extracted_lm is not None:
language_model = extracted_lm
model_type = extracted_model_type

if model_type == "phi4mm":
warnings.warn("Please set the default input_mode to InputMode.LANGUAGE before quantizing.")
Expand All @@ -355,6 +431,7 @@ def load_model(args: argparse.Namespace):
processor,
tokenizer,
default_padding_side,
default_pad_token,
device,
)

Expand Down Expand Up @@ -432,9 +509,15 @@ def mono_quantize(

if not use_calibration:
warnings.warn("Dynamic quantization. Calibration skipped.")
calibrate_loop = (
create_forward_loop(dataloader=calib_dataloader) if use_calibration else None
)
calibrate_loop = None
if use_calibration:
base_forward_loop = create_forward_loop(dataloader=calib_dataloader)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: you combine 514 and 520

# For Nemotron VL image calibration, the dataloader yields multimodal kwargs (e.g., pixel_values).
# Those kwargs must be consumed by the *full* VLM model, not the extracted language_model.
if args.calib_with_images and is_nemotron_vl_model:
calibrate_loop = create_vlm_calibration_loop(full_model, calib_dataloader)
else:
calibrate_loop = base_forward_loop

if calibration_only:
language_model = mtq.calibrate(
Expand All @@ -461,6 +544,7 @@ def export_quantized(
model_type: str | None,
tokenizer: PreTrainedTokenizerBase | None,
default_padding_side,
default_pad_token,
):
with torch.inference_mode():
if model_type is None:
Expand Down Expand Up @@ -546,6 +630,8 @@ def export_quantized(
# Restore default padding and export the tokenizer as well.
if tokenizer is not None:
tokenizer.padding_side = default_padding_side
if default_pad_token is not None:
tokenizer.pad_token = default_pad_token
tokenizer.save_pretrained(export_path)

end_time = time.time()
Expand Down Expand Up @@ -690,6 +776,7 @@ def quantize_main(
processor: BaseImageProcessor | ProcessorMixin | None,
tokenizer: PreTrainedTokenizerBase | None,
default_padding_side,
default_pad_token,
device: torch.device,
):
if args.batch_size == 0:
Expand Down Expand Up @@ -805,7 +892,15 @@ def quantize_main(
is_nemotron_vl_model,
first_text_speech_dataset,
)
export_quantized(args, full_model, language_model, model_type, tokenizer, default_padding_side)
export_quantized(
args,
full_model,
language_model,
model_type,
tokenizer,
default_padding_side,
default_pad_token,
)


def parse_args() -> argparse.Namespace:
Expand Down Expand Up @@ -856,6 +951,14 @@ def parse_args() -> argparse.Namespace:
type=str,
default=None,
)
parser.add_argument(
"--calib_with_images",
action="store_true",
help=(
"Calibrate with image-text pairs (for VLMs). "
"This uses nemotron_vlm_dataset_v2 with default subsets (sparsetables, plotqa_cot, wiki_en)."
),
)
parser.add_argument("--inference_tensor_parallel", type=int, default=1)
parser.add_argument("--inference_pipeline_parallel", type=int, default=1)
parser.add_argument("--awq_block_size", default=0, type=int)
Expand Down Expand Up @@ -993,6 +1096,7 @@ def main(args: argparse.Namespace):
processor,
tokenizer,
default_padding_side,
default_pad_token,
device,
) = load_model(args)

Expand All @@ -1010,6 +1114,7 @@ def main(args: argparse.Namespace):
processor,
tokenizer,
default_padding_side,
default_pad_token,
device,
)

Expand All @@ -1020,6 +1125,6 @@ def main(args: argparse.Namespace):
if args.export_fmt != "hf":
warnings.warn("Deprecated. --export_fmt forced to hf.")

args.dataset = args.dataset.split(",") if args.dataset else None
args.dataset = args.dataset.split(",") if isinstance(args.dataset, str) else args.dataset
args.calib_size = [int(num_sample) for num_sample in args.calib_size.split(",")]
main(args)
Loading