Skip to content

Conversation

@jingyu-ml
Copy link
Contributor

@jingyu-ml jingyu-ml commented Jan 23, 2026

What does this PR do?

Type of change: New feature

Overview:

This MR adds HuggingFace checkpoint export support for LTX‑2 by treating TI2VidTwoStagesPipeline as a diffusion-like pipeline, exporting only the stage‑1 transformer (with QKV-fusion-enabled dummy inputs) and falling back to writing model.safetensors when save_pretrained isn’t available. It also preserves the original forward in DynamicModule patching (_forward_pre_dm) so downstream callers can still invoke the pre-patched forward implementation.

Changes

  1. Added the calibration & quantization support of the LTX2, even with FP8 precision.
  2. Preserve original forward before DynamicModule patching: when patching forward, we now stash the pre-patched implementation in self._forward_pre_dm (once) so downstream code can still call the original forward, then re-bind forward to the class implementation. This is needed for the LTX2 FP8 calibration.
  3. Added LTX‑2 HF export path: export_hf_checkpoint() now also treats ltx_pipelines.ti2vid_two_stages.TI2VidTwoStagesPipeline as a “diffusion-like” object and routes it through _export_diffusers_checkpoint() (import guarded; no hard dependency).
  4. Generalized component discovery: introduced get_diffusion_components() (aliasing the old get_diffusers_components) to support non-diffusers pipelines; for LTX‑2 it returns only stage_1_transformer.
  5. Enabled QKV fusion for LTX‑2 backbone: added a model-aware dummy forward generator (generate_diffusion_dummy_forward_fn) that builds minimal LTX Modality inputs (including correct timesteps broadcasting) so shared-input hooks can run and fuse QKV when applicable.
  6. Export fallback for non-save_pretrained modules: when a component lacks save_pretrained (LTX‑2 transformer), export now writes model.safetensors + minimal config.json instead of pytorch_model.bin.

Plans

  • [1/4] Add the basic functionalities to support limited image models with NVFP4 + FP8, with some refactoring on the previous LLM code and the diffusers example. PIC: @jingyu-ml
  • [2/4] Add support to more video gen models. PIC: @jingyu-ml
  • [3/4] Add test cases, refactor on the doc, and all related README. PIC: @jingyu-ml
  • [4/4] Add the final support to ComfyUI. PIC @jingyu-ml

Usage

python quantize.py --model ltx-2 --format fp4 --batch-size 64 --calib-size 1 --n-steps 40 --extra-param checkpoint_path=/home/scratch.omniml_data_2/jingyux/models/LTX-2/ltx-2-19b-dev-fp8.safetensors --extra-param distilled_lora_path=/home/scratch.omniml_data_2/jingyux/models/LTX-2/ltx-2-19b-distilled-lora-384.safetensors --extra-param spatial_upsampler_path=/home/scratch.omniml_data_2/jingyux/models/LTX-2/ltx-2-spatial-upscaler-x2-1.0.safetensors --extra-param gemma_root=/home/scratch.omniml_data_2/jingyux/models/LTX-2/gemma-3-12b-it-qat-q4_0-unquantized --extra-param fp8transformer=true --hf-ckpt-dir ./ltx2-nvfp4

Testing

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes
  • Did you write any new necessary tests?:No
  • Did you add or update any necessary documentation?:No
  • Did you update Changelog?:No

Additional Information

Summary by CodeRabbit

Release Notes

  • New Features

    • Added LTX-2 video model support with complete quantization and export pipeline integration
    • Introduced --extra-param CLI option for flexible model configuration and parameter passing
    • Enhanced export capabilities with broader diffusion model compatibility
  • Chores

    • Changed default model data type from Half to BFloat16 for improved numerical stability

✏️ Tip: You can customize this high-level summary in your review settings.

Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
@jingyu-ml jingyu-ml requested review from a team as code owners January 23, 2026 07:43
@jingyu-ml jingyu-ml requested a review from a team as a code owner January 23, 2026 07:43
@jingyu-ml jingyu-ml marked this pull request as draft January 23, 2026 07:43
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 23, 2026

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

