Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 66 additions & 3 deletions cosmodiff/augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ class RandomFlip(nn.Module):
"""
def __init__(self, dims=(-2, -1), p=0.5):
super().__init__()
self.dims = list(dims)
self.p = p
self.dims = tuple(dims)
self.p = float(p)

def __call__(self, x):
if x is None:
Expand All @@ -86,6 +86,70 @@ def __repr__(self) -> str:
return f"{self.__class__.__name__}(dims={self.dims}, p={self.p})"


class RandomRot90(nn.Module):
"""Randomly rotate a tensor by a multiple of 90 degrees.

Args:
dims (tuple of int): Two dimensions defining the rotation plane.
Defaults to ``(-2, -1)``.
p (float): Probability of applying a random rotation. Defaults to ``1.0``.
"""
def __init__(self, dims=(-2, -1), p=1.0):
super().__init__()
self.dims = tuple(dims)
if len(self.dims) != 2:
raise ValueError("RandomRot90 requires exactly two dims.")
self.p = float(p)

def __call__(self, x):
if x is None:
return None
if self.p <= 0.0:
return x
if self.p < 1.0 and torch.rand((), device='cpu').item() >= self.p:
return x
k = torch.randint(0, 4, (1,), device='cpu').item()
return torch.rot90(x, k, dims=self.dims)

def __repr__(self) -> str:
return f"{self.__class__.__name__}(dims={self.dims}, p={self.p})"


class RandomDihedral2D(nn.Module):
"""Randomly apply one of the eight square symmetries to two dimensions.

This combines a random 90-degree rotation with an optional flip, which is
useful for isotropic 2D fields where orientation should not matter.

Args:
dims (tuple of int): Two image dimensions to transform.
Defaults to ``(-2, -1)``.
p (float): Probability of applying the augmentation. Defaults to ``1.0``.
"""
def __init__(self, dims=(-2, -1), p=1.0):
super().__init__()
self.dims = tuple(dims)
if len(self.dims) != 2:
raise ValueError("RandomDihedral2D requires exactly two dims.")
self.p = float(p)

def __call__(self, x):
if x is None:
return None
if self.p <= 0.0:
return x
if self.p < 1.0 and torch.rand((), device='cpu').item() >= self.p:
return x
k = torch.randint(0, 4, (1,), device='cpu').item()
x = torch.rot90(x, k, dims=self.dims)
if torch.rand((), device='cpu').item() < 0.5:
x = torch.flip(x, [self.dims[0]])
return x

def __repr__(self) -> str:
return f"{self.__class__.__name__}(dims={self.dims}, p={self.p})"


class RandomMove(nn.Module):
"""Randomly swap along specified dimensions.

Expand Down Expand Up @@ -122,4 +186,3 @@ def config_augmentations(augmentations):
pipeline.append(getattr(augment, name)(**kwargs))

return nn.Sequential(*pipeline)

9 changes: 7 additions & 2 deletions cosmodiff/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,8 @@ def generate(
)
if v is not None and k in step_supported
}
if generator is not None and 'generator' in step_supported:
step_kwargs['generator'] = generator

