-
Notifications
You must be signed in to change notification settings - Fork 243
Support VLM calibration with image-text data #755
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
b9acc43
528b51d
3ef4b9d
2d60f98
42a8406
7489a36
bd87154
3200a63
8964aa5
3b7373d
5c774f9
f2774fc
59d97a6
e2e59f6
2a3868a
2611b0e
161fd56
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -161,6 +161,23 @@ scripts/huggingface_example.sh --model $HF_PATH --quant [fp8|nvfp4|int8_sq|int4_ | |
|
|
||
| [PTQ for DeepSeek](../deepseek/README.md) shows how to quantize the DeepSeek model with FP4 and export to TensorRT-LLM. | ||
|
|
||
| #### VLM calibration with image-text pairs (e.g., Nemotron VL) | ||
|
|
||
| For vision-language models, calibration quality can likely improve by using image-text pairs instead of text-only data, especially on visual understanding tasks: | ||
|
|
||
| ```bash | ||
| python hf_ptq.py \ | ||
| --pyt_ckpt_path <huggingface_model_card> \ | ||
| --qformat nvfp4 \ | ||
| --export_path <quantized_ckpt_path> \ | ||
| --trust_remote_code \ | ||
| --calib_with_images \ | ||
Edwardf0t1 marked this conversation as resolved.
Show resolved
Hide resolved
Edwardf0t1 marked this conversation as resolved.
Show resolved
Hide resolved
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. qq: Can user choose which vlm dataset to use or we just provide one option
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you document the dataset name in the above description? |
||
| --calib_size 512 | ||
| ``` | ||
|
|
||
| > Note: when `--calib_with_images` is set, `--calib_size` must be a single value. | ||
| This functionality is currently in beta and has been tested on `nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-BF16`. | ||
|
|
||
| ### NeMo Example Script | ||
|
|
||
| NeMo 2.0 framework PTQ and TensorRT-LLM deployment examples are maintained in the NeMo GitHub repo. Please refer to the [NeMo PTQ documentation](https://docs.nvidia.com/nemo-framework/user-guide/latest/model-optimization/quantization/quantization.html) for more details. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,6 +25,7 @@ | |
| from example_utils import ( | ||
| build_quant_cfg, | ||
| copy_custom_model_files, | ||
| create_vlm_calibration_loop, | ||
| get_model, | ||
| get_processor, | ||
| get_tokenizer, | ||
|
|
@@ -97,6 +98,39 @@ | |
| mto.enable_huggingface_checkpointing() | ||
|
|
||
|
|
||
| def extract_and_prepare_language_model_from_vl(full_model): | ||
| """Extract language model from VL model and disable quantization for non-language components. | ||
|
|
||
| Args: | ||
| full_model: The full VLM model | ||
|
|
||
| Returns: | ||
| tuple: (language_model, model_type) or (None, None) if not a VLM | ||
| """ | ||
| language_model_lineage = get_language_model_from_vl(full_model) | ||
| if language_model_lineage is not None: | ||
| language_model = language_model_lineage.pop(-1) | ||
| ancestors = language_model_lineage | ||
| # Apply disabled quant to all modules that are not part of language_model | ||
| # This excludes them during HF export | ||
| disabled_quant_cfg = { | ||
| "quant_cfg": {"default": {"enable": False}}, | ||
| "algorithm": "max", | ||
| } | ||
|
|
||
| memo = set(ancestors) | {language_model} | ||
| for ancestor in ancestors: | ||
| for _, module in ancestor.named_children(): | ||
| if module not in memo: | ||
| mtq.quantize(module, disabled_quant_cfg, forward_loop=None) | ||
| memo.add(module) | ||
|
|
||
| model_type = get_model_type(language_model) | ||
| return language_model, model_type | ||
|
|
||
| return None, None | ||
|
|
||
|
|
||
| def make_calib_dataloader( | ||
| args: argparse.Namespace, | ||
| language_model: torch.nn.Module, | ||
|
|
@@ -107,7 +141,30 @@ def make_calib_dataloader( | |
| ) -> tuple[DataLoader, str | None]: | ||
| calib_dataloader = None | ||
| first_text_speech_dataset = None | ||
| if model_type == "mllama": | ||
| if args.calib_with_images: | ||
| # VLM image-text calibration path: assume Nemotron VLM dataset by default. | ||
| assert processor is not None, ( | ||
| "Please provide a processor (e.g., AutoProcessor) for image calibration." | ||
| ) | ||
| assert len(args.calib_size) == 1, ( | ||
| "Image calibration currently supports a single dataset. " | ||
| "Please pass --calib_size with one value (e.g., --calib_size 256)." | ||
| ) | ||
| calib_dataloader = get_vlm_dataset_dataloader( | ||
| dataset_name="nemotron_vlm_dataset_v2", | ||
| processor=processor, | ||
| batch_size=args.batch_size, | ||
| num_samples=args.calib_size[0], | ||
| device=device, | ||
| max_length=args.calib_seq, | ||
| require_image=True, | ||
| subsets=["sparsetables", "plotqa_cot", "wiki_en"], | ||
| shuffle_buffer_size=10_000, | ||
| seed=42, | ||
| use_media_shards=True, | ||
| max_shards=1, | ||
| ) | ||
| elif model_type == "mllama": | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can this new dataset be used for mllama too? If yes, maybe we can remove this branch |
||
| assert processor is not None and isinstance(processor, MllamaImageProcessor), ( | ||
| "The MllamaImageProcessor must be set." | ||
| ) | ||
|
|
@@ -164,6 +221,12 @@ def auto_quantize( | |
| ): | ||
| """Auto search quantization of multiple formats.""" | ||
|
|
||
| if args.calib_with_images: | ||
| raise NotImplementedError( | ||
| "AutoQuantize with image-text calibration is not supported yet. " | ||
| "Please run plain PTQ (e.g., --qformat nvfp4) with --calib_with_images." | ||
| ) | ||
|
|
||
| assert not (args.auto_quantize_bits and args.inference_pipeline_parallel > 1), ( | ||
| "Auto Quantization is not supported for pipeline parallel size > 1" | ||
| ) | ||
|
|
@@ -291,7 +354,9 @@ def load_model(args: argparse.Namespace): | |
| tokenizer = None | ||
| language_model = full_model | ||
| default_padding_side = None | ||
| default_pad_token = None | ||
|
|
||
| is_nemotron_vl_model = is_nemotron_vl(full_model) | ||
| if model_type == "mllama": | ||
| processor = get_processor( | ||
| args.pyt_ckpt_path, | ||
|
|
@@ -307,6 +372,31 @@ def load_model(args: argparse.Namespace): | |
| device, | ||
| trust_remote_code=args.trust_remote_code, | ||
| ) | ||
| elif is_nemotron_vl_model and args.calib_with_images: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is calib_with_images only working with is_nemotron_vl_model? And cannot be used for other VLMs? |
||
| # For Nemotron VL image calibration, we need an AutoProcessor to build multimodal inputs. | ||
| processor = AutoProcessor.from_pretrained( | ||
| args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code, padding_side="left" | ||
| ) | ||
|
|
||
| if hasattr(processor, "tokenizer") and processor.tokenizer is not None: | ||
| tokenizer = processor.tokenizer | ||
| else: | ||
| tokenizer = get_tokenizer(args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code) | ||
|
|
||
| default_pad_token = tokenizer.pad_token | ||
| # Some Nemotron tokenizers may not define pad_token by default; but we use padding=True during calibration. | ||
| if tokenizer.pad_token is None: | ||
| tokenizer.pad_token = tokenizer.eos_token | ||
Edwardf0t1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| assert tokenizer.pad_token is not None, f"Pad token for {args.pyt_ckpt_path} cannot be set!" | ||
|
|
||
| default_padding_side = tokenizer.padding_side | ||
| tokenizer.padding_side = "left" | ||
|
|
||
| # Quantize only the language model, but keep the full_model for calibration forward. | ||
| extracted_lm, extracted_model_type = extract_and_prepare_language_model_from_vl(full_model) | ||
| if extracted_lm is not None: | ||
| language_model = extracted_lm | ||
| model_type = extracted_model_type | ||
| else: | ||
| if args.dataset is None: | ||
| args.dataset = ["cnn_dailymail", "nemotron-post-training-dataset-v2"] | ||
|
|
@@ -320,29 +410,15 @@ def load_model(args: argparse.Namespace): | |
| tokenizer = get_tokenizer(args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code) | ||
|
|
||
| default_padding_side = tokenizer.padding_side | ||
| default_pad_token = tokenizer.pad_token | ||
| # Left padding usually provides better calibration result. | ||
| tokenizer.padding_side = "left" | ||
|
|
||
| # We only quantize the language model for VLMs other than the type supported above. | ||
| language_model_lineage = get_language_model_from_vl(full_model) | ||
| if language_model_lineage is not None: | ||
| language_model = language_model_lineage.pop(-1) | ||
| ancestors = language_model_lineage | ||
| # Apply disabled quant to all modules that are not part of language_model so we can exclude them during | ||
| # HF export. | ||
| disabled_quant_cfg = { | ||
| "quant_cfg": {"default": {"enable": False}}, | ||
| "algorithm": "max", | ||
| } | ||
|
|
||
| memo = set(ancestors) | {language_model} | ||
| for ancestor in ancestors: | ||
| for _, module in ancestor.named_children(): | ||
| if module not in memo: | ||
| mtq.quantize(module, disabled_quant_cfg, forward_loop=None) | ||
| memo.add(module) | ||
|
|
||
| model_type = get_model_type(language_model) | ||
| extracted_lm, extracted_model_type = extract_and_prepare_language_model_from_vl(full_model) | ||
| if extracted_lm is not None: | ||
| language_model = extracted_lm | ||
| model_type = extracted_model_type | ||
|
|
||
| if model_type == "phi4mm": | ||
| warnings.warn("Please set the default input_mode to InputMode.LANGUAGE before quantizing.") | ||
|
|
@@ -355,6 +431,7 @@ def load_model(args: argparse.Namespace): | |
| processor, | ||
| tokenizer, | ||
| default_padding_side, | ||
| default_pad_token, | ||
| device, | ||
| ) | ||
|
|
||
|
|
@@ -432,9 +509,15 @@ def mono_quantize( | |
|
|
||
| if not use_calibration: | ||
| warnings.warn("Dynamic quantization. Calibration skipped.") | ||
| calibrate_loop = ( | ||
| create_forward_loop(dataloader=calib_dataloader) if use_calibration else None | ||
| ) | ||
| calibrate_loop = None | ||
| if use_calibration: | ||
| base_forward_loop = create_forward_loop(dataloader=calib_dataloader) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: you combine 514 and 520 |
||
| # For Nemotron VL image calibration, the dataloader yields multimodal kwargs (e.g., pixel_values). | ||
| # Those kwargs must be consumed by the *full* VLM model, not the extracted language_model. | ||
| if args.calib_with_images and is_nemotron_vl_model: | ||
| calibrate_loop = create_vlm_calibration_loop(full_model, calib_dataloader) | ||
| else: | ||
| calibrate_loop = base_forward_loop | ||
|
|
||
| if calibration_only: | ||
| language_model = mtq.calibrate( | ||
|
|
@@ -461,6 +544,7 @@ def export_quantized( | |
| model_type: str | None, | ||
| tokenizer: PreTrainedTokenizerBase | None, | ||
| default_padding_side, | ||
| default_pad_token, | ||
| ): | ||
| with torch.inference_mode(): | ||
| if model_type is None: | ||
|
|
@@ -546,6 +630,8 @@ def export_quantized( | |
| # Restore default padding and export the tokenizer as well. | ||
| if tokenizer is not None: | ||
| tokenizer.padding_side = default_padding_side | ||
| if default_pad_token is not None: | ||
| tokenizer.pad_token = default_pad_token | ||
| tokenizer.save_pretrained(export_path) | ||
|
|
||
| end_time = time.time() | ||
|
|
@@ -690,6 +776,7 @@ def quantize_main( | |
| processor: BaseImageProcessor | ProcessorMixin | None, | ||
| tokenizer: PreTrainedTokenizerBase | None, | ||
| default_padding_side, | ||
| default_pad_token, | ||
| device: torch.device, | ||
| ): | ||
| if args.batch_size == 0: | ||
|
|
@@ -805,7 +892,15 @@ def quantize_main( | |
| is_nemotron_vl_model, | ||
| first_text_speech_dataset, | ||
| ) | ||
| export_quantized(args, full_model, language_model, model_type, tokenizer, default_padding_side) | ||
| export_quantized( | ||
| args, | ||
| full_model, | ||
| language_model, | ||
| model_type, | ||
| tokenizer, | ||
| default_padding_side, | ||
| default_pad_token, | ||
| ) | ||
|
|
||
|
|
||
| def parse_args() -> argparse.Namespace: | ||
|
|
@@ -856,6 +951,14 @@ def parse_args() -> argparse.Namespace: | |
| type=str, | ||
| default=None, | ||
| ) | ||
| parser.add_argument( | ||
Edwardf0t1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| "--calib_with_images", | ||
| action="store_true", | ||
| help=( | ||
| "Calibrate with image-text pairs (for VLMs). " | ||
| "This uses nemotron_vlm_dataset_v2 with default subsets (sparsetables, plotqa_cot, wiki_en)." | ||
| ), | ||
| ) | ||
| parser.add_argument("--inference_tensor_parallel", type=int, default=1) | ||
| parser.add_argument("--inference_pipeline_parallel", type=int, default=1) | ||
| parser.add_argument("--awq_block_size", default=0, type=int) | ||
|
|
@@ -993,6 +1096,7 @@ def main(args: argparse.Namespace): | |
| processor, | ||
| tokenizer, | ||
| default_padding_side, | ||
| default_pad_token, | ||
| device, | ||
| ) = load_model(args) | ||
|
|
||
|
|
@@ -1010,6 +1114,7 @@ def main(args: argparse.Namespace): | |
| processor, | ||
| tokenizer, | ||
| default_padding_side, | ||
| default_pad_token, | ||
| device, | ||
| ) | ||
|
|
||
|
|
@@ -1020,6 +1125,6 @@ def main(args: argparse.Namespace): | |
| if args.export_fmt != "hf": | ||
| warnings.warn("Deprecated. --export_fmt forced to hf.") | ||
|
|
||
| args.dataset = args.dataset.split(",") if args.dataset else None | ||
| args.dataset = args.dataset.split(",") if isinstance(args.dataset, str) else args.dataset | ||
| args.calib_size = [int(num_sample) for num_sample in args.calib_size.split(",")] | ||
| main(args) | ||
Uh oh!
There was an error while loading. Please reload this page.