Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 63 additions & 13 deletions src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
from typing import List, Literal, Optional, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -102,12 +102,21 @@ def __init__(
use_karras_sigmas: Optional[bool] = False,
use_exponential_sigmas: Optional[bool] = False,
use_beta_sigmas: Optional[bool] = False,
time_shift_type: str = "exponential",
time_shift_type: Literal["exponential", "linear"] = "exponential",
stochastic_sampling: bool = False,
):
if self.config.use_beta_sigmas and not is_scipy_available():
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
if (
sum(
[
self.config.use_beta_sigmas,
self.config.use_exponential_sigmas,
self.config.use_karras_sigmas,
]
)
> 1
):
raise ValueError(
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
)
Expand Down Expand Up @@ -166,6 +175,13 @@ def set_begin_index(self, begin_index: int = 0):
self._begin_index = begin_index

def set_shift(self, shift: float):
"""
Sets the shift value for the scheduler.

Args:
shift (`float`):
The shift value to be set.
"""
self._shift = shift

def scale_noise(
Expand Down Expand Up @@ -218,10 +234,25 @@ def scale_noise(

return sample

def _sigma_to_t(self, sigma):
def _sigma_to_t(self, sigma) -> float:
return sigma * self.config.num_train_timesteps

def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
def time_shift(self, mu: float, sigma: float, t: torch.Tensor) -> torch.Tensor:
"""
Apply time shifting to the sigmas.

Args:
mu (`float`):
The mu parameter for the time shift.
sigma (`float`):
The sigma parameter for the time shift.
t (`torch.Tensor`):
The input timesteps.

Returns:
`torch.Tensor`:
The time-shifted timesteps.
"""
if self.config.time_shift_type == "exponential":
return self._time_shift_exponential(mu, sigma, t)
elif self.config.time_shift_type == "linear":
Expand Down Expand Up @@ -302,7 +333,9 @@ def set_timesteps(
if sigmas is None:
if timesteps is None:
timesteps = np.linspace(
self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
self._sigma_to_t(self.sigma_max),
self._sigma_to_t(self.sigma_min),
num_inference_steps,
)
sigmas = timesteps / self.config.num_train_timesteps
else:
Expand Down Expand Up @@ -350,7 +383,24 @@ def set_timesteps(
self._step_index = None
self._begin_index = None

def index_for_timestep(self, timestep, schedule_timesteps=None):
def index_for_timestep(
self,
timestep: Union[float, torch.FloatTensor],
schedule_timesteps: Optional[torch.FloatTensor] = None,
) -> int:
"""
Get the index for the given timestep.

Args:
timestep (`float` or `torch.FloatTensor`):
The timestep to find the index for.
schedule_timesteps (`torch.FloatTensor`, *optional*):
The schedule timesteps to validate against. If `None`, the scheduler's timesteps are used.

Returns:
`int`:
The index of the timestep.
"""
if schedule_timesteps is None:
schedule_timesteps = self.timesteps

Expand All @@ -364,7 +414,7 @@ def index_for_timestep(self, timestep, schedule_timesteps=None):

return indices[pos].item()

def _init_step_index(self, timestep):
def _init_step_index(self, timestep: Union[float, torch.FloatTensor]) -> None:
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
Expand Down Expand Up @@ -405,7 +455,7 @@ def step(
A random number generator.
per_token_timesteps (`torch.Tensor`, *optional*):
The timesteps for each token in the sample.
return_dict (`bool`):
return_dict (`bool`, defaults to `True`):
Whether or not to return a
[`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or tuple.

Expand Down Expand Up @@ -474,7 +524,7 @@ def step(
return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)

# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
"""
Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
Models](https://huggingface.co/papers/2206.00364).
Expand Down Expand Up @@ -595,11 +645,11 @@ def _convert_to_beta(
)
return sigmas

def _time_shift_exponential(self, mu, sigma, t):
def _time_shift_exponential(self, mu: float, sigma: float, t: torch.Tensor) -> torch.Tensor:
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)

def _time_shift_linear(self, mu, sigma, t):
def _time_shift_linear(self, mu: float, sigma: float, t: torch.Tensor) -> torch.Tensor:
return mu / (mu + (1 / t - 1) ** sigma)

def __len__(self):
def __len__(self) -> int:
return self.config.num_train_timesteps
15 changes: 15 additions & 0 deletions src/diffusers/schedulers/scheduling_unipc_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,21 @@ def set_timesteps(

# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.time_shift
def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
"""
Apply time shifting to the sigmas.

Args:
mu (`float`):
The mu parameter for the time shift.
sigma (`float`):
The sigma parameter for the time shift.
t (`torch.Tensor`):
The input timesteps.

Returns:
`torch.Tensor`:
The time-shifted timesteps.
"""
if self.config.time_shift_type == "exponential":
return self._time_shift_exponential(mu, sigma, t)
elif self.config.time_shift_type == "linear":
Expand Down