Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 122 additions & 0 deletions cosmodiff/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,3 +365,125 @@ def generate(
images = renorm(images)

return images


class PCAEncoder(torch.nn.Module):
"""Linear PCA encoder as an nn.Module.

Projects ``(N, C, H, W)`` images onto the top-k principal components,
returning ``(N, rank)`` feature vectors.
"""

def __init__(self, V: torch.Tensor, center: torch.Tensor):
super().__init__()
self.register_buffer("V", V)
self.register_buffer("center", center)

def forward(self, images: torch.Tensor) -> torch.Tensor:
"""Args:
images: ``(N, C, H, W)`` float64 tensor.
"""
return (images.flatten(1) - self.center) @ self.V


def build_pca_encoder(train_images: torch.Tensor, rank: int = 16) -> PCAEncoder:
"""Build a PCAEncoder from training images via low-rank SVD.

A lightweight alternative to a learned encoder for use with
``compute_fid`` / ``compute_kid`` when no domain-specific encoder is
available.

Args:
train_images: ``(N, C, H, W)`` float64 tensor of training images.
rank: Number of principal components to retain.

Returns:
PCAEncoder mapping ``(N, C, H, W)`` images to ``(N, rank)`` features.
"""
X = train_images.flatten(1)
center = X.mean(0)
_, _, Vt = torch.linalg.svd(X - center, full_matrices=False)
V = Vt[:rank].T # (D, rank)
return PCAEncoder(V, center)


def _sqrtm_sym(A: torch.Tensor) -> torch.Tensor:
"""Matrix square root for a symmetric PSD matrix via eigendecomposition."""
L, V = torch.linalg.eigh(A)
return V @ torch.diag(L.clamp(min=0).sqrt()) @ V.mT


def compute_fid(feats_real: torch.Tensor, feats_fake: torch.Tensor) -> float:
"""Fréchet Inception Distance from pre-computed feature vectors.

Args:
feats_real: ``(N, d)`` float64 feature tensor for real samples.
feats_fake: ``(M, d)`` float64 feature tensor for generated samples.

Returns:
Scalar FID value.
"""
mu_r, mu_g = feats_real.mean(0), feats_fake.mean(0)
r, g = feats_real, feats_fake
sigma_r = (r - mu_r).T @ (r - mu_r) / (len(r) - 1)
sigma_g = (g - mu_g).T @ (g - mu_g) / (len(g) - 1)

sqrt_sigma_r = _sqrtm_sym(sigma_r)
M = sqrt_sigma_r @ sigma_g @ sqrt_sigma_r
trace_covmean = torch.linalg.eigvalsh(M).clamp(min=0).sqrt().sum()

diff = mu_r - mu_g
fid = diff @ diff + torch.trace(sigma_r) + torch.trace(sigma_g) - 2 * trace_covmean
return fid.item()


def compute_kid(
feats_real: torch.Tensor,
feats_fake: torch.Tensor,
degree: int = 3,
gamma: Optional[float] = None,
coef: float = 1.0,
subset_size: int = 1000,
n_subsets: int = 10,
) -> tuple[float, float]:
"""Kernel Inception Distance (polynomial MMD) from pre-computed features.

Preferred over FID when sample counts are small (~2k), as the estimator is
unbiased and has lower variance than FID's covariance-based approach.

Args:
feats_real: ``(N, d)`` float64 feature tensor for real samples.
feats_fake: ``(M, d)`` float64 feature tensor for generated samples.
degree: Polynomial kernel degree. Defaults to ``3``.
gamma: Kernel scale; defaults to ``1/d``.
coef: Kernel offset. Defaults to ``1.0``.
subset_size: Samples per subset for the MMD estimate.
n_subsets: Number of random subsets to average over.

Returns:
``(mean_kid, std_kid)`` across subsets.
"""
_gamma = 1.0 / feats_real.shape[1] if gamma is None else gamma

def poly_kernel(x, y):
return (_gamma * (x @ y.mT) + coef) ** degree

scores = []
for _ in range(n_subsets):
ri = feats_real[torch.randperm(len(feats_real), device=feats_real.device)[:subset_size]]
gi = feats_fake[torch.randperm(len(feats_fake), device=feats_fake.device)[:subset_size]]
n = subset_size

kxx = poly_kernel(ri, ri)
kyy = poly_kernel(gi, gi)
kxy = poly_kernel(ri, gi)

mmd2 = (
(kxx.sum() - kxx.trace()) / (n * (n - 1))
+ (kyy.sum() - kyy.trace()) / (n * (n - 1))
- 2 * kxy.mean()
)
scores.append(mmd2.item())

scores_t = torch.tensor(scores)
return float(scores_t.mean()), float(scores_t.std())
74 changes: 73 additions & 1 deletion cosmodiff/tests/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
from diffusers import UNet2DModel, DDPMScheduler, DDIMScheduler, DiTTransformer2DModel
from cosmodiff.utils import load_checkpoint, ArrayDataset
from cosmodiff.optim import train, generate
from cosmodiff.optim import train, generate, compute_fid, compute_kid, build_pca_encoder
from cosmodiff.augment import RandomRoll, RandomFlip


Expand Down Expand Up @@ -224,4 +224,76 @@ def test_generate_reproducible():
assert torch.allclose(out1, out2)


# ------------------------------------------------------------------ #
# Helpers for FID / KID tests #
# ------------------------------------------------------------------ #


def _make_features(n: int, seed: int, mean: float = 0.0, std: float = 1.0):
torch.manual_seed(seed)
return (torch.randn(n, 1, 16, 16) * std + mean).double()


# ------------------------------------------------------------------ #
# FID tests #
# ------------------------------------------------------------------ #

def test_fid_finite_and_nonneg():
"""FID is finite and non-negative."""
train_imgs = _make_features(500, seed=0)
fake_imgs = _make_features(500, seed=1)
encode = build_pca_encoder(train_imgs, rank=16)
fid = compute_fid(encode(train_imgs), encode(fake_imgs))
assert torch.isfinite(torch.tensor(fid))
assert fid >= 0.0


def test_fid_smaller_same_dist():
"""FID is smaller when fake comes from the same distribution vs a shifted one."""
torch.manual_seed(0)
train_imgs = torch.randn(500, 1, 16, 16).double()
encode = build_pca_encoder(train_imgs, rank=16)
feats_real = encode(train_imgs)

torch.manual_seed(1)
feats_same = encode(torch.randn(500, 1, 16, 16).double())

torch.manual_seed(2)
feats_diff = encode((torch.randn(500, 1, 16, 16) * 3 + 5).double())

assert compute_fid(feats_real, feats_same) < compute_fid(feats_real, feats_diff)


# ------------------------------------------------------------------ #
# KID tests #
# ------------------------------------------------------------------ #

def test_kid_finite_and_nonneg():
"""KID mean is finite; std is non-negative."""
train_imgs = _make_features(500, seed=0)
fake_imgs = _make_features(500, seed=1)
encode = build_pca_encoder(train_imgs, rank=16)
mean_kid, std_kid = compute_kid(encode(train_imgs), encode(fake_imgs), subset_size=200, n_subsets=5)
assert torch.isfinite(torch.tensor(mean_kid))
assert std_kid >= 0.0


def test_kid_smaller_same_dist():
"""KID mean is smaller when fake comes from the same distribution vs a shifted one."""
torch.manual_seed(0)
train_imgs = torch.randn(500, 1, 16, 16).double()
encode = build_pca_encoder(train_imgs, rank=16)
feats_real = encode(train_imgs)

torch.manual_seed(1)
feats_same = encode(torch.randn(500, 1, 16, 16).double())

torch.manual_seed(2)
feats_diff = encode((torch.randn(500, 1, 16, 16) * 3 + 5).double())

kid_same, _ = compute_kid(feats_real, feats_same, subset_size=200, n_subsets=5)
kid_diff, _ = compute_kid(feats_real, feats_diff, subset_size=200, n_subsets=5)
assert kid_same < kid_diff



Loading