Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion modelopt/onnx/autocast/precisionconverter.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def convert(
onnx.ModelProto: The converted mixed precision model.
"""
try:
self.model = onnx_utils.check_model(self.model)
onnx_utils.check_model(self.model)
except onnx.checker.ValidationError as e:
logger.error(f"Internal error: onnx.checker failed on input model {e}")
raise Exception(
Expand Down
61 changes: 56 additions & 5 deletions modelopt/onnx/autocast/referencerunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@
import copy
import io
import sys
import tempfile
from collections import OrderedDict

import numpy as np
import onnx

from modelopt.onnx import utils as onnx_utils
from modelopt.onnx.autocast.logging_config import configure_logging, logger
from modelopt.onnx.quantization.ort_utils import _prepare_ep_list

Expand Down Expand Up @@ -118,13 +120,62 @@ def _load_inputs(self, inputs):

return data_loader

def _get_ort_runner(self, model):
import onnxruntime as ort
from polygraphy.backend.onnx import BytesFromOnnx
Copy link

@pranavm-nvidia pranavm-nvidia Jan 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from polygraphy.backend.onnx import BytesFromOnnx
from polygraphy.backend.onnx import BytesFromOnnx, save_onnx

from polygraphy.backend.onnxrt import OnnxrtRunner, SessionFromOnnx

# Check if model has external data by checking:
# 1. If any initializer has data_location set to EXTERNAL (even if data is loaded)
# 2. If model size would exceed 2GB (indicating need for external data)
has_external_data = onnx_utils.check_model_uses_external_data(self.model)

# Also check if model would be too large (>2GB) for SerializeToString
# This handles cases where model was loaded with external data already loaded
if not has_external_data:
try:
# Try to estimate size by serializing the model
# If it fails or exceeds 2GB, we need file-based approach
model_size = len(self.model.SerializeToString())

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
model_size = len(self.model.SerializeToString())
model_size = model.ByteSize()

if model_size > 2 * (1024**3): # 2GB threshold
has_external_data = True
logger.debug(
f"Model size ({model_size / (1024**3):.2f} GB) exceeds 2GB, using file-based approach"
)
except (ValueError, AttributeError) as e:
# SerializeToString failed (likely >2GB limit), use file-based approach
if "exceeds maximum protobuf size" in str(e) or "2GB" in str(e):
has_external_data = True
logger.debug("Model exceeds protobuf 2GB limit, using file-based approach")

if has_external_data:
logger.debug("Model has external data, using file-based approach")
# Get the actual ONNX ModelProto from ModifyOutputs wrapper
modified_model = model()

# Use a persistent temp file to handle external data files properly
tmp_file = tempfile.NamedTemporaryFile(suffix=".onnx", delete=False)
tmp_file.close()
tmp_file_path = tmp_file.name
onnx_utils.save_onnx(modified_model, tmp_file_path, save_as_external_data=True)
logger.debug(f"Model with all outputs saved to {tmp_file_path}")
session = ort.InferenceSession(tmp_file_path, providers=self.providers)
Comment on lines +157 to +162

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI Polygraphy's SaveOnnx can handle models with external data. Also SessionFromOnnx can accept paths.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @pranavm-nvidia . I'll take a look and see if I can refactor.
If it's an quick fix for you - feel free to push a commit to this PR, and I'll review.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a suggestion with the change. One thing I'm not sure about - does your onnx_utils.save_onnx do anything special besides saving the model? On quick inspection, it seems like it's also setting a custom IR version. If that's still required, you'll probably need to add a line like:

modified_model.ir_version = 10

prior to calling Polygraphy's save_onnx.

runners = [OnnxrtRunner(lambda: session)]

else:
# For models without external data, use the original BytesFromOnnx approach (no tmp files)
logger.debug("Model has no external data, using BytesFromOnnx approach")
serialize_onnx = BytesFromOnnx(model)
build_onnxrt_session = SessionFromOnnx(serialize_onnx, providers=self.providers)
runners = [OnnxrtRunner(build_onnxrt_session)]
Comment on lines +151 to +170

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if has_external_data:
logger.debug("Model has external data, using file-based approach")
# Get the actual ONNX ModelProto from ModifyOutputs wrapper
modified_model = model()
# Use a persistent temp file to handle external data files properly
tmp_file = tempfile.NamedTemporaryFile(suffix=".onnx", delete=False)
tmp_file.close()
tmp_file_path = tmp_file.name
onnx_utils.save_onnx(modified_model, tmp_file_path, save_as_external_data=True)
logger.debug(f"Model with all outputs saved to {tmp_file_path}")
session = ort.InferenceSession(tmp_file_path, providers=self.providers)
runners = [OnnxrtRunner(lambda: session)]
else:
# For models without external data, use the original BytesFromOnnx approach (no tmp files)
logger.debug("Model has no external data, using BytesFromOnnx approach")
serialize_onnx = BytesFromOnnx(model)
build_onnxrt_session = SessionFromOnnx(serialize_onnx, providers=self.providers)
runners = [OnnxrtRunner(build_onnxrt_session)]
if has_external_data:
logger.debug("Model has external data, using file-based approach")
# Get the actual ONNX ModelProto from ModifyOutputs wrapper
modified_model = model()
# Use a persistent temp file to handle external data files properly
outdir = tempfile.TemporaryDirectory()
tmp_file_path = os.path.join(outdir.name, "tmp_model.onnx")
save_onnx(modified_model, tmp_file_path, external_data_path="ext.data")
logger.debug(f"Model with all outputs saved to {tmp_file_path}")
build_onnxrt_session = SessionFromOnnx(tmp_file_path, providers=self.providers)
else:
# For models without external data, use the original BytesFromOnnx approach (no tmp files)
logger.debug("Model has no external data, using BytesFromOnnx approach")
serialize_onnx = BytesFromOnnx(model)
build_onnxrt_session = SessionFromOnnx(serialize_onnx, providers=self.providers)
runners = [OnnxrtRunner(build_onnxrt_session)]


return runners

def run(self, inputs=None):
"""Run FP32 inference with provided or random inputs."""
import onnxruntime as ort
from polygraphy import constants
from polygraphy.backend.onnx import BytesFromOnnx
from polygraphy.backend.onnx import ModifyOutputs as ModifyOnnxOutputs
from polygraphy.backend.onnxrt import OnnxrtRunner, SessionFromOnnx
from polygraphy.comparator import Comparator

logger.info("Running ONNX Runtime to obtain reference outputs (this may take a while)...")
Expand All @@ -133,9 +184,9 @@ def run(self, inputs=None):

model_copy = copy.deepcopy(self.model)
modify_outputs = ModifyOnnxOutputs(model_copy, outputs=constants.MARK_ALL)
serialize_onnx = BytesFromOnnx(modify_outputs)
build_onnxrt_session = SessionFromOnnx(serialize_onnx, providers=self.providers)
runners = [OnnxrtRunner(build_onnxrt_session)]

# Load the modified model and create an inference session
runners = self._get_ort_runner(modify_outputs)

# Comparator is used despite the fact that we are using ONNXRuntime
# because it provides the ability to generate random inputs using DataLoader
Expand Down
25 changes: 20 additions & 5 deletions modelopt/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

"""Utility functions related to onnx."""

import copy
import io
import os
import tempfile
Expand Down Expand Up @@ -552,7 +553,7 @@ def _get_unique_name(old_name):
return onnx_model, is_modified


def check_model(model: onnx.ModelProto) -> onnx.ModelProto:
def check_model(model: onnx.ModelProto) -> None:
"""Checks if the given model is valid."""
if model.ByteSize() > (2 * (1024**3)): # 2GB limit
with tempfile.TemporaryDirectory() as temp_dir:
Expand All @@ -561,10 +562,8 @@ def check_model(model: onnx.ModelProto) -> onnx.ModelProto:
onnx_tmp_path = os.path.join(temp_dir, f"model_{unique_id}.onnx")
save_onnx(model, onnx_tmp_path, save_as_external_data=True)
onnx.checker.check_model(onnx_tmp_path)
return onnx.load(onnx_tmp_path)
else:
onnx.checker.check_model(model)
return model


def find_lowest_common_ancestor(node1: Node, node2: Node) -> tuple[str | None, int, int]:
Expand Down Expand Up @@ -658,15 +657,16 @@ def save_onnx(model: onnx.ModelProto, onnx_path: str, save_as_external_data: boo

# Set ir_version to 10, remove it once ORT supports ir_version 11
model.ir_version = 10

if save_as_external_data:
external_data_path = os.path.basename(onnx_path) + "_data"
if os.path.exists(external_data_path):
logger.warning(f"Removing existing external data file: {external_data_path}")
os.remove(external_data_path)

# Copy so the onnx.ModelProto object will not be modified
model_copy = copy.deepcopy(model)
onnx.save_model(
model,
model_copy,
onnx_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
Expand Down Expand Up @@ -696,6 +696,21 @@ def get_opset_version(model: onnx.ModelProto) -> int:
return ai_onnx_domain[0].version


def check_model_uses_external_data(model: onnx.ModelProto) -> bool:
"""Checks if the model uses external data.

Args:
model: Loaded in-memory onnx ModelProto.

Returns:
True if any initializer tensor has data_location set to EXTERNAL.
"""
return any(
init.HasField("data_location") and init.data_location == onnx.TensorProto.EXTERNAL
for init in model.graph.initializer
)


def bfloat16_to_float32(bf16_array):
"""Converts a bfloat16 array (as raw data) to a float32 array."""
uint32_array = bf16_array.astype(np.uint32) << 16
Expand Down
11 changes: 0 additions & 11 deletions modelopt/torch/_deploy/utils/onnx_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,3 @@ def _get_onnx_external_data_tensors(model: onnx.ModelProto) -> list[str]:
if tensor.HasField("data_location") and tensor.data_location == onnx.TensorProto.EXTERNAL
]
return model_tensors_ext


def check_model_uses_external_data(model: onnx.ModelProto) -> bool:
"""
Checks if the model uses external data.
"""
model_tensors = _get_initializer_tensors(model)
return any(
tensor.HasField("data_location") and tensor.data_location == onnx.TensorProto.EXTERNAL
for tensor in model_tensors
)
2 changes: 1 addition & 1 deletion modelopt/torch/_deploy/utils/torch_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
)
from modelopt.onnx.quantization.qdq_utils import qdq_to_dq, replace_zero_scale_with_smallest_nonzero
from modelopt.onnx.utils import (
check_model_uses_external_data,
get_input_names,
get_input_shapes,
get_node_names,
Expand All @@ -55,7 +56,6 @@
from modelopt.torch.utils._pytree import TreeSpec

from ..utils.onnx_optimizer import Optimizer
from .onnx_utils import check_model_uses_external_data

ModelMetadata = dict[str, Any]
ModelType = Any
Expand Down