Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,15 +442,11 @@ def set_timesteps(
sigmas = np.flip(sigmas).copy()
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
if self.config.beta_schedule != "squaredcos_cap_v2":
timesteps = timesteps.round()
elif self.config.use_lu_lambdas:
lambdas = np.flip(log_sigmas.copy())
lambdas = self._convert_to_lu(in_lambdas=lambdas, num_inference_steps=num_inference_steps)
sigmas = np.exp(lambdas)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
if self.config.beta_schedule != "squaredcos_cap_v2":
timesteps = timesteps.round()
elif self.config.use_exponential_sigmas:
sigmas = np.flip(sigmas).copy()
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
Expand All @@ -476,6 +472,17 @@ def set_timesteps(
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
)

# When sigma-to-timestep mapping produces duplicates after rounding to integer
# (e.g. cosine schedule with karras/lu sigmas where many sigmas map to nearly the same timestep),
# deduplicate while preserving order to prevent index out-of-bounds errors in multistep solvers.
# This mirrors the same fix in DPMSolverMultistepInverseScheduler.
timesteps = np.round(timesteps).astype(np.int64)
_, unique_indices = np.unique(timesteps, return_index=True)
unique_indices = np.sort(unique_indices)
if len(unique_indices) < len(timesteps):
timesteps = timesteps[unique_indices]
sigmas = sigmas[unique_indices]

sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)

self.sigmas = torch.from_numpy(sigmas)
Expand Down