-
Notifications
You must be signed in to change notification settings - Fork 292
feat: add MobiusModelBuilder Olive pass for mobius-backed ONNX export #2406
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
60fe271
5921223
9d77132
5ba5b1f
cd86ba3
3ee4a23
c82f407
8c1259c
2eb7de5
209b616
0c4a3cf
be13f27
ee7fbd4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,12 @@ | ||
| { | ||
|
github-advanced-security[bot] marked this conversation as resolved.
Fixed
|
||
| "comment": "Build google/gemma-4-E2B-it as a float32 ONNX model using mobius, targeting CPU execution. E2B and E4B are Any-to-Any (vision + audio + text). For Image-Text-to-Text only models (no audio encoder), use google/gemma-4-26B-A4B-it or google/gemma-4-31B-it.", | ||
| "input_model": { "type": "HfModel", "model_path": "google/gemma-4-E2B-it", "task": "text-generation" }, | ||
| "systems": { | ||
| "local_system": { | ||
| "type": "LocalSystem", | ||
| "accelerators": [ { "device": "cpu", "execution_providers": [ "CPUExecutionProvider" ] } ] | ||
| } | ||
| }, | ||
| "passes": { "mobius_build": { "type": "MobiusModelBuilder", "precision": "fp32" } }, | ||
| "engine": { "target": "local_system", "output_dir": "models/gemma4-e2b-fp32-cpu", "log_severity_level": 1 } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,15 @@ | ||
| { | ||
|
github-advanced-security[bot] marked this conversation as resolved.
Fixed
github-advanced-security[bot] marked this conversation as resolved.
Fixed
|
||
| "comment": "Build google/gemma-4-E2B-it as a float16 ONNX model using mobius, then apply INT4 weight-only quantization for compact deployment. E2B and E4B are Any-to-Any (vision + audio + text). For Image-Text-to-Text only models (no audio encoder), use google/gemma-4-26B-A4B-it or google/gemma-4-31B-it.", | ||
| "input_model": { "type": "HfModel", "model_path": "google/gemma-4-E2B-it", "task": "text-generation" }, | ||
|
justinchuby marked this conversation as resolved.
|
||
| "systems": { | ||
| "local_system": { | ||
| "type": "LocalSystem", | ||
| "accelerators": [ { "device": "gpu", "execution_providers": [ "CUDAExecutionProvider" ] } ] | ||
| } | ||
| }, | ||
| "passes": { | ||
| "mobius_build": { "type": "MobiusModelBuilder", "precision": "fp16" }, | ||
| "int4_quantize": { "type": "GptqQuantizer", "bits": 4, "group_size": 128, "sym": true } | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the quantization pass works in a pytorch model and should be run before mobius
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a pass I can use to quantize after the model? I would like to use that as a an example for now
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you can use the rtn pass {
"type": "rtn",
"bits": 4,
"sym": false,
"group_size": 32,
"embeds": true,
"lm_head": true
}
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh, i misread the comment. you can use the blockwise quantizer pass: {
"type": "OnnxBlockWiseRtnQuantization",
"block_size": 128,
"is_symmetric": true,
"accuracy_level": 4,
"save_as_external_data": true
}
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks. And it will process all components together?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if the model is composite it should run the quantizer on each component and return a new composite model. |
||
| }, | ||
|
justinchuby marked this conversation as resolved.
|
||
| "engine": { "target": "local_system", "output_dir": "models/gemma4-e2b-int4-cuda", "log_severity_level": 1 } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,7 +3,6 @@ | |
| # Licensed under the MIT License. | ||
| # -------------------------------------------------------------------------- | ||
|
|
||
| # ruff: noqa: T201 | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: maybe we could do this removal in a different PR? these make the PR seem bigger than it actually is.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sorry I ran lintrunner in the wrong repo 😅 Looks like Olive's ruff version can be updated |
||
|
|
||
| from argparse import ArgumentParser, Namespace | ||
| from collections import OrderedDict | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
| @@ -0,0 +1,204 @@ | ||||
| # ------------------------------------------------------------------------- | ||||
|
github-advanced-security[bot] marked this conversation as resolved.
Fixed
github-advanced-security[bot] marked this conversation as resolved.
Fixed
|
||||
| # Copyright (c) Microsoft Corporation. All rights reserved. | ||||
| # Licensed under the MIT License. | ||||
| # -------------------------------------------------------------------------- | ||||
| """Build ONNX models from HuggingFace model IDs using the mobius package.""" | ||||
|
|
||||
| from __future__ import annotations | ||||
|
|
||||
| import logging | ||||
| from pathlib import Path | ||||
| from typing import TYPE_CHECKING, ClassVar | ||||
|
|
||||
| from olive.constants import Precision | ||||
| from olive.hardware.constants import ExecutionProvider | ||||
| from olive.model import HfModelHandler, ONNXModelHandler | ||||
| from olive.model.handler.composite import CompositeModelHandler | ||||
| from olive.passes import Pass | ||||
| from olive.passes.olive_pass import PassConfigParam | ||||
|
|
||||
| if TYPE_CHECKING: | ||||
| from olive.hardware.accelerator import AcceleratorSpec | ||||
| from olive.passes.pass_config import BasePassConfig | ||||
|
|
||||
| logger = logging.getLogger(__name__) | ||||
|
|
||||
| # Maps Olive Precision values to mobius dtype strings. | ||||
| # "f32" = 32-bit float (torch.float32), standard full precision. | ||||
| # "f16" = 16-bit float (torch.float16), half precision — good for GPU inference. | ||||
| # "bf16" = bfloat16 (torch.bfloat16), brain float — preferred over f16 on newer hardware. | ||||
| # For INT4/INT8 quantization, use a downstream Olive quantization pass (e.g. OnnxMatMulNBits) | ||||
| # after this pass rather than setting precision here. | ||||
| _PRECISION_TO_DTYPE: dict[str, str] = { | ||||
| Precision.FP32: "f32", | ||||
| Precision.FP16: "f16", | ||||
| Precision.BF16: "bf16", | ||||
| } | ||||
|
|
||||
|
|
||||
| class MobiusModelBuilder(Pass): | ||||
| """Olive pass that uses mobius to build ONNX models from HuggingFace model IDs. | ||||
|
|
||||
| Supports all model architectures registered in mobius (LLMs, VLMs, speech | ||||
| models, diffusion models). For multi-component models (e.g. vision-language | ||||
| models that produce ``model``, ``vision``, and ``embedding`` sub-graphs) the | ||||
| pass returns a :class:`~olive.model.handler.composite.CompositeModelHandler` | ||||
| whose components are individual :class:`~olive.model.ONNXModelHandler` objects. | ||||
| Single-component models return a plain :class:`~olive.model.ONNXModelHandler`. | ||||
|
|
||||
| Requires ``mobius-ai`` to be installed:: | ||||
|
|
||||
| pip install mobius-ai | ||||
|
|
||||
|
justinchuby marked this conversation as resolved.
|
||||
| See https://github.com/microsoft/mobius | ||||
| """ | ||||
|
|
||||
| # Maps Olive ExecutionProvider enum values to mobius EP names. | ||||
| EP_MAP: ClassVar[dict[ExecutionProvider, str]] = { | ||||
| ExecutionProvider.CPUExecutionProvider: "cpu", | ||||
| ExecutionProvider.CUDAExecutionProvider: "cuda", | ||||
| ExecutionProvider.DmlExecutionProvider: "dml", | ||||
| ExecutionProvider.WebGpuExecutionProvider: "webgpu", | ||||
| } | ||||
|
|
||||
| @classmethod | ||||
| def is_accelerator_agnostic(cls, accelerator_spec: AcceleratorSpec) -> bool: | ||||
| # EP selection determines which fused ops are emitted, so this pass is | ||||
| # EP-specific. | ||||
| return False | ||||
|
|
||||
| @classmethod | ||||
| def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassConfigParam]: | ||||
| return { | ||||
| "precision": PassConfigParam( | ||||
| type_=Precision, | ||||
| required=False, | ||||
| default_value=Precision.FP32, | ||||
| description=( | ||||
| "Model weight / compute precision. One of: fp32, fp16, bf16. " | ||||
| "Defaults to fp32. For INT4 quantization, run an Olive " | ||||
| "quantization pass (e.g. OnnxMatMulNBits) after this pass." | ||||
| ), | ||||
| ), | ||||
| "execution_provider": PassConfigParam( | ||||
| type_=str, | ||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we could create an enum of the supported eps for automatic validation like in Olive/olive/passes/pytorch/autoawq.py Line 27 in 8b1957e
unless you think the options might keep growing and it would be hard to keep it in sync across versions |
||||
| required=False, | ||||
| default_value=None, | ||||
| description=( | ||||
| "Override the mobius execution provider (cpu, cuda, dml, webgpu). " | ||||
| "When None (default), the EP is auto-detected from the Olive " | ||||
| "accelerator spec." | ||||
| ), | ||||
| ), | ||||
| } | ||||
|
|
||||
| def _run_for_config( | ||||
| self, | ||||
| model: HfModelHandler, | ||||
| config: type[BasePassConfig], | ||||
| output_model_path: str, | ||||
| ) -> ONNXModelHandler | CompositeModelHandler: | ||||
| try: | ||||
| from mobius import build | ||||
| except ImportError as exc: | ||||
| raise ImportError( | ||||
| "mobius-ai is required to run MobiusModelBuilder. Install with: pip install mobius-ai" | ||||
| ) from exc | ||||
|
justinchuby marked this conversation as resolved.
|
||||
|
|
||||
| if not isinstance(model, HfModelHandler): | ||||
| raise ValueError(f"MobiusModelBuilder requires an HfModelHandler input, got {type(model).__name__}.") | ||||
|
|
||||
| # Resolve EP: explicit config override > accelerator spec > fallback to cpu. | ||||
| ep_str: str = config.execution_provider or self.EP_MAP.get(self.accelerator_spec.execution_provider, "cpu") | ||||
|
|
||||
| dtype_str: str = _PRECISION_TO_DTYPE.get(config.precision, "f32") | ||||
| model_id: str = model.model_name_or_path | ||||
|
|
||||
| # Read trust_remote_code from the model's HuggingFace load kwargs. | ||||
| trust_remote_code: bool = model.get_load_kwargs().get("trust_remote_code", False) | ||||
|
|
||||
| logger.info( | ||||
| "MobiusModelBuilder: building '%s' (ep=%s, dtype=%s)", | ||||
| model_id, | ||||
| ep_str, | ||||
| dtype_str, | ||||
| ) | ||||
|
|
||||
| if trust_remote_code: | ||||
| logger.warning("MobiusModelBuilder: trust_remote_code=True — only use with trusted model sources.") | ||||
|
|
||||
| output_dir = Path(output_model_path) | ||||
| output_dir.mkdir(parents=True, exist_ok=True) | ||||
|
|
||||
| pkg = build( | ||||
| model_id, | ||||
| dtype=dtype_str, | ||||
| execution_provider=ep_str, | ||||
| load_weights=True, | ||||
| trust_remote_code=trust_remote_code, | ||||
| ) | ||||
|
|
||||
| # ModelPackage.save() handles both single and multi-component layouts: | ||||
| # single component → <output_dir>/model.onnx | ||||
| # multi-component → <output_dir>/<name>/model.onnx for each key | ||||
| pkg.save(str(output_dir)) | ||||
|
|
||||
| package_keys = list(pkg.keys()) | ||||
| logger.info("MobiusModelBuilder: saved components %s to '%s'", package_keys, output_dir) | ||||
|
|
||||
| if len(package_keys) == 1: | ||||
| # Single-component model (most LLMs): return a plain ONNXModelHandler. | ||||
| onnx_path = output_dir / "model.onnx" | ||||
| if not onnx_path.exists(): | ||||
| raise RuntimeError( | ||||
| f"MobiusModelBuilder: expected output file not found: {onnx_path}. " | ||||
| "mobius.build() may have failed silently or saved to an unexpected path." | ||||
| ) | ||||
| additional_files = sorted( | ||||
| {str(fp) for fp in output_dir.iterdir()} - {str(onnx_path), str(onnx_path) + ".data"} | ||||
| ) | ||||
| return ONNXModelHandler( | ||||
| model_path=str(output_dir), | ||||
| onnx_file_name="model.onnx", | ||||
| model_attributes={ | ||||
| "mobius_package_keys": package_keys, | ||||
|
jambayk marked this conversation as resolved.
|
||||
| "additional_files": additional_files, | ||||
| **(model.model_attributes or {}), | ||||
| }, | ||||
| ) | ||||
|
|
||||
| # Multi-component model (VLMs, encoder-decoders, diffusion pipelines): | ||||
| # mobius saves each component to <output_dir>/<key>/model.onnx. | ||||
| components = [] | ||||
| for key in package_keys: | ||||
| component_dir = output_dir / key | ||||
| onnx_path = component_dir / "model.onnx" | ||||
| if not onnx_path.exists(): | ||||
| raise RuntimeError( | ||||
| f"MobiusModelBuilder: expected output file not found: {onnx_path}. " | ||||
| f"mobius.build() may have failed silently for component '{key}'." | ||||
| ) | ||||
| additional_files = sorted( | ||||
| {str(fp) for fp in component_dir.iterdir()} - {str(onnx_path), str(onnx_path) + ".data"} | ||||
| ) | ||||
| components.append( | ||||
| ONNXModelHandler( | ||||
| model_path=str(component_dir), | ||||
| onnx_file_name="model.onnx", | ||||
| model_attributes={ | ||||
| "mobius_component": key, | ||||
| "additional_files": additional_files, | ||||
| **(model.model_attributes or {}), | ||||
| }, | ||||
| ) | ||||
| ) | ||||
|
|
||||
| return CompositeModelHandler( | ||||
| model_components=components, | ||||
| model_component_names=package_keys, | ||||
| model_path=str(output_dir), | ||||
| model_attributes={ | ||||
| "mobius_package_keys": package_keys, | ||||
| **(model.model_attributes or {}), | ||||
| }, | ||||
| ) | ||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we don't keep examples in this repo anymore. can you create an accompanying PR in microsoft/olive-recipes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am just using this PR to iterate the files. Could you comment on whether there are errors or changes needed? I will move the files over once stable.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some comments: