Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
a33cf13
Your commit message describing all changes
jingyu-ml Jan 14, 2026
dff152b
Merge the diffusion and llms layer fusion code
jingyu-ml Jan 14, 2026
9e94843
Create a diffusers utils function, moved some functions to it
jingyu-ml Jan 14, 2026
db61c20
Merge branch 'main' into jingyux/diffusion.export-fixed
jingyu-ml Jan 14, 2026
8a81723
Fixed some bugs in the CI/CD
jingyu-ml Jan 14, 2026
16a2bbf
Merge branch 'main' into jingyux/diffusion.export-fixed
jingyu-ml Jan 14, 2026
68d5665
Move one function to diffusers utils
jingyu-ml Jan 14, 2026
ace5773
Merge branch 'main' into jingyux/diffusion.export-fixed
jingyu-ml Jan 15, 2026
95dfb52
removed the DiffusionPipeline import
jingyu-ml Jan 15, 2026
302e2f4
Update the example
jingyu-ml Jan 15, 2026
8eed21b
Fixed the CI/CD
jingyu-ml Jan 16, 2026
01d31d7
Update the CI/CD
jingyu-ml Jan 16, 2026
ca3fdaa
Update the Flux example & address Chenjie's comments
jingyu-ml Jan 16, 2026
44345f8
use single line of code
jingyu-ml Jan 16, 2026
78f12cc
Update the test case
jingyu-ml Jan 16, 2026
3911a3d
Add the support for the WAN video
jingyu-ml Jan 16, 2026
4cf9e76
Moved the has_quantized_modules to quant utils
jingyu-ml Jan 20, 2026
1da2b46
moving model specific configs to separate files
jingyu-ml Jan 20, 2026
eafedde
Merge branch 'main' into jingyux/diffusion.export-fixed
jingyu-ml Jan 20, 2026
3fb8320
Fixed the CI/CD
jingyu-ml Jan 20, 2026
372c6f7
Fixed the cicd
jingyu-ml Jan 20, 2026
e67bf85
reducee the repeated code
jingyu-ml Jan 21, 2026
9b5cf13
Merge branch 'main' into jingyux/diffusion.export-fixed
jingyu-ml Jan 21, 2026
e931fbc
Update the lint
jingyu-ml Jan 21, 2026
8b29228
Merge branch 'main' into jingyux/diffusion.export-fixed
jingyu-ml Jan 21, 2026
b8b5eaf
Merge branch 'main' into jingyux/2-3-diffusion-export
jingyu-ml Jan 22, 2026
b717bae
Add the LTX2 FP8/BF16 support + Some core code changes
jingyu-ml Jan 23, 2026
0d93e1a
Merge branch 'main' into jingyux/2-3-diffusion-export
jingyu-ml Jan 23, 2026
c2aadca
Update
jingyu-ml Jan 23, 2026
109c010
Merge branch 'main' into jingyux/2-3-diffusion-export
jingyu-ml Jan 23, 2026
d7aef93
Fixed the CICD
jingyu-ml Jan 23, 2026
ac5fcd0
Fixed more CICD
jingyu-ml Jan 24, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion examples/diffusers/quantization/models_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class ModelType(str, Enum):
FLUX_DEV = "flux-dev"
FLUX_SCHNELL = "flux-schnell"
LTX_VIDEO_DEV = "ltx-video-dev"
LTX2 = "ltx-2"
WAN22_T2V_14b = "wan2.2-t2v-14b"
WAN22_T2V_5b = "wan2.2-t2v-5b"

Expand All @@ -64,6 +65,7 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]:
ModelType.SD3_MEDIUM: filter_func_default,
ModelType.SD35_MEDIUM: filter_func_default,
ModelType.LTX_VIDEO_DEV: filter_func_ltx_video,
ModelType.LTX2: filter_func_ltx_video,
ModelType.WAN22_T2V_14b: filter_func_wan_video,
ModelType.WAN22_T2V_5b: filter_func_wan_video,
}
Expand All @@ -80,18 +82,20 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]:
ModelType.FLUX_DEV: "black-forest-labs/FLUX.1-dev",
ModelType.FLUX_SCHNELL: "black-forest-labs/FLUX.1-schnell",
ModelType.LTX_VIDEO_DEV: "Lightricks/LTX-Video-0.9.7-dev",
ModelType.LTX2: "Lightricks/LTX-2",
ModelType.WAN22_T2V_14b: "Wan-AI/Wan2.2-T2V-A14B-Diffusers",
ModelType.WAN22_T2V_5b: "Wan-AI/Wan2.2-TI2V-5B-Diffusers",
}

