Eliminate GPU sync overhead and CPU→GPU transfers across LTX2 pipeline#13564
Eliminate GPU sync overhead and CPU→GPU transfers across LTX2 pipeline#13564ViktoriiaRomanova wants to merge 4 commits intohuggingface:mainfrom
Conversation
…or creation across the LTX2 pipeline, transformer, scheduler, and connector logic. - Add set_begin_index(0) to schedulers to eliminate DtoH sync in _init_step_index - Replace torch.tensor(..., device=...) with on-device tensor construction for decode scaling - Move RoPE-related tensor creation to GPU to avoid memcpy overhead - Refactor connector padding logic using vectorized masking instead of list-based ops
sayakpaul
left a comment
There was a problem hiding this comment.
Thanks for the PR. Please provide comments inline to the changes explaining how they eliminate the syncs.
| patch_size = (self.patch_size_t, self.patch_size, self.patch_size) | ||
| patch_size_delta = torch.tensor(patch_size, dtype=grid.dtype, device=grid.device) | ||
| patch_ends = grid + patch_size_delta.view(3, 1, 1, 1) | ||
| patch_size_delta = torch.stack( | ||
| [ | ||
| grid.new_ones(1) * self.patch_size_t, | ||
| grid.new_ones(1) * self.patch_size, | ||
| grid.new_ones(1) * self.patch_size, | ||
| ] | ||
| ).reshape(3, 1, 1, 1) | ||
| patch_ends = grid + patch_size_delta |
There was a problem hiding this comment.
This refactor seems unnecessary.
There was a problem hiding this comment.
This replaces host-side tensor construction with device-native ops to eliminate an implicit cudaStreamSynchronize (~60ms).
There was a problem hiding this comment.
This refactor eliminates a CPU --> GPU sync (since the patch_size tuple lives on the CPU host, torch.tensor needs to copy it to GPU, which the refactor avoids by doing the operations on the GPU), but because the tuple is really small, it looks like the corresponding cudaStreamSynchronize block (and that of the similar scale_tensor refactor below) takes about 2ms. I think most of the removed non-scheduler sync time is in the connectors.py refactor.
| scale_tensor = torch.tensor(self.scale_factors, device=latent_coords.device) | ||
| scale_tensor = torch.stack([latent_coords.new_ones(1) * factor for factor in self.scale_factors]) |
There was a problem hiding this comment.
Also avoids implicit cudaStreamSynchronize (~60ms) by replacing torch.tensor(...) with device-native tensor construction.
|
|
||
| # 5. Convert sigmas and timesteps to tensors and move to specified device | ||
| sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) | ||
| sigmas = torch.from_numpy(sigmas).pin_memory().to(dtype=torch.float32, device=device, non_blocking=True) |
There was a problem hiding this comment.
This is a hard no. We cannot be pinning memory and running it async within the scheduler.
There was a problem hiding this comment.
@sayakpaul Should I provide an alternative implementation to avoid the cudaStreamSynchronize, or is this sync considered acceptable in this case?
| self.scheduler.set_begin_index(0) | ||
| audio_scheduler.set_begin_index(0) |
There was a problem hiding this comment.
Move it out of the set_begin_index hasattr check.
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
@bot /style |
|
Style bot fixed some files and pushed the changes. |
dg845
left a comment
There was a problem hiding this comment.
Thanks for the PR! I agree with #13564 (review) that having comments for the changes would be helpful.
Fixes performance issues identified by profiling LTX2 with torch.profiler as part of #13401.
Optimises LTX2 by removing unnecessary GPU synchronisation points and replacing CPU tensor creation with on-device tensor operations across the decoding pipeline, transformer RoPE computations, scheduler, and connector padding logic.
Pipeline Denoising Optimisation
Before (eager mode):
Before (compile mode):


After (compile mode, no sync gap):
Transformer Model Optimisation
Replaced CPU tensor creation for patch sizes with on-device tensor construction.
Eliminates unnecessary CPU-to-GPU memcpy operations during RoPE coordinate preparation.
Connector Refactoring
Replaced list-comprehension-based padding logic with vectorised masking. This simplifies left-padding layout logic and eliminates unnecessary cudaStreamSynchronize calls.
Performance Results
Profiler trace
https://drive.google.com/drive/folders/1cZn1xw-8Eon22mA2zP1uoF1nE4YCC3Wo?usp=drive_link
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@sayakpaul @dg845