-
Notifications
You must be signed in to change notification settings - Fork 243
Support VLM calibration with image-text data #755
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #755 +/- ##
==========================================
- Coverage 74.66% 73.13% -1.53%
==========================================
Files 192 193 +1
Lines 18975 19555 +580
==========================================
+ Hits 14167 14302 +135
- Misses 4808 5253 +445 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
…for Nemotron-VLM-Dataset-v2 Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
…for Nemotron-VLM-Dataset-v2 Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
|
So, we only support image quantization for just nemotron-vl? If yes, why? |
| # limitations under the License. | ||
|
|
||
| """Utility functions for getting samples and forward loop function for different vlm datasets.""" | ||
| """Utility functions for getting samples and dataloader for different VLM calibration datasets. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ajrasane could you review this change?
|
@Edwardf0t1 do you have experiments evaluating the accuracy impact of using the new dataset? |
At this time, only Nemotron VL has been tested. We can extend the logic to support other VLMs later. Note that different VLMs may have different forward functions—e.g., the way the vision encoder interacts with the language decoder can vary across models. Do you have a preferred VL model you’d like us to support next? For instance, Qwen3-VL? |
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
Tested on two benchmarks DocVQA and InfoVQA for Nemotron Nano VL v2 with vLLM backend:
Image-text calibration is only marginally better in these cases, but the calibration flow in this PR should be ready. The follow-up experiments can be
|
| --qformat nvfp4 \ | ||
| --export_path <quantized_ckpt_path> \ | ||
| --trust_remote_code \ | ||
| --calib_with_images \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
qq: Can user choose which vlm dataset to use or we just provide one option
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When --calib_with_images is used, the calibration dataset is hardcoded to nemotron_vlm_dataset_v2, it's a very large dataset and we can choose a few subsets from it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you document the dataset name in the above description?
| # prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | ||
| # inputs = processor(text=[prompt], images=[pil_image], ...) | ||
|
|
||
| def _collate_fn(examples: list[dict[str, Any]]) -> dict[str, torch.Tensor] | dict[str, Any]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need to introduce these while the original one does not?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Previously we don't use image-text data for calibration, and standard dataLoader collation doesn't work for VLMs. A few reasons:
- Dataset has inconsistent image formats
- We need to convert conversational format to model input format.
- Processor must process images and text together to align properly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we create a class for this collate function?
class VLMCollator:
def __init__(self, processor, dataset_name, require_image, max_length, device):
self.processor = processor
self.repo_id = (
SUPPORTED_VLM_DATASET_CONFIG[dataset_name]["config"]["path"]
if dataset_name == "nemotron_vlm_dataset_v2"
else None
)
self.image_root = getattr(processor, "_modelopt_vlm_image_root", None)
self.require_image = require_image
self.max_length = max_length
self.device = device
def __call__(self, examples):
# ... the collate logicThis would make it more readable and easier to test.
jingyu-ml
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. I only reviewed the dataset processing part, which behaves as expected, loading the dataset on demand rather than downloading the entire dataset.
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
| use_media_shards=True, | ||
| max_shards=1, | ||
| ) | ||
| elif model_type == "mllama": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can this new dataset be used for mllama too? If yes, maybe we can remove this branch
| device, | ||
| trust_remote_code=args.trust_remote_code, | ||
| ) | ||
| elif is_nemotron_vl_model and args.calib_with_images: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is calib_with_images only working with is_nemotron_vl_model? And cannot be used for other VLMs?
| ) | ||
| calibrate_loop = None | ||
| if use_calibration: | ||
| base_forward_loop = create_forward_loop(dataloader=calib_dataloader) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: you combine 514 and 520
| if not (isinstance(part, dict) and part.get("type") == "image"): | ||
| continue | ||
| if "image" in part: | ||
| return part["image"] | ||
| # fallback | ||
| for key in ("images", "path", "image_url", "url", "value", "data"): | ||
| if key in part: | ||
| return part[key] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can be simplified to:
if isinstance(part, dict) and part.get("type") == "image":
for key in ("image", "images", "path", "image_url", "url", "value", "data"):
if key in part:
return part[key]| for shard in shard_list: | ||
| if yielded_total >= self.num_samples or not needed: | ||
| break | ||
| local_tar = hf_hub_download( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are downloading the shards twice, once here and once on line 145. Is there a way we can cache the results downloaded on line 145?
| if img is None: | ||
| img = ex.get("images", None) | ||
| if img is None and messages is not None: | ||
| img = _extract_first_image_from_messages(messages) | ||
| img = _maybe_load_image(img, repo_id=repo_id, image_root=image_root) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic is also used on line 291. Can we create a util:
def _get_image_from_example(ex: dict) -> Any:
"""Extract image from an example, checking common field names."""
img = ex.get("image") or ex.get("images")
if img is None:
img = _extract_first_image_from_messages(ex.get("messages"))
return imgThis will also simplify the lambda
| # prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | ||
| # inputs = processor(text=[prompt], images=[pil_image], ...) | ||
|
|
||
| def _collate_fn(examples: list[dict[str, Any]]) -> dict[str, torch.Tensor] | dict[str, Any]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we create a class for this collate function?
class VLMCollator:
def __init__(self, processor, dataset_name, require_image, max_length, device):
self.processor = processor
self.repo_id = (
SUPPORTED_VLM_DATASET_CONFIG[dataset_name]["config"]["path"]
if dataset_name == "nemotron_vlm_dataset_v2"
else None
)
self.image_root = getattr(processor, "_modelopt_vlm_image_root", None)
self.require_image = require_image
self.max_length = max_length
self.device = device
def __call__(self, examples):
# ... the collate logicThis would make it more readable and easier to test.
|
|
||
| # Match the model's preferred vision dtype (usually bf16). | ||
| vision_dtype = None | ||
| with contextlib.suppress(Exception): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you specify the Exceptions to be suppressed? Same for the other calls.
| SUPPORTED_VLM_DATASET_CONFIG: dict[str, dict[str, Any]] = { | ||
| "scienceqa": {"config": {"path": "derek-thomas/ScienceQA", "split": "train"}}, | ||
| # Large multi-subset dataset (use streaming to avoid downloading the entire dataset) | ||
| "nemotron_vlm_dataset_v2": { | ||
| "config": {"path": "nvidia/Nemotron-VLM-Dataset-v2", "split": "train", "streaming": True}, | ||
| # Provide a sane default that (a) includes in-repo media shards and (b) is document-centric. | ||
| # Subsets like docvqa_cot/chartqa_cot are JSONL-only in the dataset repo and require --vlm_image_root. | ||
| "default_subsets": ["sparsetables", "plotqa_cot", "wiki_en"], | ||
| }, | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we create a dataclass for this? Something like:
@dataclass
class VLMDatasetConfig:
path: str
split: str = "train"
streaming: bool = False
default_subsets: list[str] = field(default_factory=list)
SUPPORTED_VLM_DATASETS = {
"scienceqa": VLMDatasetConfig(path="derek-thomas/ScienceQA"),
"nemotron_vlm_dataset_v2": VLMDatasetConfig(
path="nvidia/Nemotron-VLM-Dataset-v2",
streaming=True,
default_subsets=["sparsetables", "plotqa_cot", "wiki_en"],
),
}| cfg = SUPPORTED_VLM_DATASET_CONFIG[dataset_name]["config"].copy() | ||
| streaming = bool(cfg.pop("streaming", False)) | ||
|
|
||
| if dataset_name == "nemotron_vlm_dataset_v2": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we move this logic to a different function like _get_nemotron_dataset()
What does this PR do?
Type of change: New feature
Overview:
The primary goal of this PR is to allow the model optimizer to use image-text pair data during the calibration phase of quantization, which is likely help improve accuracy of quantized VLMs like Nemotron VL on visual understanding tasks particularly, compared to text-only calibration data.
Nemotron-VLM-Dataset-v2.hf_ptq.py) clean.Nemotron-Nano-VL-12B-V2model with image data.This PR complements #347 and we will consolidate llm_ptq and vlm_ptq examples in follow-up PRs.
Usage
Testing
Before your PR is "Ready for review"
Additional Information