Skip to content

Add LTX-2.X IC LoRA and HDR Pipelines#13572

Open
dg845 wants to merge 18 commits intomainfrom
ltx2-hdr-ic-lora-pipeline
Open

Add LTX-2.X IC LoRA and HDR Pipelines#13572
dg845 wants to merge 18 commits intomainfrom
ltx2-hdr-ic-lora-pipeline

Conversation

@dg845
Copy link
Copy Markdown
Collaborator

@dg845 dg845 commented Apr 28, 2026

What does this PR do?

This PR adds two new LTX-2.X pipelines: LTX2InContextPipeline, which supports in-context (IC) conditioning (used for example by some IC LoRAs) and LTX2HDRPipeline, which supports the newly released high dynamic range (HDR) pipeline (and HDR IC-LoRA) introduced in the LTX LumiVid paper.

This PR also updates LTX2ConditionPipeline to follow the LTX-2 repo's current image conditioning strategy, which overwrites the noisy latents only for first-frame (I2V) conditions and treats non-first-frame as keyframe conditions which are appended to the noisy latents. (Previously, the pipeline overwrote the noisy latents at all latent indices.)

Here is an example of using LTX2InContextPipeline with the Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In IC LoRA, which biases the generated video toward panning in on the subject:

LTX-2.3 IC LoRA Example Script
import torch

from diffusers import LTX2InContextPipeline
from diffusers.pipelines.ltx2.export_utils import encode_video
from diffusers.pipelines.ltx2.utils import DEFAULT_NEGATIVE_PROMPT


pipe = LTX2InContextPipeline.from_pretrained("dg845/LTX-2.3-Diffusers", torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload(device="cuda:0")
pipe.load_lora_weights(
    "Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In",
    adapter_name="ic_lora",
    weight_name="ltx-2-19b-lora-camera-control-dolly-in.safetensors",
)
pipe.set_adapters("ic_lora", 1.0)

# If the IC LoRA uses reference conditions, you can specify them as follows:
# reference_video = load_video("reference.mp4")
# ref_cond = LTX2ReferenceCondition(frames=reference_video, strength=1.0)

prompt = "A flowing river in a forest"
frame_rate = 24.0
video, audio = pipe(
    prompt=prompt,
    negative_prompt=DEFAULT_NEGATIVE_PROMPT,
    # reference_conditions=[ref_cond],
    width=768,
    height=512,
    num_frames=121,
    frame_rate=frame_rate,
    num_inference_steps=30,
    guidance_scale=3.0,
    output_type="np",
    return_dict=False,
)

encode_video(
    video[0],
    fps=frame_rate,
    audio=audio[0].float().cpu(),
    audio_sample_rate=pipe.vocoder.config.output_sampling_rate,
    output_path="ltx2_ic_lora_output.mp4",
)
ltx2_ic_lora_output.mp4

And here is an example of using LTX2HDRPipeline with the Lightricks/LTX-2.3-22b-IC-LoRA-HDR HDR IC-LoRA, using the video above as the reference:

LTX-2.3 HDR IC LoRA Example Script
import torch
from safetensors import safe_open

from diffusers import LTX2HDRPipeline
from diffusers.pipelines.ltx2.export_utils import encode_hdr_tensor_to_mp4
from diffusers.pipelines.ltx2.pipeline_ltx2_hdr_lora import LTX2HDRReferenceCondition
from diffusers.pipelines.ltx2.utils import DISTILLED_SIGMA_VALUES
from diffusers.utils import load_video


pipe = LTX2HDRPipeline.from_pretrained("dg845/LTX-2.3-Distilled-Diffusers", torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload(device="cuda")
pipe.load_lora_weights(
    "Lightricks/LTX-2.3-22b-IC-LoRA-HDR",
    adapter_name="hdr_lora",
    weight_name="ltx-2.3-22b-ic-lora-hdr-0.9.safetensors",
)
pipe.set_adapters("hdr_lora", 1.0)

reference_video = load_video("ltx2_ic_lora_output.mp4")
ref_cond = LTX2HDRReferenceCondition(frames=reference_video, strength=1.0)

# Load pre-computed HDR LoRA connector embeddings.
with safe_open(
    "/path/to/ltx-2.3-22b-ic-lora-hdr-scene-emb.safetensors", framework="pt", device="cuda"
) as f:
    connector_video_embeds = f.get_tensor("video_context")
    connector_audio_embeds = f.get_tensor("audio_context")

# `hdr_video` is a linear HDR tensor of shape (batch, frames, H, W, C).
hdr_video = pipe(
    reference_conditions=[ref_cond],
    connector_video_embeds=connector_video_embeds,
    connector_audio_embeds=connector_audio_embeds,
    width=768,
    height=512,
    num_frames=121,
    frame_rate=24.0,
    num_inference_steps=8,
    sigmas=DISTILLED_SIGMA_VALUES,
    guidance_scale=1.0,
    output_type="pt",
    return_dict=False,
)[0]

# Convert the HDR video to a SDR sRGB-tonemapped `.mp4` video.
# You can also save the output to EXR using `save_hdr_video_frames_as_exr`.
# A custom tone-mapper can be specified via the `tone_mapping_fn` argument.
encode_hdr_tensor_to_mp4(
    hdr_video[0],
    output_mp4="ltx2_hdr_lora_output.mp4",
    frame_rate=24.0,
)
ltx2_hdr_lora_output.mp4

This uses the default tone-mapper that simply clips HDR values to [0, 1], which is also used by the original LTX-2 code.

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@yiyixuxu
@sayakpaul
@linoytsaban

@dg845 dg845 requested review from sayakpaul and yiyixuxu April 28, 2026 05:20
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 28, 2026
@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels May 1, 2026

# Adapted from ltx_pipelines.utils.media_io.save_exr_tensor
# https://github.com/Lightricks/LTX-2/blob/41d924371612b692c0fd1e4d9d94c3dfb3c02cb3/packages/ltx-pipelines/src/ltx_pipelines/utils/media_io.py#L609
def save_exr_tensor(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think with this PR, encode_hdr_tensor_to_mp4 is enough, no?
can we remove all the other functions here?

num_frames: int = 121,
frame_rate: float = 24.0,
num_inference_steps: int = 40,
num_inference_steps: int = 30,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did they change the default since the release? 🤯

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ended up changing the LTX2ConditionPipeline.__call__ defaults to the LTX-2.3 defaults (since the implementation now follows the newest behavior rather than the original LTX-2.0 behavior); num_inference_steps=30 is the default for LTX-2.3, while num_inference_steps=40 is the default for LTX-2.0.

@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels May 6, 2026
@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels May 6, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants