Skip to content

Conversation

@sayakpaul
Copy link
Member

@sayakpaul sayakpaul requested a review from DN6 January 28, 2026 09:21
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

# Upcast the QR orthogonalization operation to FP32
original_motion_dtype = motion_feat.dtype
motion_feat = motion_feat.to(weight.dtype)
motion_feat = motion_feat.to(torch.float32)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably because of #12691. Cc: @dg845

@sayakpaul sayakpaul requested a review from dg845 January 29, 2026 02:58
Copy link
Collaborator

@dg845 dg845 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! As far as I can tell, Wan Animate single file / GGUF support doesn't depend on this change (the generated samples look normal), so the change should be fine.

Wan Animate Single File Test Script
import numpy as np
import torch

from diffusers import AutoencoderKLWan, GGUFQuantizationConfig
from diffusers import WanAnimatePipeline, WanAnimateTransformer3DModel
from diffusers.utils import export_to_video, load_image, load_video


LoRA = True
device_gpu = torch.device("cuda:0")

original_model_id = "Wan-AI/Wan2.2-Animate-14B-Diffusers"
single_file_url = "https://huggingface.co/QuantStack/Wan2.2-Animate-14B-GGUF/blob/main/Wan2.2-Animate-14B-Q8_0.gguf"

lora_model_id = "Kijai/WanVideo_comfy"
lora_model_path = "Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank64_bf16.safetensors"

print("Loading transformer ....")
transformer = WanAnimateTransformer3DModel.from_single_file(
    single_file_url,
    quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
    config=original_model_id,
    subfolder="transformer",
    torch_dtype=torch.bfloat16,
    offload_device="cpu",
    device=device_gpu
)
print("Transformer loaded successfully ....")

print("Loading pipeline ....")
pipe = WanAnimatePipeline.from_pretrained(
    original_model_id,
    transformer=transformer,
    torch_dtype=torch.bfloat16,
)

if LoRA:
    pipe.load_lora_weights(
        lora_model_id,
        weight_name=lora_model_path,
        adapter_name="lightning",
        offload_device="cpu",
        device=device_gpu
    )

pipe.enable_model_cpu_offload()
print("Pipeline loaded successfully ....")

# Load the character image
image = load_image(
     "Wan2.2/examples/wan_animate/animate/image.jpeg",
 )

# Load pose and face videos (preprocessed from reference video)
# Note: Videos should be preprocessed to extract pose keypoints and face features
# Refer to the Wan-Animate preprocessing documentation for details
pose_video = load_video("Wan2.2/examples/wan_animate/animate/process_results/src_pose.mp4")
face_video = load_video("Wan2.2/examples/wan_animate/animate/process_results/src_face.mp4")

# Calculate optimal dimensions based on VAE constraints
max_area = 1280 * 720
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
image = image.resize((width, height))

prompt = "People in the video are doing actions."

# Animation mode: Animate the character with the motion from pose/face videos
print("Generating animation ....")
if LoRA:
    output = pipe(
        image=image,
        pose_video=pose_video,
        face_video=face_video,
        prompt=prompt,
        #  negative_prompt=negative_prompt,
        height=height,
        width=width,
        segment_frame_length=77,
        guidance_scale=1.0,
        prev_segment_conditioning_frames=1,  # refert_num in original code
        num_inference_steps=4,
        mode="animate",
    ).frames[0]
else:
    output = pipe(
        image=image,
        pose_video=pose_video,
        face_video=face_video,
        prompt=prompt,
        #  negative_prompt=negative_prompt,
        height=height,
        width=width,
        segment_frame_length=77,
        guidance_scale=1.0,
        prev_segment_conditioning_frames=1,  # refert_num in original code
        num_inference_steps=20,
        mode="animate",
    ).frames[0]

print("Exporting animation ....")
export_to_video(output, "wan_animate_gguf_lora.mp4", fps=30)

Sample with GGUF + LoRA:

wan_animate_gguf_lora.mp4

@sayakpaul
Copy link
Member Author

Failing tests are unrelated.

@sayakpaul sayakpaul merged commit e7de7d8 into main Jan 29, 2026
9 of 12 checks passed
@sayakpaul sayakpaul deleted the wan-layerwise-upcasting-tests branch January 29, 2026 09:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants