|
24 | 24 | from ..utils import deprecate, is_scipy_available |
25 | 25 | from ..utils.torch_utils import randn_tensor |
26 | 26 | from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput |
27 | | - |
| 27 | +from ..utils import logging |
| 28 | +logger = logging.get_logger(__name__) |
28 | 29 |
|
29 | 30 | if is_scipy_available(): |
30 | 31 | import scipy.stats |
@@ -411,29 +412,34 @@ def set_timesteps( |
411 | 412 | if self.config.use_karras_sigmas: |
412 | 413 | sigmas = np.flip(sigmas).copy() |
413 | 414 | sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) |
414 | | - timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) |
415 | | - if self.config.beta_schedule != "squaredcos_cap_v2": |
416 | | - timesteps = timesteps.round() |
| 415 | + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() |
| 416 | + timesteps, sigmas = self._ensure_unique_timesteps(timesteps, sigmas, num_inference_steps) |
| 417 | + |
417 | 418 | elif self.config.use_lu_lambdas: |
418 | 419 | lambdas = np.flip(log_sigmas.copy()) |
419 | 420 | lambdas = self._convert_to_lu(in_lambdas=lambdas, num_inference_steps=num_inference_steps) |
420 | 421 | sigmas = np.exp(lambdas) |
421 | | - timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) |
422 | | - if self.config.beta_schedule != "squaredcos_cap_v2": |
423 | | - timesteps = timesteps.round() |
| 422 | + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() |
| 423 | + timesteps, sigmas = self._ensure_unique_timesteps(timesteps, sigmas, num_inference_steps) |
| 424 | + |
424 | 425 | elif self.config.use_exponential_sigmas: |
425 | 426 | sigmas = np.flip(sigmas).copy() |
426 | 427 | sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps) |
427 | | - timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) |
| 428 | + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() |
| 429 | + timesteps, sigmas = self._ensure_unique_timesteps(timesteps, sigmas, num_inference_steps) |
| 430 | + |
428 | 431 | elif self.config.use_beta_sigmas: |
429 | 432 | sigmas = np.flip(sigmas).copy() |
430 | 433 | sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps) |
431 | | - timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) |
| 434 | + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() |
| 435 | + timesteps, sigmas = self._ensure_unique_timesteps(timesteps, sigmas, num_inference_steps) |
| 436 | + |
432 | 437 | elif self.config.use_flow_sigmas: |
433 | 438 | alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1) |
434 | 439 | sigmas = 1.0 - alphas |
435 | 440 | sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy() |
436 | 441 | timesteps = (sigmas * self.config.num_train_timesteps).copy() |
| 442 | + |
437 | 443 | else: |
438 | 444 | sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) |
439 | 445 |
|
@@ -544,6 +550,38 @@ def _sigma_to_t(self, sigma: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray: |
544 | 550 | t = t.reshape(sigma.shape) |
545 | 551 | return t |
546 | 552 |
|
| 553 | + def _ensure_unique_timesteps(self, timesteps, sigmas, num_inference_steps): |
| 554 | + """ |
| 555 | + Ensure timesteps are unique and handle duplicates while preserving the correspondence with sigmas. |
| 556 | +
|
| 557 | + Args: |
| 558 | + timesteps (`np.ndarray`): |
| 559 | + The timestep values that may contain duplicates. |
| 560 | + sigmas (`np.ndarray`): |
| 561 | + The sigma values corresponding to the timesteps. |
| 562 | + num_inference_steps (`int`): |
| 563 | + The number of inference steps originally requested. |
| 564 | +
|
| 565 | + Returns: |
| 566 | + `Tuple[np.ndarray, np.ndarray]`: |
| 567 | + A tuple of (timesteps, sigmas) where timesteps are unique and sigmas are filtered accordingly. |
| 568 | + """ |
| 569 | + unique_timesteps, unique_indices = np.unique(timesteps, return_index=True) |
| 570 | + |
| 571 | + if len(unique_timesteps) < len(timesteps): |
| 572 | + # Sort by original indices to maintain order |
| 573 | + unique_indices_sorted = np.sort(unique_indices) |
| 574 | + timesteps = timesteps[unique_indices_sorted] |
| 575 | + sigmas = sigmas[unique_indices_sorted] |
| 576 | + |
| 577 | + if len(timesteps) < num_inference_steps: |
| 578 | + logger.warning( |
| 579 | + f"Due to the current scheduler configuration, only {len(timesteps)} unique timesteps " |
| 580 | + f"could be generated instead of the requested {num_inference_steps}." |
| 581 | + ) |
| 582 | + |
| 583 | + return timesteps, sigmas |
| 584 | + |
547 | 585 | def _sigma_to_alpha_sigma_t(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
548 | 586 | """ |
549 | 587 | Convert sigma values to alpha_t and sigma_t values. |
|
0 commit comments