Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/diffusers/models/unets/unet_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,14 @@ def forward(
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
is_mps = sample.device.type == "mps"
is_npu = sample.device.type == "npu"
if isinstance(timestep, float):
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else:
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)

Expand Down
9 changes: 8 additions & 1 deletion src/diffusers/models/unets/unet_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,14 @@ def forward(
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
is_mps = sample.device.type == "mps"
is_npu = sample.device.type == "npu"
if isinstance(timestep, float):
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else:
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)

Expand Down
3 changes: 2 additions & 1 deletion src/diffusers/models/unets/unet_2d_condition_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,8 @@ def __call__(
"""
# 1. time
if not isinstance(timesteps, jnp.ndarray):
timesteps = jnp.array([timesteps], dtype=jnp.int32)
dtype = jnp.float32 if isinstance(timesteps, float) else jnp.int32
timesteps = jnp.array([timesteps], dtype=dtype)
elif isinstance(timesteps, jnp.ndarray) and len(timesteps.shape) == 0:
timesteps = timesteps.astype(dtype=jnp.float32)
timesteps = jnp.expand_dims(timesteps, 0)
Expand Down
Loading