-
Notifications
You must be signed in to change notification settings - Fork 243
Support multiple-batch input for autocast calibration. #760
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -19,12 +19,16 @@ | |||||||||||||||||
| implementation. It supports both random input generation and user-provided inputs through | ||||||||||||||||||
| NPZ or Polygraphy JSON files. The runner is used to analyze model behavior and validate | ||||||||||||||||||
| outputs during precision conversion. | ||||||||||||||||||
|
|
||||||||||||||||||
| When multiple batches of calibration data are provided, the runner aggregates statistics | ||||||||||||||||||
| across all batches to provide more robust range information for precision conversion decisions. | ||||||||||||||||||
| """ | ||||||||||||||||||
|
|
||||||||||||||||||
| import copy | ||||||||||||||||||
| import io | ||||||||||||||||||
| import sys | ||||||||||||||||||
| from collections import OrderedDict | ||||||||||||||||||
| from dataclasses import dataclass | ||||||||||||||||||
|
|
||||||||||||||||||
| import numpy as np | ||||||||||||||||||
| import onnx | ||||||||||||||||||
|
|
@@ -35,6 +39,35 @@ | |||||||||||||||||
| configure_logging() | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| @dataclass | ||||||||||||||||||
| class TensorStats: | ||||||||||||||||||
| """Statistics for a tensor aggregated across multiple batches. | ||||||||||||||||||
|
|
||||||||||||||||||
| Attributes: | ||||||||||||||||||
| absmax: Maximum absolute value across all batches. | ||||||||||||||||||
| min_val: Minimum value across all batches. | ||||||||||||||||||
| max_val: Maximum value across all batches. | ||||||||||||||||||
| shape: Shape of the tensor (from first batch). | ||||||||||||||||||
| """ | ||||||||||||||||||
|
|
||||||||||||||||||
| absmax: float | ||||||||||||||||||
| min_val: float | ||||||||||||||||||
| max_val: float | ||||||||||||||||||
| shape: tuple | ||||||||||||||||||
|
|
||||||||||||||||||
| def __abs__(self): | ||||||||||||||||||
| """Return the maximum absolute value (for compatibility with np.abs).""" | ||||||||||||||||||
| return self.absmax | ||||||||||||||||||
|
|
||||||||||||||||||
| @property | ||||||||||||||||||
| def size(self): | ||||||||||||||||||
| """Return total number of elements.""" | ||||||||||||||||||
| result = 1 | ||||||||||||||||||
| for dim in self.shape: | ||||||||||||||||||
| result *= dim | ||||||||||||||||||
| return result | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| class ReferenceRunner: | ||||||||||||||||||
| """A class to run ONNX models with ONNXRuntime for reference inference.""" | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
@@ -69,8 +102,29 @@ def _load_inputs_from_json(self, input_data_path): | |||||||||||||||||
| return load_json(input_data_path, description="input data") | ||||||||||||||||||
|
|
||||||||||||||||||
| def _load_inputs_from_npz(self, input_data_path): | ||||||||||||||||||
| """Load inputs from NPZ format.""" | ||||||||||||||||||
| return [np.load(input_data_path)] | ||||||||||||||||||
| """Load inputs from NPZ format. | ||||||||||||||||||
|
|
||||||||||||||||||
| Supports both single NPZ file and directory containing multiple NPZ files for multi-batch calibration. | ||||||||||||||||||
|
|
||||||||||||||||||
| Args: | ||||||||||||||||||
| input_data_path: Path to NPZ file or directory containing NPZ files. | ||||||||||||||||||
|
|
||||||||||||||||||
| Returns: | ||||||||||||||||||
| List of input dictionaries, one per batch. | ||||||||||||||||||
| """ | ||||||||||||||||||
| import os | ||||||||||||||||||
|
|
||||||||||||||||||
| if os.path.isdir(input_data_path): | ||||||||||||||||||
| # Load all NPZ files in the directory as multiple batches | ||||||||||||||||||
| npz_files = sorted( | ||||||||||||||||||
| [f for f in os.listdir(input_data_path) if f.endswith(".npz")] | ||||||||||||||||||
| ) | ||||||||||||||||||
| if not npz_files: | ||||||||||||||||||
| raise ValueError(f"No NPZ files found in directory: {input_data_path}") | ||||||||||||||||||
| logger.info(f"Loading {len(npz_files)} NPZ files from directory for multi-batch calibration") | ||||||||||||||||||
| return [np.load(os.path.join(input_data_path, f)) for f in npz_files] | ||||||||||||||||||
| else: | ||||||||||||||||||
| return [np.load(input_data_path)] | ||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Edit: This additional fix will be added in another PR. Suggestion to support bug 5676209 (single NPZ file containing multiple batches):
Suggested change
The
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This also requires a minor change in def __init__(
self,
# Before #1
onnx_path: str,
# After #1
onnx_path: str | onnx.ModelProto,
):
...
# Before #2
onnx_model = onnx.load(onnx_path)
# After #2
onnx_model = onnx.load(onnx_path) if isinstance(onnx_path, str) else onnx_path
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please also update the
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll add this fix in a different PR, so please ignore this for now. Thanks. |
||||||||||||||||||
|
|
||||||||||||||||||
| def _validate_inputs(self, data_loader): | ||||||||||||||||||
| """Validate that input names and shapes match the model.""" | ||||||||||||||||||
|
|
@@ -96,16 +150,18 @@ def _load_inputs(self, inputs): | |||||||||||||||||
| # If no inputs are provided, use random inputs | ||||||||||||||||||
| data_loader = DataLoader(val_range={"": (-1, 1)}) | ||||||||||||||||||
|
|
||||||||||||||||||
| import os | ||||||||||||||||||
|
|
||||||||||||||||||
| if inputs is not None: | ||||||||||||||||||
| if isinstance(inputs, str): | ||||||||||||||||||
| if inputs.endswith(".json"): | ||||||||||||||||||
| data_loader = self._load_inputs_from_json(inputs) | ||||||||||||||||||
| elif inputs.endswith(".npz"): | ||||||||||||||||||
| elif inputs.endswith(".npz") or os.path.isdir(inputs): | ||||||||||||||||||
| data_loader = self._load_inputs_from_npz(inputs) | ||||||||||||||||||
| else: | ||||||||||||||||||
| raise ValueError( | ||||||||||||||||||
| f"Invalid input file: {inputs}. Supported input file types: .json (Polygraphy JSON format), " | ||||||||||||||||||
| ".npz (Numpy)" | ||||||||||||||||||
| f"Invalid input file: {inputs}. Supported input types: .json (Polygraphy JSON format), " | ||||||||||||||||||
| ".npz (Numpy), or a directory containing .npz files" | ||||||||||||||||||
| ) | ||||||||||||||||||
| elif isinstance(inputs, (dict, OrderedDict)): | ||||||||||||||||||
| data_loader = [inputs] | ||||||||||||||||||
|
|
@@ -118,8 +174,71 @@ def _load_inputs(self, inputs): | |||||||||||||||||
|
|
||||||||||||||||||
| return data_loader | ||||||||||||||||||
|
|
||||||||||||||||||
| def _aggregate_tensor_stats(self, all_batch_data: list[OrderedDict]) -> OrderedDict: | ||||||||||||||||||
| """Aggregate tensor statistics across multiple batches. | ||||||||||||||||||
|
|
||||||||||||||||||
| Args: | ||||||||||||||||||
| all_batch_data: List of dictionaries containing tensor data for each batch. | ||||||||||||||||||
|
|
||||||||||||||||||
| Returns: | ||||||||||||||||||
| OrderedDict mapping tensor names to TensorStats objects. | ||||||||||||||||||
| """ | ||||||||||||||||||
| if len(all_batch_data) == 1: | ||||||||||||||||||
| # Single batch - return raw data for backward compatibility | ||||||||||||||||||
| return all_batch_data[0] | ||||||||||||||||||
|
|
||||||||||||||||||
| logger.info(f"Aggregating statistics across {len(all_batch_data)} batches...") | ||||||||||||||||||
|
|
||||||||||||||||||
| aggregated = OrderedDict() | ||||||||||||||||||
| tensor_names = all_batch_data[0].keys() | ||||||||||||||||||
|
|
||||||||||||||||||
| for name in tensor_names: | ||||||||||||||||||
| absmax = -np.inf | ||||||||||||||||||
| min_val = np.inf | ||||||||||||||||||
| max_val = -np.inf | ||||||||||||||||||
| shape = None | ||||||||||||||||||
|
|
||||||||||||||||||
| for batch_data in all_batch_data: | ||||||||||||||||||
| if name not in batch_data: | ||||||||||||||||||
| continue | ||||||||||||||||||
| data = batch_data[name] | ||||||||||||||||||
| if shape is None: | ||||||||||||||||||
| shape = data.shape | ||||||||||||||||||
|
|
||||||||||||||||||
| batch_absmax = np.max(np.abs(data)) if data.size > 0 else 0 | ||||||||||||||||||
| batch_min = np.min(data) if data.size > 0 else 0 | ||||||||||||||||||
| batch_max = np.max(data) if data.size > 0 else 0 | ||||||||||||||||||
|
|
||||||||||||||||||
| absmax = max(absmax, batch_absmax) | ||||||||||||||||||
| min_val = min(min_val, batch_min) | ||||||||||||||||||
| max_val = max(max_val, batch_max) | ||||||||||||||||||
|
|
||||||||||||||||||
| if shape is not None: | ||||||||||||||||||
| aggregated[name] = TensorStats( | ||||||||||||||||||
| absmax=absmax, | ||||||||||||||||||
| min_val=min_val, | ||||||||||||||||||
| max_val=max_val, | ||||||||||||||||||
| shape=shape, | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| return aggregated | ||||||||||||||||||
|
|
||||||||||||||||||
| def run(self, inputs=None): | ||||||||||||||||||
| """Run FP32 inference with provided or random inputs.""" | ||||||||||||||||||
| """Run FP32 inference with provided or random inputs. | ||||||||||||||||||
|
|
||||||||||||||||||
| When multiple batches of input data are provided, inference is run for each batch | ||||||||||||||||||
| and statistics are aggregated across all batches for more robust range estimation. | ||||||||||||||||||
|
|
||||||||||||||||||
| Args: | ||||||||||||||||||
| inputs: Optional input data. Can be: | ||||||||||||||||||
| - None: Random inputs will be generated | ||||||||||||||||||
| - str: Path to JSON file, NPZ file, or directory containing NPZ files | ||||||||||||||||||
| - dict/OrderedDict: Single batch of input data | ||||||||||||||||||
|
|
||||||||||||||||||
| Returns: | ||||||||||||||||||
| OrderedDict: Combined input and output data. For single batch, returns raw arrays. | ||||||||||||||||||
| For multiple batches, returns TensorStats objects with aggregated statistics. | ||||||||||||||||||
| """ | ||||||||||||||||||
| import onnxruntime as ort | ||||||||||||||||||
| from polygraphy import constants | ||||||||||||||||||
| from polygraphy.backend.onnx import BytesFromOnnx | ||||||||||||||||||
|
|
@@ -156,15 +275,30 @@ def run(self, inputs=None): | |||||||||||||||||
| logger.error(f"ONNXRuntime execution failed with output:\n{captured_output}") | ||||||||||||||||||
| raise Exception("ONNXRuntime failed to run, see logs for details") | ||||||||||||||||||
|
|
||||||||||||||||||
| # Get the output results | ||||||||||||||||||
| output_dict = OrderedDict(results[0][1][0]) | ||||||||||||||||||
| # Collect all batch data (inputs + outputs) | ||||||||||||||||||
| all_batch_data = [] | ||||||||||||||||||
| runner_results = results[0][1] # Get all iteration results for the first runner | ||||||||||||||||||
| data_loader_iter = iter(data_loader) | ||||||||||||||||||
|
|
||||||||||||||||||
| for iter_idx, iter_result in enumerate(runner_results): | ||||||||||||||||||
| output_dict = OrderedDict(iter_result) | ||||||||||||||||||
|
|
||||||||||||||||||
| # Get corresponding input data | ||||||||||||||||||
| try: | ||||||||||||||||||
| input_data = next(data_loader_iter) | ||||||||||||||||||
| except StopIteration: | ||||||||||||||||||
| # If data_loader is exhausted, it might be a DataLoader that generates random data | ||||||||||||||||||
| input_data = {} | ||||||||||||||||||
|
|
||||||||||||||||||
| # Include input data for completeness | ||||||||||||||||||
| input_data = next(iter(data_loader)) | ||||||||||||||||||
| # Combine inputs and outputs for this batch | ||||||||||||||||||
| batch_dict = OrderedDict() | ||||||||||||||||||
| batch_dict.update(input_data) | ||||||||||||||||||
| batch_dict.update(output_dict) | ||||||||||||||||||
| all_batch_data.append(batch_dict) | ||||||||||||||||||
|
|
||||||||||||||||||
| # Combine inputs and outputs in the returned dictionary | ||||||||||||||||||
| combined_dict = OrderedDict() | ||||||||||||||||||
| combined_dict.update(input_data) | ||||||||||||||||||
| combined_dict.update(output_dict) | ||||||||||||||||||
| num_batches = len(all_batch_data) | ||||||||||||||||||
| if num_batches > 1: | ||||||||||||||||||
| logger.info(f"Processed {num_batches} batches of calibration data") | ||||||||||||||||||
|
|
||||||||||||||||||
| return combined_dict | ||||||||||||||||||
| # Aggregate statistics across all batches | ||||||||||||||||||
| return self._aggregate_tensor_stats(all_batch_data) | ||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please add a test for this? Thanks!