-
Notifications
You must be signed in to change notification settings - Fork 7k
Eliminate GPU sync overhead and CPU→GPU transfers across LTX2 pipeline #13564
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
5988744
9891781
64d19cb
f71dc9f
bc22a6b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -894,9 +894,14 @@ def prepare_video_coords( | |
| grid = torch.stack(grid, dim=0) # [3, N_F, N_H, N_W], where e.g. N_F is the number of temporal patches | ||
|
|
||
| # 2. Get the patch boundaries with respect to the latent video grid | ||
| patch_size = (self.patch_size_t, self.patch_size, self.patch_size) | ||
| patch_size_delta = torch.tensor(patch_size, dtype=grid.dtype, device=grid.device) | ||
| patch_ends = grid + patch_size_delta.view(3, 1, 1, 1) | ||
| patch_size_delta = torch.stack( | ||
| [ | ||
| grid.new_ones(1) * self.patch_size_t, | ||
| grid.new_ones(1) * self.patch_size, | ||
| grid.new_ones(1) * self.patch_size, | ||
| ] | ||
| ).reshape(3, 1, 1, 1) | ||
| patch_ends = grid + patch_size_delta | ||
|
|
||
| # Combine the start (grid) and end (patch_ends) coordinates along new trailing dimension | ||
| latent_coords = torch.stack([grid, patch_ends], dim=-1) # [3, N_F, N_H, N_W, 2] | ||
|
|
@@ -905,7 +910,7 @@ def prepare_video_coords( | |
| latent_coords = latent_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1) | ||
|
|
||
| # 3. Calculate the pixel space patch boundaries from the latent boundaries. | ||
| scale_tensor = torch.tensor(self.scale_factors, device=latent_coords.device) | ||
| scale_tensor = torch.stack([latent_coords.new_ones(1) * factor for factor in self.scale_factors]) | ||
|
Comment on lines
-908
to
+913
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also avoids implicit cudaStreamSynchronize (~60ms) by replacing torch.tensor(...) with device-native tensor construction. |
||
| # Broadcast the VAE scale factors such that they are compatible with latent_coords's shape | ||
| broadcast_shape = [1] * latent_coords.ndim | ||
| broadcast_shape[1] = -1 # This is the (frame, height, width) dim | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1189,6 +1189,11 @@ def __call__( | |
| num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) | ||
| self._num_timesteps = len(timesteps) | ||
|
|
||
| # Set begin index to skip nonzero().item() call in scheduler initialization, which triggers GPU sync | ||
| if hasattr(self.scheduler, "set_begin_index"): | ||
| self.scheduler.set_begin_index(0) | ||
| audio_scheduler.set_begin_index(0) | ||
|
Comment on lines
+1194
to
+1195
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Move it out of the |
||
|
|
||
| # 6. Prepare micro-conditions | ||
| # Pre-compute video and audio positional ids as they will be the same at each step of the denoising loop | ||
| video_coords = self.transformer.rope.prepare_video_coords( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -363,11 +363,13 @@ def set_timesteps( | |
| sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps) | ||
|
|
||
| # 5. Convert sigmas and timesteps to tensors and move to specified device | ||
| sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) | ||
| sigmas = torch.from_numpy(sigmas).pin_memory().to(dtype=torch.float32, device=device, non_blocking=True) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a hard no. We cannot be pinning memory and running it async within the scheduler.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @sayakpaul Should I provide an alternative implementation to avoid the cudaStreamSynchronize, or is this sync considered acceptable in this case?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My current thoughts are that an alternative implementation which avoids the |
||
| if not is_timesteps_provided: | ||
| timesteps = sigmas * self.config.num_train_timesteps | ||
| else: | ||
| timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=device) | ||
| timesteps = ( | ||
| torch.from_numpy(timesteps).pin_memory().to(dtype=torch.float32, device=device, non_blocking=True) | ||
| ) | ||
|
|
||
| # 6. Append the terminal sigma value. | ||
| # If a model requires inverted sigma schedule for denoising but timesteps without inversion, the | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This refactor seems unnecessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This replaces host-side tensor construction with device-native ops to eliminate an implicit cudaStreamSynchronize (~60ms).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This refactor eliminates a CPU --> GPU sync (since the
patch_sizetuple lives on the CPU host,torch.tensorneeds to copy it to GPU, which the refactor avoids by doing the operations on the GPU), but because the tuple is really small, it looks like the correspondingcudaStreamSynchronizeblock (and that of the similarscale_tensorrefactor below) takes about 2ms. I think most of the removed non-scheduler sync time is in theconnectors.pyrefactor.