Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion src/diffusers/schedulers/scheduling_ddim_cogvideox.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ def step(
# - pred_prev_sample -> "x_t-1"

# 1. get previous step value (=t-1)
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
prev_timestep = self.previous_timestep(timestep)

# 2. compute alphas, betas
alpha_prod_t = self.alphas_cumprod[timestep]
Expand Down Expand Up @@ -500,5 +500,24 @@ def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: tor
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
return velocity

def previous_timestep(self, timestep: int) -> torch.Tensor:
"""
Find the previous timestep in the scheduler's timestep schedule.

Args:
timestep (`int`):
The current timestep.

Returns:
`torch.Tensor`:
The previous timestep. Returns -1 if the current timestep is the last one.
"""
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
if index == self.timesteps.shape[0] - 1:
prev_t = torch.tensor(-1)
else:
prev_t = self.timesteps[index + 1]
return prev_t

def __len__(self) -> int:
return self.config.num_train_timesteps
21 changes: 20 additions & 1 deletion src/diffusers/schedulers/scheduling_dpm_cogvideox.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ def step(
# - pred_prev_sample -> "x_t-1"

# 1. get previous step value (=t-1)
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
prev_timestep = self.previous_timestep(timestep)

# 2. compute alphas, betas
alpha_prod_t = self.alphas_cumprod[timestep]
Expand Down Expand Up @@ -599,5 +599,24 @@ def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: tor
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
return velocity

def previous_timestep(self, timestep: int) -> torch.Tensor:
"""
Find the previous timestep in the scheduler's timestep schedule.

Args:
timestep (`int`):
The current timestep.

Returns:
`torch.Tensor`:
The previous timestep. Returns -1 if the current timestep is the last one.
"""
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
if index == self.timesteps.shape[0] - 1:
prev_t = torch.tensor(-1)
else:
prev_t = self.timesteps[index + 1]
return prev_t

def __len__(self) -> int:
return self.config.num_train_timesteps