Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 30 additions & 21 deletions src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver and https://github.com/NVlabs/edm

import math
from typing import List, Optional, Tuple, Union
from typing import List, Literal, Optional, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -51,13 +51,15 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
schedule was incorporated in this model: https://huggingface.co/stabilityai/cosxl.
num_train_timesteps (`int`, defaults to 1000):
The number of diffusion steps to train the model.
solver_order (`int`, defaults to 2):
The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided
sampling, and `solver_order=3` for unconditional sampling.
prediction_type (`str`, defaults to `epsilon`, *optional*):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
Video](https://huggingface.co/papers/2210.02303) paper).
rho (`float`, *optional*, defaults to 7.0):
The rho parameter in the Karras sigma schedule. This was set to 7.0 in the EDM paper [1].
solver_order (`int`, defaults to 2):
The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided
sampling, and `solver_order=3` for unconditional sampling.
thresholding (`bool`, defaults to `False`):
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
as Stable Diffusion.
Expand Down Expand Up @@ -94,19 +96,19 @@ def __init__(
sigma_min: float = 0.002,
sigma_max: float = 80.0,
sigma_data: float = 0.5,
sigma_schedule: str = "karras",
sigma_schedule: Literal["karras", "exponential"] = "karras",
num_train_timesteps: int = 1000,
prediction_type: str = "epsilon",
prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon",
rho: float = 7.0,
solver_order: int = 2,
thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995,
sample_max_value: float = 1.0,
algorithm_type: str = "dpmsolver++",
solver_type: str = "midpoint",
algorithm_type: Literal["dpmsolver++", "sde-dpmsolver++"] = "dpmsolver++",
solver_type: Literal["midpoint", "heun"] = "midpoint",
lower_order_final: bool = True,
euler_at_final: bool = False,
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
final_sigmas_type: Optional[Literal["zero", "sigma_min"]] = "zero", # "zero", "sigma_min"
):
# settings for DPM-Solver
if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"]:
Expand Down Expand Up @@ -145,19 +147,19 @@ def __init__(
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication

@property
def init_noise_sigma(self):
def init_noise_sigma(self) -> float:
# standard deviation of the initial noise distribution
return (self.config.sigma_max**2 + 1) ** 0.5

@property
def step_index(self):
def step_index(self) -> int:
"""
The index counter for current timestep. It will increase 1 after each scheduler step.
"""
return self._step_index

@property
def begin_index(self):
def begin_index(self) -> int:
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
Expand Down Expand Up @@ -274,7 +276,11 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.T
self.is_scale_input_called = True
return sample

def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
def set_timesteps(
self,
num_inference_steps: int = None,
device: Optional[Union[str, torch.device]] = None,
):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).

Expand Down Expand Up @@ -460,13 +466,12 @@ def _sigma_to_t(self, sigma, log_sigmas):
def _sigma_to_alpha_sigma_t(self, sigma):
alpha_t = torch.tensor(1) # Inputs are pre-scaled before going into unet, so alpha_t = 1
sigma_t = sigma

return alpha_t, sigma_t

def convert_model_output(
self,
model_output: torch.Tensor,
sample: torch.Tensor = None,
sample: torch.Tensor,
) -> torch.Tensor:
"""
Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
Expand Down Expand Up @@ -497,7 +502,7 @@ def convert_model_output(
def dpm_solver_first_order_update(
self,
model_output: torch.Tensor,
sample: torch.Tensor = None,
sample: torch.Tensor,
noise: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Expand All @@ -508,6 +513,8 @@ def dpm_solver_first_order_update(
The direct output from the learned diffusion model.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
noise (`torch.Tensor`, *optional*):
The noise tensor to add to the original samples.

Returns:
`torch.Tensor`:
Expand Down Expand Up @@ -538,7 +545,7 @@ def dpm_solver_first_order_update(
def multistep_dpm_solver_second_order_update(
self,
model_output_list: List[torch.Tensor],
sample: torch.Tensor = None,
sample: torch.Tensor,
noise: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Expand All @@ -549,6 +556,8 @@ def multistep_dpm_solver_second_order_update(
The direct outputs from learned diffusion model at current and latter timesteps.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
noise (`torch.Tensor`, *optional*):
The noise tensor to add to the original samples.

Returns:
`torch.Tensor`:
Expand Down Expand Up @@ -609,7 +618,7 @@ def multistep_dpm_solver_second_order_update(
def multistep_dpm_solver_third_order_update(
self,
model_output_list: List[torch.Tensor],
sample: torch.Tensor = None,
sample: torch.Tensor,
) -> torch.Tensor:
"""
One step for the third-order multistep DPMSolver.
Expand Down Expand Up @@ -698,7 +707,7 @@ def index_for_timestep(
return step_index

# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
def _init_step_index(self, timestep):
def _init_step_index(self, timestep: Union[int, torch.Tensor]) -> None:
"""
Initialize the step_index counter for the scheduler.

Expand All @@ -719,7 +728,7 @@ def step(
model_output: torch.Tensor,
timestep: Union[int, torch.Tensor],
sample: torch.Tensor,
generator=None,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
"""
Expand Down Expand Up @@ -860,5 +869,5 @@ def _get_conditioning_c_in(self, sigma: Union[float, torch.Tensor]) -> Union[flo
c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
return c_in

def __len__(self):
def __len__(self) -> int:
return self.config.num_train_timesteps
Loading