Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 97 additions & 35 deletions cosmodiff/optim.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import pickle
import yaml
import numpy as np
import torch
import torch.nn.functional as F
Expand All @@ -14,14 +15,24 @@
from . import utils


def _to_yaml_safe(obj):
"""Recursively convert tuples to lists so yaml.safe_load can round-trip."""
if isinstance(obj, dict):
return {k: _to_yaml_safe(v) for k, v in obj.items()}
if isinstance(obj, (list, tuple)):
return [_to_yaml_safe(v) for v in obj]
return obj


def train(
dataset,
model,
model=None,
*,
noise_scheduler=None, # None → DDPMScheduler(num_train_timesteps=1000)
optimizer=None, # None → AdamW
lr_scheduler=None, # None → ConstantLR()
output_dir: str = "checkpoints",
resume_from_checkpoint: Optional[str] = None,
num_epochs: int = 50,
batch_size: int = 16,
shuffle: bool = True,
Expand All @@ -36,11 +47,10 @@ def train(
):
"""Train a diffusers diffusion model.

To resume from a checkpoint, load it first with ``load_checkpoint()`` and
pass the returned objects directly into this function. Augmentations are
expected to be built into the dataset object directly (e.g. via
``ArrayDataset.augmentations``), and are checkpointed automatically if the
dataset exposes an ``augmentations`` attribute.
``model`` is optional when ``resume_from_checkpoint`` is set — the model
(and any unspecified scheduler/optimizer/augmentations) are loaded from the
checkpoint automatically. Augmentations are checkpointed automatically if
the dataset exposes an ``augmentations`` attribute.

The model's forward call is dispatched automatically: if the batch contains
``"labels"``, they are passed as a keyword argument (for
Expand All @@ -54,8 +64,9 @@ def train(
``"labels"`` key (LongTensor of shape ``(batch_size,)``) for
class-conditional DiT training. Augmentations should be applied
inside the dataset's ``__getitem__``.
model (nn.Module): Pre-instantiated diffusers model (e.g.
``UNet2DModel``, ``DiTTransformer2DModel``).
model (nn.Module, optional): Pre-instantiated diffusers model (e.g.
``UNet2DModel``, ``DiTTransformer2DModel``). May be omitted when
``resume_from_checkpoint`` is provided.
noise_scheduler (optional): Pre-instantiated diffusers noise scheduler.
Defaults to ``DDPMScheduler(num_train_timesteps=1000)``.
optimizer (torch.optim.Optimizer, optional): Optimizer for ``model``.
Expand All @@ -65,6 +76,12 @@ def train(
fixed learning rate for the entire run.
output_dir (str): Root directory for checkpoints and TensorBoard logs.
Defaults to ``"checkpoints"``.
resume_from_checkpoint (str, optional): Path to a checkpoint directory
produced by a previous call to ``train()``. Objects not explicitly
passed (model, noise_scheduler, optimizer, lr_scheduler) are loaded
from the checkpoint. After ``accelerator.prepare()`` the full
training state (optimizer moments, grad scaler, RNG) is restored
via ``accelerator.load_state()``.
num_epochs (int): Total number of training epochs. Defaults to ``50``.
batch_size (int): Per-device batch size. Defaults to ``16``.
shuffle (bool): Shuffle the dataset each epoch. Defaults to ``True``.
Expand Down Expand Up @@ -104,23 +121,29 @@ def train(
# dataset must return dicts with "images" and "labels" keys
train(my_dataset, model)

Resume from a checkpoint::
Resume from a checkpoint (model loaded automatically)::

model, noise_scheduler, optimizer, lr_scheduler, augmentations = (
load_checkpoint("checkpoints/checkpoint-epoch-10")
)
dataset.augmentations = augmentations
train(
my_dataset,
model,
noise_scheduler=noise_scheduler,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
)
train(my_dataset, resume_from_checkpoint="checkpoints/checkpoint-epoch-0010")
"""
# ------------------------------------------------------------------ #
# 1. Defaults #
# 1. Defaults / checkpoint loading #
# ------------------------------------------------------------------ #
start_epoch = 0

if model is None and resume_from_checkpoint is None:
raise ValueError(
"Either `model` or `resume_from_checkpoint` must be provided."
)

if resume_from_checkpoint is not None:
model, noise_scheduler, optimizer, lr_scheduler, _aug = (
utils.load_checkpoint(resume_from_checkpoint)
)
if isinstance(dataset, utils.ArrayDataset):
dataset.augmentations = _aug

start_epoch = int(resume_from_checkpoint.split('-')[-1]) + 1

if noise_scheduler is None:
noise_scheduler = DDPMScheduler(num_train_timesteps=1000)

Expand All @@ -142,6 +165,26 @@ def train(
)
accelerator.init_trackers(project_name="cosmodiff")

# Register hooks so save_state() delegates model serialisation to
# save_pretrained() and load_state() restores via from_pretrained(),
# avoiding a redundant second copy of the weights on disk.
def _save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
for m in models:
m.save_pretrained(output_dir)
weights.clear()

def _load_model_hook(models, input_dir):
for _ in range(len(models)):
m = models.pop()
loaded = m.__class__.from_pretrained(input_dir)
m.register_to_config(**loaded.config)
m.load_state_dict(loaded.state_dict())
del loaded

accelerator.register_save_state_pre_hook(_save_model_hook)
accelerator.register_load_state_pre_hook(_load_model_hook)

# ------------------------------------------------------------------ #
# 3. DataLoader #
# ------------------------------------------------------------------ #
Expand All @@ -160,6 +203,10 @@ def train(
model, optimizer, dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, dataloader, lr_scheduler
)

if resume_from_checkpoint is not None:
accelerator.load_state(resume_from_checkpoint)

# ------------------------------------------------------------------ #
# 5. Training loop #
# ------------------------------------------------------------------ #
Expand All @@ -174,10 +221,10 @@ def train(
"epoch_lr": [],
}

for epoch in range(num_epochs):
for epoch in range(start_epoch, start_epoch + num_epochs):
progress = tqdm(
dataloader,
desc=f"Epoch {epoch}/{num_epochs - 1}",
desc=f"Epoch {epoch}/{start_epoch + num_epochs - 1}",
disable=not verbose or not accelerator.is_local_main_process,
)

Expand Down Expand Up @@ -258,17 +305,35 @@ def train(
# ---------------------------------------------------------------- #
# 6. Checkpointing #
# ---------------------------------------------------------------- #
if (epoch + 1) % checkpoint_every_n_epochs == 0 or epoch == num_epochs - 1:
if (epoch + 1) % checkpoint_every_n_epochs == 0 or epoch == (start_epoch + num_epochs - 1):
if accelerator.is_main_process:
ckpt_save_path = os.path.join(output_dir, f"checkpoint-epoch-{epoch:04d}")

# Noise scheduler config (needed by SchedulerClass.from_pretrained)
noise_scheduler.save_pretrained(ckpt_save_path)

# Class names and constructor kwargs for fresh reconstruction:
# this is only needed when resuming training from a checkpoint.
raw_opt = optimizer.optimizer
raw_sched = lr_scheduler.scheduler
ckpt_cfg = {
"noise_scheduler": {
"class": f"{noise_scheduler.__class__.__module__}.{noise_scheduler.__class__.__name__}",
},
"optimizer": {
"class": f"{raw_opt.__class__.__module__}.{raw_opt.__class__.__name__}",
},
"lr_scheduler": {
"class": f"{raw_sched.__class__.__module__}.{raw_sched.__class__.__name__}",
"kwargs": utils._get_lr_scheduler_kwargs(raw_sched),
},
}
with open(os.path.join(ckpt_save_path, "checkpoint_config.yaml"), "w") as f:
yaml.dump(_to_yaml_safe(ckpt_cfg), f)

# Model weights (via hook) + optimizer moments + grad scaler + RNG
accelerator.save_state(ckpt_save_path)
accelerator.unwrap_model(model).save_pretrained(ckpt_save_path)
with open(os.path.join(ckpt_save_path, "optimizer.pkl"), "wb") as f:
pickle.dump(optimizer.optimizer, f)
with open(os.path.join(ckpt_save_path, "noise_scheduler.pkl"), "wb") as f:
pickle.dump(noise_scheduler, f)
with open(os.path.join(ckpt_save_path, "lr_scheduler.pkl"), "wb") as f:
pickle.dump(lr_scheduler.scheduler, f)

if hasattr(dataset, "augmentations") and dataset.augmentations is not None:
with open(os.path.join(ckpt_save_path, "augmentations.pkl"), "wb") as f:
pickle.dump(dataset.augmentations, f)
Expand Down Expand Up @@ -356,10 +421,7 @@ def generate(
else:
noise_pred = model(images, timesteps, return_dict=False)[0]

step_kwargs = {}
if "generator" in noise_scheduler.step.__code__.co_varnames:
step_kwargs["generator"] = generator
images = noise_scheduler.step(noise_pred, t, images, **step_kwargs).prev_sample
images = noise_scheduler.step(noise_pred, t, images, generator=generator).prev_sample

if renorm is not None:
images = renorm(images)
Expand Down
28 changes: 17 additions & 11 deletions cosmodiff/tests/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import torch
from diffusers import UNet2DModel, DDPMScheduler, DDIMScheduler, DiTTransformer2DModel
from cosmodiff.utils import load_checkpoint, ArrayDataset
from cosmodiff.utils import load_checkpoint, ArrayDataset, find_latest_checkpoint
from cosmodiff.optim import train, generate, compute_fid, compute_kid, build_pca_encoder
from cosmodiff.augment import RandomRoll, RandomFlip

Expand Down Expand Up @@ -42,9 +42,9 @@ def test_train_basic():

ckpt_path = os.path.join(tmp_dir, "checkpoint-epoch-0001")
assert os.path.isdir(ckpt_path)
assert os.path.exists(os.path.join(ckpt_path, "noise_scheduler.pkl"))
assert os.path.exists(os.path.join(ckpt_path, "optimizer.pkl"))
assert os.path.exists(os.path.join(ckpt_path, "lr_scheduler.pkl"))
assert os.path.exists(os.path.join(ckpt_path, "config.json"))
assert os.path.exists(os.path.join(ckpt_path, "scheduler_config.json"))
assert os.path.exists(os.path.join(ckpt_path, "checkpoint_config.yaml"))
assert os.path.exists(os.path.join(ckpt_path, "augmentations.pkl"))
assert os.path.exists(os.path.join(ckpt_path, "metrics.json"))

Expand All @@ -64,12 +64,11 @@ def test_train_basic():
assert _lr_scheduler is not None
assert _augmentations is not None

# continue training
# continue training from checkpoint
initial_weights = model.conv_in.weight.data.clone()
metrics = train(
dataset,
model,
noise_scheduler=DDPMScheduler(num_train_timesteps=10),
resume_from_checkpoint=ckpt_path,
num_epochs=2,
batch_size=4,
checkpoint_every_n_epochs=2,
Expand All @@ -79,10 +78,17 @@ def test_train_basic():
verbose=False,
)

# get new checkpoint: ensure it is epoch-0003
ckpt_path2 = find_latest_checkpoint(tmp_dir)
assert int(ckpt_path2.split('-')[-1]) == 3
_model2, _noise_scheduler2, _optimizer2, _lr_scheduler2, _augmentations2 = (
load_checkpoint(ckpt_path2)
)

# training checks: finite output, and weights changed
assert all(torch.isfinite(torch.tensor(v)) for v in metrics["loss"])
assert all(torch.isfinite(torch.tensor(v)) for v in metrics["epoch_loss"])
assert not torch.allclose(model.conv_in.weight.data, initial_weights)
assert not torch.allclose(_model2.conv_in.weight.data, initial_weights)


def test_train_conditional_dit():
Expand Down Expand Up @@ -121,9 +127,9 @@ def test_train_conditional_dit():

ckpt_path = os.path.join(tmp_dir, "checkpoint-epoch-0001")
assert os.path.isdir(ckpt_path)
assert os.path.exists(os.path.join(ckpt_path, "noise_scheduler.pkl"))
assert os.path.exists(os.path.join(ckpt_path, "optimizer.pkl"))
assert os.path.exists(os.path.join(ckpt_path, "lr_scheduler.pkl"))
assert os.path.exists(os.path.join(ckpt_path, "config.json"))
assert os.path.exists(os.path.join(ckpt_path, "scheduler_config.json"))
assert os.path.exists(os.path.join(ckpt_path, "checkpoint_config.yaml"))
assert os.path.exists(os.path.join(ckpt_path, "augmentations.pkl"))
assert os.path.exists(os.path.join(ckpt_path, "metrics.json"))

Expand Down
Loading
Loading