Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions src/diffusers/models/transformers/transformer_ltx2.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,9 +894,14 @@ def prepare_video_coords(
grid = torch.stack(grid, dim=0) # [3, N_F, N_H, N_W], where e.g. N_F is the number of temporal patches

# 2. Get the patch boundaries with respect to the latent video grid
patch_size = (self.patch_size_t, self.patch_size, self.patch_size)
patch_size_delta = torch.tensor(patch_size, dtype=grid.dtype, device=grid.device)
patch_ends = grid + patch_size_delta.view(3, 1, 1, 1)
patch_size_delta = torch.stack(
[
grid.new_ones(1) * self.patch_size_t,
grid.new_ones(1) * self.patch_size,
grid.new_ones(1) * self.patch_size,
]
).reshape(3, 1, 1, 1)
patch_ends = grid + patch_size_delta
Comment on lines -897 to +904
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This refactor seems unnecessary.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This replaces host-side tensor construction with device-native ops to eliminate an implicit cudaStreamSynchronize (~60ms).

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This refactor eliminates a CPU --> GPU sync (since the patch_size tuple lives on the CPU host, torch.tensor needs to copy it to GPU, which the refactor avoids by doing the operations on the GPU), but because the tuple is really small, it looks like the corresponding cudaStreamSynchronize block (and that of the similar scale_tensor refactor below) takes about 2ms. I think most of the removed non-scheduler sync time is in the connectors.py refactor.


# Combine the start (grid) and end (patch_ends) coordinates along new trailing dimension
latent_coords = torch.stack([grid, patch_ends], dim=-1) # [3, N_F, N_H, N_W, 2]
Expand All @@ -905,7 +910,7 @@ def prepare_video_coords(
latent_coords = latent_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1)

# 3. Calculate the pixel space patch boundaries from the latent boundaries.
scale_tensor = torch.tensor(self.scale_factors, device=latent_coords.device)
scale_tensor = torch.stack([latent_coords.new_ones(1) * factor for factor in self.scale_factors])
Comment on lines -908 to +913
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also avoids implicit cudaStreamSynchronize (~60ms) by replacing torch.tensor(...) with device-native tensor construction.

# Broadcast the VAE scale factors such that they are compatible with latent_coords's shape
broadcast_shape = [1] * latent_coords.ndim
broadcast_shape[1] = -1 # This is the (frame, height, width) dim
Expand Down
20 changes: 9 additions & 11 deletions src/diffusers/pipelines/ltx2/connectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import torch
import torch.nn as nn
import torch.nn.functional as F

from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
Expand Down Expand Up @@ -295,22 +294,21 @@ def forward(
)

num_register_repeats = seq_len // self.num_learnable_registers
registers = torch.tile(self.learnable_registers, (num_register_repeats, 1)) # [seq_len, inner_dim]
registers = (
self.learnable_registers.unsqueeze(0).expand(num_register_repeats, -1, -1).reshape(seq_len, -1)
) # [seq_len, inner_dim]

binary_attn_mask = (attention_mask >= attn_mask_binarize_threshold).int()
if binary_attn_mask.ndim == 4:
binary_attn_mask = binary_attn_mask.squeeze(1).squeeze(1) # [B, 1, 1, L] --> [B, L]

hidden_states_non_padded = [hidden_states[i, binary_attn_mask[i].bool(), :] for i in range(batch_size)]
valid_seq_lens = [x.shape[0] for x in hidden_states_non_padded]
pad_lengths = [seq_len - valid_seq_len for valid_seq_len in valid_seq_lens]
padded_hidden_states = [
F.pad(x, pad=(0, 0, 0, p), value=0) for x, p in zip(hidden_states_non_padded, pad_lengths)
]
padded_hidden_states = torch.cat([x.unsqueeze(0) for x in padded_hidden_states], dim=0) # [B, L, D]
# Replace padding positions with learned registers using vectorized masking
mask = binary_attn_mask.unsqueeze(-1) # [B, L, 1]
registers_expanded = registers.unsqueeze(0).expand(batch_size, -1, -1) # [B, L, D]
hidden_states = mask * hidden_states + (1 - mask) * registers_expanded

flipped_mask = torch.flip(binary_attn_mask, dims=[1]).unsqueeze(-1) # [B, L, 1]
hidden_states = flipped_mask * padded_hidden_states + (1 - flipped_mask) * registers
# Flip sequence: embeddings move to front, registers to back (from left padding layout)
hidden_states = torch.flip(hidden_states, dims=[1])

# Overwrite attention_mask with an all-zeros mask if using registers.
attention_mask = torch.zeros_like(attention_mask)
Expand Down
5 changes: 5 additions & 0 deletions src/diffusers/pipelines/ltx2/pipeline_ltx2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1189,6 +1189,11 @@ def __call__(
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)

# Set begin index to skip nonzero().item() call in scheduler initialization, which triggers GPU sync
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(0)
audio_scheduler.set_begin_index(0)
Comment on lines +1194 to +1195
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move it out of the set_begin_index hasattr check.


# 6. Prepare micro-conditions
# Pre-compute video and audio positional ids as they will be the same at each step of the denoising loop
video_coords = self.transformer.rope.prepare_video_coords(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -363,11 +363,13 @@ def set_timesteps(
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)

# 5. Convert sigmas and timesteps to tensors and move to specified device
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
sigmas = torch.from_numpy(sigmas).pin_memory().to(dtype=torch.float32, device=device, non_blocking=True)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a hard no. We cannot be pinning memory and running it async within the scheduler.

Copy link
Copy Markdown
Author

@ViktoriiaRomanova ViktoriiaRomanova May 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sayakpaul Should I provide an alternative implementation to avoid the cudaStreamSynchronize, or is this sync considered acceptable in this case?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My current thoughts are that an alternative implementation which avoids the cudaStreamSynchronize would be good, but accepting the sync would be preferable to the pin_memory implementation. (As far as I can tell, this particular sync point overlaps with GPU kernels, so it's probably not in the critical path.)

if not is_timesteps_provided:
timesteps = sigmas * self.config.num_train_timesteps
else:
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=device)
timesteps = (
torch.from_numpy(timesteps).pin_memory().to(dtype=torch.float32, device=device, non_blocking=True)
)

# 6. Append the terminal sigma value.
# If a model requires inverted sigma schedule for denoising but timesteps without inversion, the
Expand Down
Loading