Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cosmodiff/configs/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ data:
label_path: /path/to/labels.npy # optional
label_read_fn: npy_read_fn # optional
log: false
minmax: true
norm: center-scale
two_dim: true
zthin: 4
n_samples: null
Expand Down
24 changes: 12 additions & 12 deletions cosmodiff/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ def _make_array(n=20, nz=4, nx=8, ny=8):

def test_n_samples():
arr = _make_array(n=20)
images, _ = load_data(arr, img_read_fn=None, minmax=False, two_dim=False)
images, _ = load_data(arr, img_read_fn=None, norm=None, two_dim=False)
assert images.shape[0] == 20

images, _ = load_data(arr, img_read_fn=None, n_samples=7, minmax=False, two_dim=False)
images, _ = load_data(arr, img_read_fn=None, n_samples=7, norm=None, two_dim=False)
assert images.shape[0] == 7

arr_t = torch.as_tensor(arr)
Expand All @@ -47,7 +47,7 @@ def test_n_samples_labels_in_sync():
arr, img_read_fn=None,
label_path=labels, label_read_fn=None,
n_samples=7, seed=0,
minmax=False, two_dim=False,
norm=None, two_dim=False,
)
assert images.shape[0] == 7
assert out_labels.shape[0] == 7
Expand All @@ -60,9 +60,9 @@ def test_n_samples_labels_in_sync():

def test_seed():
arr = _make_array(n=50)
imgs1, _ = load_data(arr, img_read_fn=None, n_samples=10, seed=42, minmax=False, two_dim=False)
imgs2, _ = load_data(arr, img_read_fn=None, n_samples=10, seed=42, minmax=False, two_dim=False)
imgs3, _ = load_data(arr, img_read_fn=None, n_samples=10, seed=99, minmax=False, two_dim=False)
imgs1, _ = load_data(arr, img_read_fn=None, n_samples=10, seed=42, norm=None, two_dim=False)
imgs2, _ = load_data(arr, img_read_fn=None, n_samples=10, seed=42, norm=None, two_dim=False)
imgs3, _ = load_data(arr, img_read_fn=None, n_samples=10, seed=99, norm=None, two_dim=False)
assert torch.allclose(imgs1, imgs2)
assert not torch.allclose(imgs1, imgs3)

Expand All @@ -72,8 +72,8 @@ def test_memmap():
with tempfile.NamedTemporaryFile(suffix=".npy") as f:
np.save(f.name, arr)
mmap = np.load(f.name, mmap_mode="r")
images_all, _ = load_data(mmap, img_read_fn=None, minmax=False, two_dim=False)
images_sub, _ = load_data(mmap, img_read_fn=None, n_samples=5, minmax=False, two_dim=False)
images_all, _ = load_data(mmap, img_read_fn=None, norm=None, two_dim=False)
images_sub, _ = load_data(mmap, img_read_fn=None, n_samples=5, norm=None, two_dim=False)
assert images_all.shape[0] == 10
assert images_sub.shape[0] == 5

Expand Down Expand Up @@ -109,9 +109,9 @@ def test_minmax_norm():

def test_center_scale_norm():
x = torch.randn(100)
out, avg, std = center_scale_norm(x.clone(), scale=10)
assert abs(out.median().item()) < 0.1
assert out.abs().max().item() < 2.0
out = center_scale_norm(x.clone())
assert abs(out.mean().item()) < 0.1
assert out.abs().max().item() <= 1.0 + 1e-6


# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -186,7 +186,7 @@ def test_parse_config_data():
"img_path": img_path,
"img_read_fn": "npy_read_fn",
"log": False,
"minmax": True,
"norm": "min-max",
"two_dim": True,
"zthin": 1,
"keep_on_cpu": True,
Expand Down
36 changes: 20 additions & 16 deletions cosmodiff/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def load_data(
label_path: str | np.ndarray | None = None,
label_read_fn: Optional[callable] = None,
log: bool = False,
minmax: bool = True,
norm: str | None = None,
two_dim: bool = True,
zthin: int = 1,
n_samples: int | None = None,
Expand Down Expand Up @@ -107,8 +107,8 @@ def load_data(
if ``label_path`` is already a numpy array. Defaults to ``None``.
log (bool): Apply a log transform to images before normalization.
Defaults to ``False``.
minmax (bool): Normalize images to ``[-1, 1]`` via min-max scaling.
Defaults to ``True``.
norm (str): Normalize images via "min-max" scaling ``[-1, 1]``,
or "center-scale".
two_dim (bool): If ``True``, reshape the data to treat each z-slice
as an independent 2D image. If ``False``, treat each sample as a
3D volume. Defaults to ``True``.
Expand Down Expand Up @@ -179,9 +179,11 @@ def read_images(path):
if log:
images = images.log()

if minmax:
images = images - images.min()
images = images / images.max() * 2 - 1.0
if norm is not None:
if norm == 'center-scale':
images = center_scale_norm(images)
elif norm == 'min-max':
images = minmax_norm(images)

# --- reshape --------------------------------------------------------
if two_dim:
Expand Down Expand Up @@ -266,27 +268,29 @@ def minmax_norm(x: torch.Tensor) -> torch.Tensor:
return x * 2 - 1


def center_scale_norm(x: torch.Tensor, scale: int = 10):
"""Center a tensor based on its median, and normalize by its scale
def center_scale_norm(x: torch.Tensor, inplace: bool = False):
"""Center a tensor based on its mean, and normalize by its absolute deviation.

Args:
x (torch.Tensor): Input tensor of any shape.
scale (int): Number of standard deviations to scale by
inplace (bool): If True, edit inplace.

Returns:
torch.Tensor: scaled tensor
float: avg
float: std
"""
# center
avg = x.median()
x -= avg
avg = x.mean()
if inplace:
x -= avg
else:
x = x - avg

# norm
std = x.std()
x /= std * scale
# scale by max-abs
x /= x.abs().max()

return x, avg, std
return x


def parse_config_model(config: dict):
Expand Down Expand Up @@ -387,7 +391,7 @@ def parse_config_data(config: dict):
device=device,
dtype=dtype,
label_read_fn=label_read_fn,
minmax=data_cfg.get("minmax", True),
norm=data_cfg.get("norm", 'center-scale'),
two_dim=data_cfg.get("two_dim", True),
zthin=data_cfg.get("zthin", 1),
)
Expand Down
Loading