📝 Walkthrough

Walkthrough

This PR adds comprehensive support for quantizing LTX-2 video models within a diffusion-based quantization framework. Changes include new model registration and configuration, multi-stage pipeline creation with dynamic class lookup, LTX-2-specific quantization handlers, duck-typed pipeline export paths, and infrastructure updates to support both traditional diffusers and LTX-2 pipelines alongside module forward caching for preservation of pre-quantization methods.

Changes

Cohort / File(s) Summary
LTX-2 Model Registration & Configuration
examples/diffusers/quantization/models_utils.py
Added ModelType.LTX2 = "ltx-2" enum variant with corresponding registry entries mapping to Lightricks/LTX-2 model ID and default configuration (backbone, dataset, inference parameters including height, width, num_frames, frame_rate, cfg guidance scale).
Core Pipeline & Quantization Management
examples/diffusers/quantization/quantize.py
Introduced extra_params field in ModelConfig with CLI parsing support; added _create_ltx2_pipeline() and _ensure_ltx2_transformer_cached() to PipelineManager for dynamic LTX-2 pipeline creation; extended Calibrator with _run_ltx2_calibration() dispatch; updated Quantizer.quantize_model() return type from None to torch.nn.Module; added --extra-param CLI option with default --model-dtype changed to BFloat16.
Quantization Filter Functions
examples/diffusers/quantization/utils.py
Updated AttentionModuleMixin import path from modelopt.torch.quantization.plugins.diffusers to modelopt.torch.quantization.plugins.diffusion.diffusers; expanded filter_func_ltx_video pattern to include patchify_proj and adaln_single.
Generalized Diffusion Export Support
modelopt/torch/export/diffusers_utils.py
Added generate_diffusion_dummy_forward_fn() to support both LTX-2 duck-typed and traditional diffusers models; renamed get_diffusers_components() to get_diffusion_components() with broadened type acceptance (Any instead of union), added backward-compatible alias, and extended to handle LTX-2 stage-1 transformer exposure via stage_1_model_ledger.
Unified Export Infrastructure
modelopt/torch/export/unified_export_hf.py
Replaced generate_diffusion_dummy_inputs with generate_diffusion_dummy_forward_fn() and updated get_diffusers_components calls to get_diffusion_components(); introduced optional TI2VidTwoStagesPipeline import; tightened QKV fusion to require exactly 3 modules; added dummy_forward_fn parameter to _fuse_qkv_linears_diffusion() with graceful fallback on fusion failure; expanded _export_diffusers_checkpoint() and export_hf_checkpoint() with broader Any type hints and duck-typed pipeline detection.
Module Forward Method Preservation
modelopt/torch/opt/dynamic.py
Added caching of monkey-patched forward methods in _forward_pre_dm attribute within DynamicModule.convert() to preserve pre-DM forward for downstream consumers.
Quantization Forward Hook Support
modelopt/torch/quantization/nn/modules/quant_module.py
Modified QuantInputBase.forward() to conditionally invoke cached _forward_pre_dm pre-forward hook if present on instance.
Plugin Path Restructuring
modelopt/torch/quantization/plugins/__init__.py
Updated import paths for diffusers and fastvideo plugins from .diffusers and .fastvideo to .diffusion.diffusers and .diffusion.fastvideo.
Diffusion Plugin Imports
modelopt/torch/quantization/plugins/diffusion/diffusers.py, modelopt/torch/quantization/plugins/diffusion/fastvideo.py
Adjusted relative import paths from two-dot (..) to three-dot (...) to reflect deeper package hierarchy for export_onnx, nn, and module registry imports.
LTX-2 Linear Quantization
modelopt/torch/quantization/plugins/diffusion/ltx2.py
Introduced _QuantLTX2Linear class extending _QuantLinear that upcasts FP8-weighted tensors to bfloat16 before quantization; added register_ltx2_quant_linear() function to register custom linear layer handling.
Quantized Weight Detection
modelopt/torch/quantization/utils.py
Relaxed gate on standard weight attribute in quantized weight detection to yield "weight" when weight quantizer is present, regardless of whether weight is an nn.Parameter.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant Main as main()
    participant PM as PipelineManager
    participant Cal as Calibrator
    participant Quan as Quantizer
    participant EM as ExportManager
    
    User->>Main: invoke with LTX-2 config + extra_params
    Main->>Main: parse_extra_params() from CLI
    Main->>Main: create ModelConfig with extra_params
    
    Main->>PM: __init__(config)
    PM->>PM: create_pipeline() with extra_params
    PM->>PM: _create_ltx2_pipeline() for LTX-2
    PM->>PM: _ensure_ltx2_transformer_cached()
    PM-->>Main: pipeline ready
    
    Main->>PM: get_backbone()
    PM-->>Main: cached LTX-2 transformer
    
    Main->>Cal: run_calibration()
    Cal->>Cal: _run_ltx2_calibration() dispatch
    Cal->>PM: forward pipeline with LTX-2 prompts
    Cal-->>Main: calibration complete
    
    Main->>Quan: quantize_model(backbone, quant_config)
    Quan->>Quan: register_ltx2_quant_linear()
    Quan->>Quan: apply quantization with FP8 upcast
    Quan-->>Main: quantized backbone
    
    Main->>EM: export_hf_ckpt(pipeline)
    EM->>EM: generate_diffusion_dummy_forward_fn()
    EM->>EM: get_diffusion_components() with duck-typing
    EM->>EM: _export_diffusers_checkpoint() with Any type
    EM-->>Main: exported checkpoint