for t in tqdm(noise_scheduler.timesteps, desc="Sampling"):
# FM schedulers use float timesteps; DDPM-family use int — preserve dtype.
Expand Down Expand Up @@ -773,7 +775,7 @@ def generate(
pred = model(images_input, timesteps, return_dict=False)[0]

images = noise_scheduler.step(
pred, t, images, generator=generator, **step_kwargs,
pred, t, images, **step_kwargs,
).prev_sample

if renorm is not None:
Expand Down Expand Up @@ -921,7 +923,10 @@ def synthesize_ema_from_checkpoints(
tmp_path = Path(tmp_dir)
for ckpt_dir in ckpt_dirs:
for pt_file in (ckpt_dir / 'ema').glob('*.pt'):
(tmp_path / pt_file.name).symlink_to(pt_file.resolve())
link_path = tmp_path / pt_file.name
if link_path.exists():
continue
link_path.symlink_to(pt_file.resolve())

ema = PostHocEMA(
model,
Expand Down
74 changes: 74 additions & 0 deletions cosmodiff/tests/test_augment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import pytest
import torch

from cosmodiff.augment import RandomDihedral2D, RandomFlip, RandomRot90, config_augmentations


def test_random_flip_uses_configured_dims():
x = torch.arange(6).reshape(1, 2, 3)

out = RandomFlip(dims=(-1,), p=1.0)(x)

assert torch.equal(out, torch.flip(x, [-1]))


def test_random_flip_noop_when_no_dims_selected():
x = torch.arange(6).reshape(1, 2, 3)

out = RandomFlip(dims=(-1, -2), p=0.0)(x)

assert torch.equal(out, x)


def test_random_rot90_noop_when_probability_is_zero():
x = torch.arange(16).reshape(1, 4, 4)

out = RandomRot90(dims=(-2, -1), p=0.0)(x)

assert torch.equal(out, x)


def test_random_rot90_preserves_square_image_values():
x = torch.arange(16).reshape(1, 4, 4)

out = RandomRot90(dims=(-2, -1), p=1.0)(x)

assert out.shape == x.shape
assert torch.equal(out.flatten().sort().values, x.flatten().sort().values)


def test_random_dihedral_noop_when_probability_is_zero():
x = torch.arange(16).reshape(1, 4, 4)

out = RandomDihedral2D(dims=(-2, -1), p=0.0)(x)

assert torch.equal(out, x)


def test_random_dihedral_preserves_square_image_values():
x = torch.arange(16).reshape(1, 4, 4)

out = RandomDihedral2D(dims=(-2, -1), p=1.0)(x)

assert out.shape == x.shape
assert torch.equal(out.flatten().sort().values, x.flatten().sort().values)


def test_square_symmetry_augmentations_require_two_dims():
with pytest.raises(ValueError, match="requires exactly two dims"):
RandomRot90(dims=(-1,))
with pytest.raises(ValueError, match="requires exactly two dims"):
RandomDihedral2D(dims=(-1,))


def test_config_augmentations_can_build_square_symmetry_pipeline():
x = torch.arange(16).reshape(1, 4, 4)
pipeline = config_augmentations({
"RandomRot90": {"dims": [-2, -1], "p": 1.0},
"RandomDihedral2D": {"dims": [-2, -1], "p": 1.0},
})

out = pipeline(x)

assert out.shape == x.shape
assert torch.equal(out.flatten().sort().values, x.flatten().sort().values)
32 changes: 32 additions & 0 deletions cosmodiff/tests/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
UNet2DConditionModel,
DDPMScheduler,
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerDiscreteScheduler,
HeunDiscreteScheduler,
DiTTransformer2DModel,
PixArtTransformer2DModel,
FlowMatchEulerDiscreteScheduler,
Expand Down Expand Up @@ -244,6 +246,36 @@ def test_generate_num_steps():
assert torch.isfinite(images).all()


def test_generate_dpm_solver_drops_unsupported_generator_kwarg():
"""generate() works with solver schedulers whose step() lacks generator."""
model = _make_unet()
scheduler = DPMSolverMultistepScheduler(num_train_timesteps=10)
generator = torch.Generator().manual_seed(0)
images = generate(
model, scheduler,
batch_size=2, image_shape=(1, 8, 8),
num_steps=5, generator=generator,
)

assert images.shape == (2, 1, 8, 8)
assert torch.isfinite(images).all()


def test_generate_heun_drops_unsupported_generator_kwarg():
"""generate() works with HeunDiscreteScheduler and a reproducible RNG."""
model = _make_unet()
scheduler = HeunDiscreteScheduler(num_train_timesteps=10)
generator = torch.Generator().manual_seed(0)
images = generate(
model, scheduler,
batch_size=2, image_shape=(1, 8, 8),
num_steps=5, generator=generator,
)

assert images.shape == (2, 1, 8, 8)
assert torch.isfinite(images).all()


def test_generate_renorm():
"""renorm callable is applied to the output."""
model = _make_unet()
Expand Down
2 changes: 2 additions & 0 deletions scripts/cosmodiff_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,8 @@ def cfg(key, fallback=None):
model = synthesize_ema_from_checkpoints(
model, checkpoints_dir, sigma_rel_target=ema_sigma_rel,
)
if hasattr(model, "ema_model"):
model = model.ema_model
model = model.to(device)
model.eval()

Expand Down
Loading