Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 123 additions & 33 deletions examples/llm_ptq/example_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from accelerate.utils import get_max_memory
from transformers import (
AutoConfig,
AutoModel,
AutoModelForCausalLM,
AutoProcessor,
AutoTokenizer,
Expand Down Expand Up @@ -64,27 +65,39 @@ def run_nemotron_vl_preview(
"""
from vlm_utils import run_text_only_generation, run_vl_preview_generation

print(f"Running text-only preview generation for Nemotron VL model ({stage_name})...")
question = tokenizer.decode(input_ids[0], skip_special_tokens=True)
generation_config = {
"max_new_tokens": 100,
"do_sample": False,
"eos_token_id": tokenizer.eos_token_id,
}

# Try text-only generation
text_response = run_text_only_generation(
full_model, tokenizer, question, generation_config, pyt_ckpt_path
)
# Check if this is Nemotron-Parse (encoder-decoder model that requires images)
config = full_model.config
architectures = getattr(config, "architectures", [])
is_nemotron_parse = any("nemotronparse" in arch.lower() for arch in architectures)

generated_ids = None

if not is_nemotron_parse:
# Only try text-only generation for models that support it (not Nemotron-Parse)
print(f"Running text-only preview generation for Nemotron VL model ({stage_name})...")
question = tokenizer.decode(input_ids[0], skip_special_tokens=True)
generation_config = {
"max_new_tokens": 100,
"do_sample": False,
"eos_token_id": tokenizer.eos_token_id,
}

# Try text-only generation
text_response = run_text_only_generation(
full_model, tokenizer, question, generation_config, pyt_ckpt_path
)

if text_response is not None:
print(f"✅ Text-only generation successful: {text_response[:100]}...")
generated_ids = text_response
elif allow_fallback:
print("Text-only generation failed, falling back to standard generate...")
generated_ids = full_model.generate(input_ids, max_new_tokens=100)
if text_response is not None:
print(f"✅ Text-only generation successful: {text_response[:100]}...")
generated_ids = text_response
elif allow_fallback:
print("Text-only generation failed, falling back to standard generate...")
generated_ids = full_model.generate(input_ids, max_new_tokens=100)
else:
generated_ids = None
print(
f"Skipping text-only generation for Nemotron-Parse ({stage_name}) - "
"this encoder-decoder model requires images for all operations."
)

# Run additional VL test with images
print(f"Running additional VL test with images ({stage_name})...")
Expand All @@ -95,6 +108,10 @@ def run_nemotron_vl_preview(

def _is_multimodal_config(config):
"""Check if a config indicates a multimodal model (config-only version of is_multimodal_model)."""
# Check for Nemotron-Parse encoder-decoder architecture
architectures = getattr(config, "architectures", [])
is_nemotron_parse = any("nemotronparse" in arch.lower() for arch in architectures)

return (
hasattr(config, "vision_config") # Standard vision config (e.g., Qwen2.5-VL)
or getattr(config, "model_type", "") == "phi4mm" # Phi-4 multimodal
Expand All @@ -103,6 +120,7 @@ def _is_multimodal_config(config):
or (
hasattr(config, "embd_layer") and hasattr(config.embd_layer, "image_embd_layer")
) # Image embedding layers
or is_nemotron_parse # Nemotron-Parse conditional generation model
)


Expand Down Expand Up @@ -203,9 +221,33 @@ def get_tokenizer(ckpt_path, trust_remote_code=False, **kwargs) -> PreTrainedTok
if "vila" in ckpt_path.lower():
ckpt_path += "/llm"

tokenizer = AutoTokenizer.from_pretrained(
ckpt_path, trust_remote_code=trust_remote_code, **kwargs
)
# Suppress verbose tokenizer output (e.g., printing all special tokens)
import contextlib
import io
import logging
import os

# Save current settings
old_verbosity = os.environ.get("TOKENIZERS_PARALLELISM", None)
transformers_log_level = logging.getLogger("transformers").level

# Suppress output
os.environ["TOKENIZERS_PARALLELISM"] = "false"
logging.getLogger("transformers").setLevel(logging.ERROR)

# Also capture stdout to suppress verbose tokenizer printing
with contextlib.redirect_stdout(io.StringIO()):
try:
tokenizer = AutoTokenizer.from_pretrained(
ckpt_path, trust_remote_code=trust_remote_code, **kwargs
)
finally:
# Restore original settings
if old_verbosity is not None:
os.environ["TOKENIZERS_PARALLELISM"] = old_verbosity
else:
os.environ.pop("TOKENIZERS_PARALLELISM", None)
logging.getLogger("transformers").setLevel(transformers_log_level)

# can't set attribute 'pad_token' for "<unk>"
# We skip this step for Nemo models
Expand Down Expand Up @@ -257,8 +299,32 @@ def get_processor(
)

return MllamaImageProcessor(processor, device)
else:
# Try to load AutoProcessor for other VL models (e.g., Nemotron-Parse)
# This will only work if the model has a processor config
try:
import contextlib
import io
import logging

# Suppress verbose output from processor/tokenizer loading
transformers_log_level = logging.getLogger("transformers").level
logging.getLogger("transformers").setLevel(logging.ERROR)

with contextlib.redirect_stdout(io.StringIO()):
processor = AutoProcessor.from_pretrained(
ckpt_path,
**model_kwargs,
)

return None
# Restore logging
logging.getLogger("transformers").setLevel(transformers_log_level)

print(f"Loaded AutoProcessor for model type: {model_type}")
return processor
except Exception as e:
print(f"Could not load processor for {model_type}: {e}")
return None


def get_dtype(dtype):
Expand Down Expand Up @@ -301,12 +367,23 @@ def get_model(
# Load config once and handle VL model detection
try:
hf_config = AutoConfig.from_pretrained(ckpt_path, **config_kwargs)

# Check specifically for Nemotron-Parse
architectures = getattr(hf_config, "architectures", [])
is_nemotron_parse = any("nemotronparse" in arch.lower() for arch in architectures)

if is_nemotron_vl(hf_config):
print(
"Detected Nemotron VL model from config. "
"Disabling automatic device mapping for compatibility."
)
device_map = None
if is_nemotron_parse:
# Nemotron-Parse works fine with device_map="auto"
# Keep device_map="auto" to ensure proper device placement
print("Detected Nemotron-Parse model from config. Using automatic device mapping.")
else:
# For other Nemotron VL models, disable device_map for compatibility
print(
"Detected Nemotron VL model from config. "
"Disabling automatic device mapping for compatibility."
)
device_map = None
except Exception as e:
print(f"Error: Could not load config from {ckpt_path}: {e}")
raise RuntimeError(f"Failed to load model configuration from {ckpt_path}") from e
Expand All @@ -320,8 +397,6 @@ def get_model(
model_kwargs.setdefault("torch_dtype", "auto")

if "vila" in ckpt_path.lower():
from transformers import AutoModel

hf_vila = AutoModel.from_pretrained(
ckpt_path,
device_map=device_map,
Expand Down Expand Up @@ -353,13 +428,13 @@ def get_model(
if not hasattr(transformers, architecture):
warnings.warn(
f"Architecture {architecture} not found in transformers: {transformers.__version__}. "
"Falling back to AutoModelForCausalLM."
"Falling back to AutoModel."
)
assert trust_remote_code, (
"Please set trust_remote_code to True if you want to use this architecture"
)

auto_model_module = AutoModelForCausalLM
auto_model_module = AutoModel
from_config = auto_model_module.from_config
else:
auto_model_module = getattr(transformers, architecture)
Expand All @@ -370,7 +445,7 @@ def get_model(
# unless specified by the hf_config.
torch_dtype = getattr(hf_config, "torch_dtype", torch.float16)
model_kwargs2 = model_kwargs.copy()
if auto_model_module != AutoModelForCausalLM:
if auto_model_module not in [AutoModelForCausalLM, AutoModel]:
model_kwargs2.pop("trust_remote_code", None)
model_kwargs2["torch_dtype"] = torch_dtype
model_kwargs2.pop("max_memory", None)
Expand Down Expand Up @@ -406,6 +481,21 @@ def get_model(
print(f"Moving model to {device} device...")
model = model.to(device)

# For Nemotron-Parse, ensure the encoder (including RADIO) is fully on device
# The RADIO encoder has buffers that might not be properly moved even with device_map="auto"
# This is because custom RADIO modules might not fully support accelerate's device_map
if device != "cpu" and hasattr(model, "encoder"):
# Check if encoder has any buffers on CPU
cpu_buffers = []
for name, buffer in model.encoder.named_buffers():
if buffer.device.type == "cpu":
cpu_buffers.append(name)

if cpu_buffers:
print(f"Found {len(cpu_buffers)} encoder buffers on CPU. Moving encoder to {device}...")
model.encoder = model.encoder.to(device)
print(f"Encoder moved to {device}")

if device == "cuda" and not is_model_on_gpu(model):
print("Warning: Some parameters are not on a GPU. Calibration can be slow or hit OOM")

Expand Down
Loading
Loading