MODEL_PIPELINE: dict[ModelType, type[DiffusionPipeline]] = {
MODEL_PIPELINE: dict[ModelType, type[DiffusionPipeline] | None] = {
ModelType.SDXL_BASE: DiffusionPipeline,
ModelType.SDXL_TURBO: DiffusionPipeline,
ModelType.SD3_MEDIUM: StableDiffusion3Pipeline,
ModelType.SD35_MEDIUM: StableDiffusion3Pipeline,
ModelType.FLUX_DEV: FluxPipeline,
ModelType.FLUX_SCHNELL: FluxPipeline,
ModelType.LTX_VIDEO_DEV: LTXConditionPipeline,
ModelType.LTX2: None,
ModelType.WAN22_T2V_14b: WanPipeline,
ModelType.WAN22_T2V_5b: WanPipeline,
}
Expand Down Expand Up @@ -154,6 +158,18 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]:
"negative_prompt": "worst quality, inconsistent motion, blurry, jittery, distorted",
},
},
ModelType.LTX2: {
"backbone": "transformer",
"dataset": _SD_PROMPTS_DATASET,
"inference_extra_args": {
"height": 1024,
"width": 1536,
"num_frames": 121,
"frame_rate": 24.0,
"cfg_guidance_scale": 4.0,
"negative_prompt": "worst quality, inconsistent motion, blurry, jittery, distorted",
},
},
ModelType.WAN22_T2V_14b: {
**_WAN_BASE_CONFIG,
"from_pretrained_extra_args": {
Expand Down
168 changes: 160 additions & 8 deletions examples/diffusers/quantization/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ class ModelConfig:
override_model_path: Path | None = None
cpu_offloading: bool = False
ltx_skip_upsampler: bool = False # Skip upsampler for LTX-Video (faster calibration)
extra_params: dict[str, Any] = field(default_factory=dict)

@property
def model_path(self) -> str:
Expand Down Expand Up @@ -232,6 +233,51 @@ def setup_logging(verbose: bool = False) -> logging.Logger:
return logger


def _coerce_extra_param_value(value: str) -> Any:
lowered = value.lower()
if lowered in {"true", "false"}:
return lowered == "true"
try:
return int(value)
except ValueError:
pass
try:
return float(value)
except ValueError:
return value


def parse_extra_params(
kv_args: list[str], unknown_args: list[str], logger: logging.Logger
) -> dict[str, Any]:
extra_params: dict[str, Any] = {}
for item in kv_args:
if "=" not in item:
raise ValueError(f"Invalid --extra-param value: '{item}'. Expected KEY=VALUE.")
key, value = item.split("=", 1)
extra_params[key] = _coerce_extra_param_value(value)

i = 0
while i < len(unknown_args):
token = unknown_args[i]
if token.startswith("--extra_param."):
key = token[len("--extra_param.") :]
value = "true"
if i + 1 < len(unknown_args) and not unknown_args[i + 1].startswith("--"):
value = unknown_args[i + 1]
i += 1
extra_params[key] = _coerce_extra_param_value(value)
elif token.startswith("--extra_param"):
raise ValueError(
"Use --extra_param.KEY VALUE or --extra-param KEY=VALUE for extra parameters."
)
else:
logger.warning("Ignoring unknown argument: %s", token)
i += 1

return extra_params


class PipelineManager:
"""Manages diffusion pipeline creation and configuration."""

Expand All @@ -245,8 +291,9 @@ def __init__(self, config: ModelConfig, logger: logging.Logger):
"""
self.config = config
self.logger = logger
self.pipe: DiffusionPipeline | None = None
self.pipe: Any | None = None
self.pipe_upsample: LTXLatentUpsamplePipeline | None = None # For LTX-Video upsampling
self._transformer: torch.nn.Module | None = None

@staticmethod
def create_pipeline_from(
Expand All @@ -264,10 +311,13 @@ def create_pipeline_from(
ValueError: If model type is unsupported
"""
try:
pipeline_cls = MODEL_PIPELINE[model_type]
if pipeline_cls is None:
raise ValueError(f"Model type {model_type.value} does not use diffusers pipelines.")
model_id = (
MODEL_REGISTRY[model_type] if override_model_path is None else override_model_path
)
pipe = MODEL_PIPELINE[model_type].from_pretrained(
pipe = pipeline_cls.from_pretrained(
model_id,
torch_dtype=torch_dtype,
use_safetensors=True,
Expand All @@ -278,7 +328,7 @@ def create_pipeline_from(
except Exception as e:
raise e

def create_pipeline(self) -> DiffusionPipeline:
def create_pipeline(self) -> Any:
"""
Create and return an appropriate pipeline based on configuration.

Expand All @@ -293,6 +343,14 @@ def create_pipeline(self) -> DiffusionPipeline:
self.logger.info(f"Data type: {self.config.model_dtype}")

try:
if self.config.model_type == ModelType.LTX2:
from modelopt.torch.quantization.plugins.diffusion import ltx2 as ltx2_plugin

ltx2_plugin.register_ltx2_quant_linear()
self.pipe = self._create_ltx2_pipeline()
self.logger.info("LTX-2 pipeline created successfully")
return self.pipe

self.pipe = MODEL_PIPELINE[self.config.model_type].from_pretrained(
self.config.model_path,
torch_dtype=self.config.model_dtype,
Expand Down Expand Up @@ -325,6 +383,10 @@ def setup_device(self) -> None:
if not self.pipe:
raise RuntimeError("Pipeline not created. Call create_pipeline() first.")

if self.config.model_type == ModelType.LTX2:
self.logger.info("Skipping device setup for LTX-2 pipeline (handled internally)")
return

if self.config.cpu_offloading:
self.logger.info("Enabling CPU offloading for memory efficiency")
self.pipe.enable_model_cpu_offload()
Expand Down Expand Up @@ -352,8 +414,58 @@ def get_backbone(self) -> torch.nn.Module:
if not self.pipe:
raise RuntimeError("Pipeline not created. Call create_pipeline() first.")

if self.config.model_type == ModelType.LTX2:
self._ensure_ltx2_transformer_cached()
return self._transformer
return getattr(self.pipe, self.config.backbone)

def _ensure_ltx2_transformer_cached(self) -> None:
if not self.pipe:
raise RuntimeError("Pipeline not created. Call create_pipeline() first.")
if self._transformer is None:
transformer = self.pipe.stage_1_model_ledger.transformer()
self.pipe.stage_1_model_ledger.transformer = lambda: transformer
self._transformer = transformer

def _create_ltx2_pipeline(self) -> Any:
params = dict(self.config.extra_params)
checkpoint_path = params.pop("checkpoint_path", None)
distilled_lora_path = params.pop("distilled_lora_path", None)
distilled_lora_strength = params.pop("distilled_lora_strength", 0.8)
spatial_upsampler_path = params.pop("spatial_upsampler_path", None)
gemma_root = params.pop("gemma_root", None)
fp8transformer = params.pop("fp8transformer", False)

if not checkpoint_path:
raise ValueError("Missing required extra_param: checkpoint_path.")
if not distilled_lora_path:
raise ValueError("Missing required extra_param: distilled_lora_path.")
if not spatial_upsampler_path:
raise ValueError("Missing required extra_param: spatial_upsampler_path.")
if not gemma_root:
raise ValueError("Missing required extra_param: gemma_root.")

from ltx_core.loader import LTXV_LORA_COMFY_RENAMING_MAP, LoraPathStrengthAndSDOps
from ltx_pipelines.ti2vid_two_stages import TI2VidTwoStagesPipeline

distilled_lora = [
LoraPathStrengthAndSDOps(
str(distilled_lora_path),
float(distilled_lora_strength),
LTXV_LORA_COMFY_RENAMING_MAP,
)
]
pipeline_kwargs = {
"checkpoint_path": str(checkpoint_path),
"distilled_lora": distilled_lora,
"spatial_upsampler_path": str(spatial_upsampler_path),
"gemma_root": str(gemma_root),
"loras": [],
"fp8transformer": bool(fp8transformer),
}
pipeline_kwargs.update(params)
return TI2VidTwoStagesPipeline(**pipeline_kwargs)


class Calibrator:
"""Handles model calibration for quantization."""
Expand Down Expand Up @@ -417,7 +529,9 @@ def run_calibration(self, batched_prompts: list[list[str]]) -> None:
if i >= self.config.num_batches:
break

if self.model_type == ModelType.LTX_VIDEO_DEV:
if self.model_type == ModelType.LTX2:
self._run_ltx2_calibration(prompt_batch, extra_args)
elif self.model_type == ModelType.LTX_VIDEO_DEV:
# Special handling for LTX-Video
self._run_ltx_video_calibration(prompt_batch, extra_args)
elif self.model_type in [ModelType.WAN22_T2V_14b, ModelType.WAN22_T2V_5b]:
Expand Down Expand Up @@ -448,6 +562,29 @@ def _run_wan_video_calibration(

self.pipe(prompt=prompt_batch, **kwargs).frames # type: ignore[misc]

def _run_ltx2_calibration(self, prompt_batch: list[str], extra_args: dict[str, Any]) -> None:
from ltx_core.model.video_vae import TilingConfig

prompt = prompt_batch[0]
extra_params = self.pipeline_manager.config.extra_params
kwargs = {
"negative_prompt": extra_args.get(
"negative_prompt", "worst quality, inconsistent motion, blurry, jittery, distorted"
),
"seed": extra_params.get("seed", 0),
"height": extra_params.get("height", extra_args.get("height", 1024)),
"width": extra_params.get("width", extra_args.get("width", 1536)),
"num_frames": extra_params.get("num_frames", extra_args.get("num_frames", 121)),
"frame_rate": extra_params.get("frame_rate", extra_args.get("frame_rate", 24.0)),
"num_inference_steps": self.config.n_steps,
"cfg_guidance_scale": extra_params.get(
"cfg_guidance_scale", extra_args.get("cfg_guidance_scale", 4.0)
),
"images": extra_params.get("images", []),
"tiling_config": extra_params.get("tiling_config", TilingConfig.default()),
}
self.pipe(prompt=prompt, **kwargs) # type: ignore[misc]

def _run_ltx_video_calibration(
self, prompt_batch: list[str], extra_args: dict[str, Any]
) -> None:
Expand Down Expand Up @@ -568,7 +705,7 @@ def quantize_model(
backbone: torch.nn.Module,
quant_config: Any,
forward_loop: callable, # type: ignore[valid-type]
) -> None:
) -> torch.nn.Module:
"""
Apply quantization to the model.

Expand All @@ -590,6 +727,7 @@ def quantize_model(
mtq.disable_quantizer(backbone, model_filter_func)

self.logger.info("Quantization completed successfully")
return backbone


class ExportManager:
Expand Down Expand Up @@ -691,7 +829,8 @@ def restore_checkpoint(self, backbone: nn.Module) -> None:
mto.restore(backbone, str(self.config.restore_from))
self.logger.info("Model restored successfully")

def export_hf_ckpt(self, pipe: DiffusionPipeline) -> None:
# TODO: should not do the any data type
def export_hf_ckpt(self, pipe: Any) -> None:
"""
Export quantized model to HuggingFace checkpoint format.

Expand Down Expand Up @@ -754,7 +893,7 @@ def create_argument_parser() -> argparse.ArgumentParser:
model_group.add_argument(
"--model-dtype",
type=str,
default="Half",
default="BFloat16",
choices=[d.value for d in DataType],
help="Precision for loading the pipeline. If you want different dtypes for separate components, "
"please specify using --component-dtype",
Expand All @@ -778,6 +917,16 @@ def create_argument_parser() -> argparse.ArgumentParser:
action="store_true",
help="Skip upsampler pipeline for LTX-Video (faster calibration, only quantizes main transformer)",
)
model_group.add_argument(
"--extra-param",
action="append",
default=[],
metavar="KEY=VALUE",
help=(
"Extra model-specific parameters in KEY=VALUE form. Can be provided multiple times. "
"These override model-specific CLI arguments when present."
),
)
quant_group = parser.add_argument_group("Quantization Configuration")
quant_group.add_argument(
"--format",
Expand Down Expand Up @@ -859,7 +1008,7 @@ def create_argument_parser() -> argparse.ArgumentParser:

def main() -> None:
parser = create_argument_parser()
args = parser.parse_args()
args, unknown_args = parser.parse_known_args()

model_type = ModelType(args.model)
if args.backbone is None:
Expand All @@ -875,6 +1024,7 @@ def main() -> None:
logger.info("Starting Enhanced Diffusion Model Quantization")

try:
extra_params = parse_extra_params(args.extra_param, unknown_args, logger)
model_config = ModelConfig(
model_type=model_type,
model_dtype=model_dtype,
Expand All @@ -885,6 +1035,7 @@ def main() -> None:
else None,
cpu_offloading=args.cpu_offloading,
ltx_skip_upsampler=args.ltx_skip_upsampler,
extra_params=extra_params,
)

quant_config = QuantizationConfig(
Expand Down Expand Up @@ -950,6 +1101,7 @@ def main() -> None:
quantizer = Quantizer(quant_config, model_config, logger)
backbone_quant_config = quantizer.get_quant_config(calib_config.n_steps, backbone)

# Pipe loads the ckpt just before the inference.
def forward_loop(mod):
calibrator.run_calibration(batched_prompts)

Expand Down
6 changes: 4 additions & 2 deletions examples/diffusers/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from diffusers.utils import load_image

import modelopt.torch.quantization as mtq
from modelopt.torch.quantization.plugins.diffusers import AttentionModuleMixin
from modelopt.torch.quantization.plugins.diffusion.diffusers import AttentionModuleMixin

USE_PEFT = True
try:
Expand Down Expand Up @@ -69,7 +69,9 @@ def check_conv_and_mha(backbone, if_fp4, quantize_mha):

def filter_func_ltx_video(name: str) -> bool:
"""Filter function specifically for LTX-Video models."""
pattern = re.compile(r".*(proj_in|time_embed|caption_projection|proj_out).*")
pattern = re.compile(
r".*(proj_in|time_embed|caption_projection|proj_out|patchify_proj|adaln_single).*"
)
return pattern.match(name) is not None


Expand Down
Loading
Loading