Loading
sequenceDiagram
    participant PM as PipelineManager
    participant DM as DynamicModule
    participant QM as QuantInputBase
    participant DL as _QuantLTX2Linear
    
    PM->>PM: create_pipeline() calls DiffusionPipeline
    PM->>DM: convert() wraps forward methods
    DM->>DM: bind_forward_method_if_needed()
    DM->>DM: cache original in _forward_pre_dm
    DM-->>PM: pipeline with cached forwards
    
    PM->>QM: forward() during inference
    QM->>QM: check _forward_pre_dm exists
    QM->>QM: invoke _forward_pre_dm() if cached
    QM->>DL: _get_quantized_weight() override
    DL->>DL: upcast FP8 to bfloat16 if needed
    DL-->>QM: quantized weight
    QM->>QM: apply output quantization
    QM-->>PM: quantized output
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 66.67% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title '[2/4] Diffusion Quantized ckpt export' directly describes the PR's main change: adding HuggingFace checkpoint export support for diffusion models (specifically LTX-2) with quantization capabilities. The title is specific and clearly summarizes the primary contribution.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@copy-pr-bot
Copy link

copy-pr-bot bot commented Jan 23, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@jingyu-ml jingyu-ml self-assigned this Jan 23, 2026
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
@jingyu-ml jingyu-ml changed the title Jingyux/2 3 diffusion export [2/4] Diffusion Quantized ckpt export Jan 23, 2026
@jingyu-ml jingyu-ml marked this pull request as ready for review January 23, 2026 22:36
@jingyu-ml jingyu-ml requested a review from a team as a code owner January 23, 2026 22:36
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
@codecov
Copy link

codecov bot commented Jan 24, 2026

Codecov Report

❌ Patch coverage is 90.00000% with 2 lines in your changes missing coverage. Please review.
✅ Project coverage is 74.19%. Comparing base (4f4558a) to head (ac5fcd0).
⚠️ Report is 2 commits behind head on main.

Files with missing lines Patch % Lines
...lopt/torch/quantization/nn/modules/quant_module.py 84.61% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #810      +/-   ##
==========================================
+ Coverage   74.17%   74.19%   +0.01%     
==========================================
  Files         192      192              
  Lines       19246    19264      +18     
==========================================
+ Hits        14276    14292      +16     
- Misses       4970     4972       +2     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant