Skip to content

fix(unets): preserve scalar float timestep dtype in UNet1D/2D + FlaxUNet2DCondition (#13654)#13669

Open
Anai-Guo wants to merge 3 commits intohuggingface:mainfrom
Anai-Guo:fix-unet-1d-2d-flax-float-timestep
Open

fix(unets): preserve scalar float timestep dtype in UNet1D/2D + FlaxUNet2DCondition (#13654)#13669
Anai-Guo wants to merge 3 commits intohuggingface:mainfrom
Anai-Guo:fix-unet-1d-2d-flax-float-timestep

Conversation

@Anai-Guo
Copy link
Copy Markdown

@Anai-Guo Anai-Guo commented May 1, 2026

Summary

Fixes Issue 2 from the model_unets_shared review (#13654): scalar Python float timesteps are silently truncated to integer dtype in three unconditional / Flax UNet forward paths.

Problem

UNet1DModel, UNet2DModel, and FlaxUNet2DConditionModel all type their public timestep argument as torch.Tensor | float | int (or accept it via the __call__ signature in Flax), but in the no-tensor branch they wrap the value with an integer dtype:

# unet_1d.py / unet_2d.py
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)

# unet_2d_condition_flax.py
timesteps = jnp.array([timesteps], dtype=jnp.int32)

So unet(sample, 1e-4) becomes 0 before time_proj, breaking VE/NCSN-style small-sigma timesteps and making UNet2DModel produce non-finite outputs (the Fourier projection divides by the truncated value).

Reproduction (from #13654)

import torch
from diffusers import UNet2DModel

unet2d = UNet2DModel(
    sample_size=8, in_channels=3, out_channels=3,
    block_out_channels=(8,), layers_per_block=1,
    down_block_types=("DownBlock2D",), up_block_types=("UpBlock2D",),
    norm_num_groups=4, time_embedding_type="fourier",
).eval()

sample = torch.randn(1, 3, 8, 8)
print(torch.isfinite(unet2d(sample, 1e-4).sample).all().item())              # False (before)
print(torch.isfinite(unet2d(sample, torch.tensor([1e-4])).sample).all().item())  # True

After this fix the scalar-float path matches the tensor path.

Fix

For UNet1DModel and UNet2DModel, mirror the precedent already used in UNet2DConditionModel.get_time_embed (and UNet3DConditionModel.forward):

if not torch.is_tensor(timesteps):
    is_mps = sample.device.type == "mps"
    is_npu = sample.device.type == "npu"
    if isinstance(timestep, float):
        dtype = torch.float32 if (is_mps or is_npu) else torch.float64
    else:
        dtype = torch.int32 if (is_mps or is_npu) else torch.int64
    timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)

For FlaxUNet2DConditionModel, pick jnp.float32 for Python floats and keep jnp.int32 otherwise:

dtype = jnp.float32 if isinstance(timesteps, float) else jnp.int32
timesteps = jnp.array([timesteps], dtype=dtype)

This is a behavior-preserving change for integer timesteps — the only path that changes is the previously-broken float branch.

Scope

Three files, ~16 LoC total. Scoped to Issue 2 only; Issues 1, 3, 4 from #13654 are left for follow-up PRs.

Refs

🤖 Generated with Claude Code

@github-actions github-actions Bot added models size/S PR with diff < 50 LOC labels May 1, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

models size/S PR with diff < 50 LOC

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant