Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
2066969
update to normalization
nkern May 5, 2026
30ba844
updated data normalization
nkern May 5, 2026
8b112c4
modified: cosmodiff/utils.py
nkern May 5, 2026
93f0c95
modified: cosmodiff/tests/test_utils.py
nkern May 5, 2026
642632b
modified: cosmodiff/data/config.yaml
nkern May 5, 2026
ea59d75
modified: cosmodiff/data/config.yaml
nkern May 6, 2026
3db54c8
fixed training loss for v_prediction
nkern May 6, 2026
811014c
modified: pyproject.toml
nkern May 6, 2026
96e2a51
added CFG to training; added files in data
nkern May 6, 2026
0afeebb
test for continuous conditioning PixArt DIT model
nkern May 6, 2026
fb0f5dd
added EMA weighting to train / checkpoint
nkern May 6, 2026
053343d
modified: cosmodiff/data/config.yaml
nkern May 6, 2026
c9ae990
modified: cosmodiff/data/config.yaml
nkern May 6, 2026
9c33308
modified: scripts/cosmodiff_train.py
nkern May 6, 2026
12c9af8
modified: cosmodiff/tests/test_optim.py
nkern May 6, 2026
fad1d24
modified: cosmodiff/utils.py
nkern May 6, 2026
e92e5fc
fixed ema update
nkern May 7, 2026
1925d9b
modified: cosmodiff/optim.py
nkern May 7, 2026
883de50
default ema_update=1, load_ema_snapshot func
nkern May 7, 2026
991cbd2
added min-snr to training loss
nkern May 7, 2026
dc78fc3
added log-uniform time sampling
nkern May 7, 2026
f15c730
updated inference script with sampling additions
nkern May 7, 2026
2a97dd7
added ema burn in
nkern May 7, 2026
8560159
added transform.py, moved Normalization too
nkern May 8, 2026
70a2151
updated load_data to handle multiple paths
nkern May 8, 2026
55b75c7
added support for multiple data files, multiple normalizations
nkern May 9, 2026
ae03be4
modified: cosmodiff/utils.py
nkern May 9, 2026
3ac6e8b
modified: cosmodiff/utils.py
nkern May 10, 2026
3abf182
modified: cosmodiff/utils.py
nkern May 10, 2026
8109dfd
pyproject fixes
nkern May 10, 2026
9e24865
updated versioning system
nkern May 10, 2026
8508e3f
changed rfft2 to fft2
nkern May 11, 2026
87f0abc
added scheduler to generate, rm ddim_thin
nkern May 11, 2026
c2a50c8
added flowmatching
nkern May 11, 2026
8542f9f
added churn parameters for FM
nkern May 11, 2026
409bb98
scale model for VE infernece
nkern May 11, 2026
e77d006
modified: transform.py
nkern May 11, 2026
392fb00
updated readme
nkern May 11, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 43 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,19 +1,55 @@
# cosmo_diffusion
![banner](docs/_static/banner.png)

Train 2D/3D diffusion models (UNet and DiT) for cosmological applications.
Train 2D/3D diffusion (and flow-matching) models — UNet, UNet-Conditional, DiT,
and PixArt — for cosmological applications.

## Install

```bash
git clone https://github.com/nkern/cosmo_diffusion
cd cosmo_diffusion
pip install .
pip install -e .
```

## Running
Look at `cosmodiff/configs/config.yaml` for a configuration file. Then just run:
## Dependencies

- `numpy`
- `torch`
- `diffusers`
- `accelerate`
- `tqdm`
- `pyyaml`
- `scipy`
- `h5py`
- `matplotlib`
- `ema-pytorch`

## Quick demo

Configure a training run in `cosmodiff/data/config.yaml` (paths, model,
scheduler, training kwargs), then launch:

```bash
cosmodiff_train.py --config path_to_config
cosmodiff_train.py --config path/to/config.yaml
```

checkpointing and metrics are automatically stored in `output_dir` (defined in the config).
Checkpoints and metrics are written automatically to the `output_dir` set in
the config. To sample from a trained checkpoint:

```bash
cosmodiff_sample.py --output_dir path/to/run \
--n_samples 64 --output samples.npy
```

For fast inference, swap in a higher-order solver:

```bash
cosmodiff_sample.py --output_dir path/to/run \
--scheduler DPMSolverMultistepScheduler --num_steps 25 \
--n_samples 64 --output samples.npy
```

## Authors

- Nicholas Kern
- Jiaming Pan
5 changes: 4 additions & 1 deletion cosmodiff/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from importlib.metadata import version as _pkg_version

from . import utils
from . import augment
from . import optim

from .optim import train, generate
from .version import __version__

__version__ = _pkg_version("cosmodiff")
19 changes: 11 additions & 8 deletions cosmodiff/augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,17 @@ class RandomRoll(nn.Module):
Args:
dims (tuple of int): Dimensions to roll along. Defaults to ``(-1,)``.
"""
def __init__(self, dims=(-1,)):
def __init__(self, size=128, dims=(-1,)):
super().__init__()
self.size = size
self.dims = tuple(dims)
self.ndim = len(dims)

def __call__(self, x):
if x is None:
return None
shift = (torch.rand(self.ndim, device='cpu') * torch.tensor(x.shape)[list(self.dims)]).round().to(torch.long)
return torch.roll(x, tuple(shift), self.dims)
shift = torch.randint(0, self.size, (self.ndim,), dtype=torch.long, device='cpu').tolist()
return torch.roll(x, shift, self.dims)

def __repr__(self) -> str:
return f"{self.__class__.__name__}(dims={self.dims})"
Expand Down Expand Up @@ -68,15 +69,18 @@ class RandomFlip(nn.Module):
"""
def __init__(self, dims=(-2, -1), p=0.5):
super().__init__()
self.dims = tuple(dims)
self.dims = list(dims)
self.p = p

def __call__(self, x):
if x is None:
return None
flip = torch.rand(len(self.dims), device='cpu')
flip = torch.where(flip < self.p)[0].tolist()
return torch.flip(x, flip)
draw = torch.rand(len(self.dims), device='cpu')
flip = torch.where(draw < self.p)[0].tolist()
if len(flip) == 0:
return x
else:
return torch.flip(x, [self.dims[f] for f in flip])

def __repr__(self) -> str:
return f"{self.__class__.__name__}(dims={self.dims}, p={self.p})"
Expand Down Expand Up @@ -119,4 +123,3 @@ def config_augmentations(augmentations):

return nn.Sequential(*pipeline)


85 changes: 0 additions & 85 deletions cosmodiff/configs/config.yaml

This file was deleted.

Binary file added cosmodiff/data/IllustrisTNG_Mcdm.npy
Binary file not shown.
3 changes: 3 additions & 0 deletions cosmodiff/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from pathlib import Path

DATA_PATH = Path(__file__).parent
119 changes: 119 additions & 0 deletions cosmodiff/data/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# cosmodiff training config

# --- global ---
global:
device: cpu
dtype: float32

# --- io ---
io:
output_dir: /path/to/output

# --- data ---
data:
img_path: /path/to/images.npy
img_read_fn: npy_read_fn # any function in cosmodiff.utils
label_path: /path/to/labels.npy # optional
label_read_fn: npy_read_fn # optional
reshape: 2d # '2d' (z-slice → channel), '3d' (volume), or null (no reshape)
zthin: 4 # thinning factor along z when reshape='2d'
n_samples: null
seed: null
keep_on_cpu: true
normalization: center-max # data norm and kwargs
norm_kwargs:
center: null
xmax: null
alpha: null
beta: null
transform:
[log]

# --- augmentations ---
augmentations:
RandomRoll:
size: 128
dims: [-1, -2]
RandomFlip:
dims: [-1, -2]
p: 0.5

# --- model ---
model:
class: UNet2DModel
kwargs:
sample_size: 64
in_channels: 1
out_channels: 1
layers_per_block: 2
block_out_channels: [64, 128, 256]
down_block_types:
- DownBlock2D
- DownBlock2D
- DownBlock2D
up_block_types:
- UpBlock2D
- UpBlock2D
- UpBlock2D
norm_num_groups: 32

# --- noise scheduler ---
noise_scheduler:
class: DDPMScheduler
kwargs:
num_train_timesteps: 1000
rescale_betas_zero_snr: true

# --- optimizer ---
optimizer:
class: AdamW
kwargs:
lr: 1.0e-4
weight_decay: 1.0e-2

# --- lr scheduler ---
lr_scheduler:
class: ConstantLR
kwargs:
factor: 1.0
total_iters: 0

# --- training ---
train:
num_epochs: 50
batch_size: 16
shuffle: true
checkpoint_every_n_epochs: 5
mixed_precision: fp16
gradient_accumulation_steps: 1
dataloader_num_workers: 4
max_grad_norm: 1.0
conditioning: discrete # 'discrete' (class_labels) or 'continuous' (encoder_hidden_states)
cfg_dropout: 0.0 # fraction of labels dropped for CFG training (0 disables)
ema_sigma_rels: [.02, .10] # e.g. [0.02, 0.10] to enable post-hoc EMA tracking
ema_update_every: 1 # optimizer steps between EMA updates (Karras assumes per-step)
ema_burn_in: 1000 # skip first N optimizer steps before starting EMA tracking
min_snr_gamma: null # set to e.g. 5.0 to enable Min-SNR loss weighting (Hang et al. 2023)
sigma_log_normal: null # set to e.g. [-1.2, 1.2] for EDM-style log-σ ~ Normal(P_mean, P_std) sampling
verbose: true
force_cpu: false
pin_memory: false

# --- generation ---
generate:
scheduler: null # new scheduler for inference, default is train scheduler
num_steps: null # number of inference steps, default is training steps
s_churn: null # EDM stochasticity (Euler/Heun-family only); 0=ODE, larger=more SDE-like
s_tmin: null # restrict churn to t >= s_tmin
s_tmax: null # restrict churn to t <= s_tmax
s_noise: null # noise magnitude multiplier during churn (default 1.0)
n_samples: 64 # total number of samples to generate
batch_size: null # samples per forward pass; null → single batch of n_samples
image_shape: null # e.g. [1, 64, 64]; null → inferred from model.config
conditioning: discrete # 'discrete' or 'continuous' (must match training)
labels: null # discrete: list of int class labels (length 1 or n_samples)
continuous_labels: null # continuous: path to .npy of shape (n_samples, D) or (1, D)
guidance_scale: null # CFG guidance scale; null disables amplification
ema_sigma_rel: null # synthesize EMA at this target before sampling; null disables
seed: null # set for reproducibility
device: null # null → cuda if available else cpu
Loading
Loading