Skip to content

Eliminate GPU sync overhead and CPU→GPU transfers across LTX2 pipeline#13564

Open
ViktoriiaRomanova wants to merge 4 commits intohuggingface:mainfrom
ViktoriiaRomanova:ltx2pipelinespeedup
Open

Eliminate GPU sync overhead and CPU→GPU transfers across LTX2 pipeline#13564
ViktoriiaRomanova wants to merge 4 commits intohuggingface:mainfrom
ViktoriiaRomanova:ltx2pipelinespeedup

Conversation

@ViktoriiaRomanova
Copy link
Copy Markdown

Fixes performance issues identified by profiling LTX2 with torch.profiler as part of #13401.

Optimises LTX2 by removing unnecessary GPU synchronisation points and replacing CPU tensor creation with on-device tensor operations across the decoding pipeline, transformer RoPE computations, scheduler, and connector padding logic.

Pipeline Denoising Optimisation

  1. Added explicit set_begin_index(0) calls to both video and audio schedulers. This avoids the DtoH sync in _init_step_index. Uses the same pattern as the issue fixed in PR Avoid DtoH sync from access of nonzero() item in scheduler #11696.
    Before (eager mode):
image After (eager mode, no sync gap): image

Before (compile mode):
image
After (compile mode, no sync gap):
image

  1. Replaced torch.tensor(..., device=device) with on-device torch.stack([torch.ones(...)*s for s in decode_noise_scale]). Avoids CPU tensor allocation and GPU transfers for decode noise scaling.

Transformer Model Optimisation

Replaced CPU tensor creation for patch sizes with on-device tensor construction.
Eliminates unnecessary CPU-to-GPU memcpy operations during RoPE coordinate preparation.

Connector Refactoring

Replaced list-comprehension-based padding logic with vectorised masking. This simplifies left-padding layout logic and eliminates unnecessary cudaStreamSynchronize calls.

Performance Results

Metric Before After
cudaStreamSynchronize calls (total) 18 6
Scheduler sync (eager mode) 233ms eliminated
Scheduler sync (compiled mode) 573ms eliminated
Other syncs total (eager mode) 88ms 25ms
Other syncs total (compiled mode) 93ms 25ms

Profiler trace

https://drive.google.com/drive/folders/1cZn1xw-8Eon22mA2zP1uoF1nE4YCC3Wo?usp=drive_link

Before submitting

Who can review?

@sayakpaul @dg845

…or creation

across the LTX2 pipeline, transformer, scheduler, and connector logic.

- Add set_begin_index(0) to schedulers to eliminate DtoH sync in _init_step_index
- Replace torch.tensor(..., device=...) with on-device tensor construction for decode scaling
- Move RoPE-related tensor creation to GPU to avoid memcpy overhead
- Refactor connector padding logic using vectorized masking instead of list-based ops
Copy link
Copy Markdown
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR. Please provide comments inline to the changes explaining how they eliminate the syncs.

Comment on lines -897 to +904
patch_size = (self.patch_size_t, self.patch_size, self.patch_size)
patch_size_delta = torch.tensor(patch_size, dtype=grid.dtype, device=grid.device)
patch_ends = grid + patch_size_delta.view(3, 1, 1, 1)
patch_size_delta = torch.stack(
[
grid.new_ones(1) * self.patch_size_t,
grid.new_ones(1) * self.patch_size,
grid.new_ones(1) * self.patch_size,
]
).reshape(3, 1, 1, 1)
patch_ends = grid + patch_size_delta
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This refactor seems unnecessary.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This replaces host-side tensor construction with device-native ops to eliminate an implicit cudaStreamSynchronize (~60ms).

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This refactor eliminates a CPU --> GPU sync (since the patch_size tuple lives on the CPU host, torch.tensor needs to copy it to GPU, which the refactor avoids by doing the operations on the GPU), but because the tuple is really small, it looks like the corresponding cudaStreamSynchronize block (and that of the similar scale_tensor refactor below) takes about 2ms. I think most of the removed non-scheduler sync time is in the connectors.py refactor.

Comment on lines -908 to +913
scale_tensor = torch.tensor(self.scale_factors, device=latent_coords.device)
scale_tensor = torch.stack([latent_coords.new_ones(1) * factor for factor in self.scale_factors])
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also avoids implicit cudaStreamSynchronize (~60ms) by replacing torch.tensor(...) with device-native tensor construction.


# 5. Convert sigmas and timesteps to tensors and move to specified device
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
sigmas = torch.from_numpy(sigmas).pin_memory().to(dtype=torch.float32, device=device, non_blocking=True)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a hard no. We cannot be pinning memory and running it async within the scheduler.

Copy link
Copy Markdown
Author

@ViktoriiaRomanova ViktoriiaRomanova May 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sayakpaul Should I provide an alternative implementation to avoid the cudaStreamSynchronize, or is this sync considered acceptable in this case?

Comment on lines +1194 to +1195
self.scheduler.set_begin_index(0)
audio_scheduler.set_begin_index(0)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move it out of the set_begin_index hasattr check.

@github-actions github-actions Bot added size/S PR with diff < 50 LOC and removed size/S PR with diff < 50 LOC labels May 1, 2026
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@github-actions github-actions Bot added size/S PR with diff < 50 LOC and removed size/S PR with diff < 50 LOC labels May 5, 2026
@dg845
Copy link
Copy Markdown
Collaborator

dg845 commented May 6, 2026

@bot /style

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 6, 2026

Style bot fixed some files and pushed the changes.

@github-actions github-actions Bot added size/S PR with diff < 50 LOC and removed size/S PR with diff < 50 LOC labels May 6, 2026
Copy link
Copy Markdown
Collaborator

@dg845 dg845 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! I agree with #13564 (review) that having comments for the changes would be helpful.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants