Skip to content

add ltx2 vae in sana-video;#13229

Open
lawrence-cj wants to merge 4 commits intohuggingface:mainfrom
lawrence-cj:sana-video-ltx2vae
Open

add ltx2 vae in sana-video;#13229
lawrence-cj wants to merge 4 commits intohuggingface:mainfrom
lawrence-cj:sana-video-ltx2vae

Conversation

@lawrence-cj
Copy link
Contributor

@lawrence-cj lawrence-cj commented Mar 9, 2026

This PR adds LTX-VAE support for SANA-Video.

Cc: @dg845 @sayakpaul

GPU memory needed: 47GB for LTX refiner

SANA-Video with LTX2-Refiner:

"""Sana Video + LTX2 Refiner: Stage 1 generate latent → Stage 2 refine (3 steps)."""

import gc
import torch
from diffusers import SanaVideoPipeline, FlowMatchEulerDiscreteScheduler
from diffusers.pipelines.ltx2 import LTX2Pipeline, LTX2LatentUpsamplePipeline
from diffusers.pipelines.ltx2.latent_upsampler import LTX2LatentUpsamplerModel
from diffusers.pipelines.ltx2.utils import STAGE_2_DISTILLED_SIGMA_VALUES
from diffusers.pipelines.ltx2.export_utils import encode_video

device = "cuda"
dtype = torch.bfloat16
prompt = "A cat walking on the grass, facing the camera."
negative_prompt = "A chaotic sequence with misshapen, deformed limbs in heavy motion blur, sudden disappearance, jump cuts, jerky movements, rapid shot changes, frames out of sync, inconsistent character shapes, temporal artifacts, jitter, and ghosting effects, creating a disorienting visual experience."
motion_score = 30
height, width, frames, frame_rate = 704, 1280, 81, 16.0
seed = 42

# ── Load all models ──
sana_pipe = SanaVideoPipeline.from_pretrained(
    "Sana_video/safetensors/sana_ltxvae_sft", torch_dtype=dtype,
)
sana_pipe.text_encoder.to(dtype)
sana_pipe.enable_model_cpu_offload()

ltx_pipe = LTX2Pipeline.from_pretrained("Lightricks/LTX-2", torch_dtype=dtype)
ltx_pipe.load_lora_weights(
    "Lightricks/LTX-2", adapter_name="stage_2_distilled",
    weight_name="ltx-2-19b-distilled-lora-384.safetensors",
)
ltx_pipe.set_adapters("stage_2_distilled", 1.0)
ltx_pipe.vae.enable_tiling()
ltx_pipe.enable_model_cpu_offload()

latent_upsampler = LTX2LatentUpsamplerModel.from_pretrained(
    "Lightricks/LTX-2", subfolder="latent_upsampler", torch_dtype=dtype,
)
upsample_pipe = LTX2LatentUpsamplePipeline(vae=ltx_pipe.vae, latent_upsampler=latent_upsampler)
upsample_pipe.enable_model_cpu_offload(device=device)

# ── Stage 1: Sana Video ──
video_latent = sana_pipe(
    prompt=prompt + f" motion score: {motion_score}.", negative_prompt=negative_prompt,
    height=height, width=width, frames=frames,
    guidance_scale=6.0, num_inference_steps=50,
    generator=torch.Generator(device=device).manual_seed(seed),
    output_type="latent", return_dict=True,
).frames

del sana_pipe; gc.collect(); torch.cuda.empty_cache()

# ── Stage 1.5: Latent Upsample (2x spatial) ──
video_latent = upsample_pipe(
    latents=video_latent.to(device=device, dtype=dtype),
    latents_normalized=True,
    height=height, width=width, num_frames=frames,
    output_type="latent", return_dict=False,
)[0]
latents_mean = ltx_pipe.vae.latents_mean.view(1, -1, 1, 1, 1).to(video_latent.device, video_latent.dtype)
latents_std = ltx_pipe.vae.latents_std.view(1, -1, 1, 1, 1).to(video_latent.device, video_latent.dtype)
video_latent = (video_latent - latents_mean) * ltx_pipe.vae.config.scaling_factor / latents_std

# ── Stage 2: LTX2 Refine ──
packed = LTX2Pipeline._pack_latents(
    video_latent.to(device=device, dtype=dtype),
    patch_size=ltx_pipe.transformer_spatial_patch_size,
    patch_size_t=ltx_pipe.transformer_temporal_patch_size,
)
_, _, lF, lH, lW = video_latent.shape
pH, pW, pT = lH * ltx_pipe.vae_spatial_compression_ratio, lW * ltx_pipe.vae_spatial_compression_ratio, (lF - 1) * ltx_pipe.vae_temporal_compression_ratio + 1

dur = pT / frame_rate
audio_frames = round(dur * ltx_pipe.audio_sampling_rate / ltx_pipe.audio_hop_length / ltx_pipe.audio_vae_temporal_compression_ratio)
nch = ltx_pipe.audio_vae.config.latent_channels
mel = ltx_pipe.audio_vae.config.mel_bins // ltx_pipe.audio_vae_mel_compression_ratio
audio_latent = (
    ltx_pipe.audio_vae.latents_mean.unsqueeze(0).unsqueeze(0)
    .expand(1, audio_frames, nch * mel).to(dtype=dtype, device=device).contiguous()
    .unflatten(2, (nch, mel)).permute(0, 2, 1, 3).contiguous()
)

del video_latent; gc.collect(); torch.cuda.empty_cache()

video, _ = ltx_pipe(
    latents=packed, audio_latents=audio_latent,
    prompt=prompt, negative_prompt=negative_prompt,
    height=pH, width=pW, num_frames=pT,
    num_inference_steps=3,
    noise_scale=STAGE_2_DISTILLED_SIGMA_VALUES[0],
    sigmas=STAGE_2_DISTILLED_SIGMA_VALUES,
    guidance_scale=1.0, frame_rate=frame_rate,
    generator=torch.Generator(device=device).manual_seed(seed),
    output_type="np", return_dict=False,
)

video = torch.from_numpy((video * 255).round().astype("uint8"))
encode_video(video[0], fps=frame_rate, audio=None, audio_sample_rate=None, output_path="sana_ltx2_refined.mp4")

Result

sana_ltx_refined.mp4

@sayakpaul
Copy link
Member

@lawrence-cj thanks for the PR! Could you also provide some sample outputs?

Comment on lines +226 to +244
if getattr(self, "vae", None):
if hasattr(self.vae.config, "scale_factor_temporal"):
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal
elif hasattr(self.vae.config, "temporal_compression_ratio"):
# LTX2 VAE uses temporal_compression_ratio
self.vae_scale_factor_temporal = self.vae.config.temporal_compression_ratio
else:
self.vae_scale_factor_temporal = getattr(self.vae, "temporal_compression_ratio", 4)

if hasattr(self.vae.config, "scale_factor_spatial"):
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial
elif hasattr(self.vae.config, "spatial_compression_ratio"):
# LTX2 VAE uses spatial_compression_ratio
self.vae_scale_factor_spatial = self.vae.config.spatial_compression_ratio
else:
self.vae_scale_factor_spatial = getattr(self.vae, "spatial_compression_ratio", 8)
else:
self.vae_scale_factor_temporal = 4
self.vae_scale_factor_spatial = 8
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, should this be conditioned on the class type of the VAE being used?

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I just left one comment. But it looks good to me.

@sayakpaul sayakpaul requested a review from dg845 March 10, 2026 03:19
@lawrence-cj
Copy link
Contributor Author

lawrence-cj commented Mar 10, 2026

Could you also provide some sample outputs?

Updated code and result.

@sayakpaul @dg845

Copy link
Collaborator

@dg845 dg845 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! The code looks good to me. However, running the example script doesn't work for me because I don't have access to the Sana_video/safetensors/sana_ltxvae_sft checkpoint. Would it be possible to provide a checkpoint for testing?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants