Skip to content

Commit 1515e50

Browse files
committed
Fix duplicate timesteps in DPMSolverMultistepScheduler with sigma conversion methods
1 parent a1f36ee commit 1515e50

2 files changed

Lines changed: 75 additions & 9 deletions

File tree

src/diffusers/schedulers/scheduling_dpmsolver_multistep.py

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424
from ..utils import deprecate, is_scipy_available
2525
from ..utils.torch_utils import randn_tensor
2626
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
27-
27+
from ..utils import logging
28+
logger = logging.get_logger(__name__)
2829

2930
if is_scipy_available():
3031
import scipy.stats
@@ -411,29 +412,34 @@ def set_timesteps(
411412
if self.config.use_karras_sigmas:
412413
sigmas = np.flip(sigmas).copy()
413414
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
414-
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
415-
if self.config.beta_schedule != "squaredcos_cap_v2":
416-
timesteps = timesteps.round()
415+
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
416+
timesteps, sigmas = self._ensure_unique_timesteps(timesteps, sigmas, num_inference_steps)
417+
417418
elif self.config.use_lu_lambdas:
418419
lambdas = np.flip(log_sigmas.copy())
419420
lambdas = self._convert_to_lu(in_lambdas=lambdas, num_inference_steps=num_inference_steps)
420421
sigmas = np.exp(lambdas)
421-
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
422-
if self.config.beta_schedule != "squaredcos_cap_v2":
423-
timesteps = timesteps.round()
422+
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
423+
timesteps, sigmas = self._ensure_unique_timesteps(timesteps, sigmas, num_inference_steps)
424+
424425
elif self.config.use_exponential_sigmas:
425426
sigmas = np.flip(sigmas).copy()
426427
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
427-
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
428+
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
429+
timesteps, sigmas = self._ensure_unique_timesteps(timesteps, sigmas, num_inference_steps)
430+
428431
elif self.config.use_beta_sigmas:
429432
sigmas = np.flip(sigmas).copy()
430433
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
431-
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
434+
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
435+
timesteps, sigmas = self._ensure_unique_timesteps(timesteps, sigmas, num_inference_steps)
436+
432437
elif self.config.use_flow_sigmas:
433438
alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)
434439
sigmas = 1.0 - alphas
435440
sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy()
436441
timesteps = (sigmas * self.config.num_train_timesteps).copy()
442+
437443
else:
438444
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
439445

@@ -544,6 +550,38 @@ def _sigma_to_t(self, sigma: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray:
544550
t = t.reshape(sigma.shape)
545551
return t
546552

553+
def _ensure_unique_timesteps(self, timesteps, sigmas, num_inference_steps):
554+
"""
555+
Ensure timesteps are unique and handle duplicates while preserving the correspondence with sigmas.
556+
557+
Args:
558+
timesteps (`np.ndarray`):
559+
The timestep values that may contain duplicates.
560+
sigmas (`np.ndarray`):
561+
The sigma values corresponding to the timesteps.
562+
num_inference_steps (`int`):
563+
The number of inference steps originally requested.
564+
565+
Returns:
566+
`Tuple[np.ndarray, np.ndarray]`:
567+
A tuple of (timesteps, sigmas) where timesteps are unique and sigmas are filtered accordingly.
568+
"""
569+
unique_timesteps, unique_indices = np.unique(timesteps, return_index=True)
570+
571+
if len(unique_timesteps) < len(timesteps):
572+
# Sort by original indices to maintain order
573+
unique_indices_sorted = np.sort(unique_indices)
574+
timesteps = timesteps[unique_indices_sorted]
575+
sigmas = sigmas[unique_indices_sorted]
576+
577+
if len(timesteps) < num_inference_steps:
578+
logger.warning(
579+
f"Due to the current scheduler configuration, only {len(timesteps)} unique timesteps "
580+
f"could be generated instead of the requested {num_inference_steps}."
581+
)
582+
583+
return timesteps, sigmas
584+
547585
def _sigma_to_alpha_sigma_t(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
548586
"""
549587
Convert sigma values to alpha_t and sigma_t values.

tests/schedulers/test_scheduler_dpm_multi.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,3 +366,31 @@ def test_beta_sigmas(self):
366366

367367
def test_exponential_sigmas(self):
368368
self.check_over_configs(use_exponential_sigmas=True)
369+
370+
def test_no_duplicate_timesteps_with_sigma_methods(self):
371+
sigma_configs = [
372+
{"use_karras_sigmas": True},
373+
{"use_lu_lambdas": True},
374+
{"use_exponential_sigmas": True},
375+
{"use_beta_sigmas": True},
376+
]
377+
378+
for config in sigma_configs:
379+
scheduler = DPMSolverMultistepScheduler(
380+
num_train_timesteps=1000,
381+
beta_schedule="squaredcos_cap_v2",
382+
**config,
383+
)
384+
scheduler.set_timesteps(20)
385+
386+
sample = torch.randn(4, 3, 32, 32)
387+
388+
try:
389+
for t in scheduler.timesteps:
390+
model_output = torch.randn_like(sample)
391+
output = scheduler.step(model_output, t, sample)
392+
sample = output.prev_sample
393+
except IndexError as e:
394+
self.fail(f"Index error occurred with config {config}: {e}")
395+
except Exception as e:
396+
self.fail(f"Unexpected error with config {config}: {e}")

0 commit comments

Comments
 (0)