Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 133 additions & 21 deletions data_handler/AnomalyDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,23 +63,26 @@ def save_numpy_as_npy(

class AnomalyDataset(Dataset):
"""
PyTorch Dataset that loads 3D samples stored as `.npy` files from a folder.
PyTorch Dataset that loads 2D or 3D samples stored as `.npy` files from a folder.

Key behavior:
- The dataset is populated *only* from a folder (no manual add_sample/add_path).
- File format: `.npy` (not NIfTI).
- Optional: preload everything into RAM (load_to_ram=True).

Return format:
- return_filename=False -> x
- return_filename=True -> (x, fname)
- If return_filename=False -> returns x (or x, org_mask, tgt_mask)
- If return_filename=True -> returns (x, fname) (or x, org_mask, tgt_mask, fname)
where fname is the basename including extension (e.g., "sample_0.npy").
"""

def __init__(
self,
folder: Union[str, os.PathLike],
org_mask_folder: Union[str, os.PathLike],
tgt_mask_folder: Union[str, os.PathLike],
*,
conditional: bool = False,
return_filename: bool = True,
dtype: torch.dtype = torch.float32,
transform: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
Expand All @@ -89,6 +92,7 @@ def __init__(
load_to_ram: bool = False,
mmap_mode: Optional[str] = None,
numpy_mode: bool = False,
skip_missing_masks: bool = False,
) -> None:
"""
Initialize the dataset and discover `.npy` files.
Expand All @@ -97,8 +101,15 @@ def __init__(
------
folder:
Path to a directory containing `.npy` files.
org_mask_folder:
Path to a directory containing the original masks as `.npy` files.
tgt_mask_folder:
Path to a directory containing the target masks as `.npy` files.
conditional:
If True, original and target masks are loaded and returned alongside
the input data.
return_filename:
If True, __getitem__ returns (x, filename). Otherwise returns x only.
If True, __getitem__ returns the filename at the end of the tuple.
dtype:
Torch dtype used when converting numpy arrays into torch tensors.
(Ignored if numpy_mode=True.)
Expand All @@ -118,20 +129,25 @@ def __init__(
NOTE: If load_to_ram=True, mmap_mode is ignored (set to None).
numpy_mode:
If True, __getitem__ returns numpy arrays instead of torch tensors.

Outputs
-------
None
Side effects:
- scans the folder and stores file paths
- builds fast lookup dicts for basename and stem
- optionally preloads arrays into RAM
skip_missing_masks:
If True and conditional=True, samples without matching mask files
are omitted from the dataset.
"""
# Resolve folder path
# Resolve folder paths
self.folder = Path(folder).expanduser().resolve()
if not self.folder.is_dir():
raise FileNotFoundError(str(self.folder))

self.org_mask_folder = Path(org_mask_folder).expanduser().resolve()
if not self.org_mask_folder.is_dir():
raise FileNotFoundError(f"org_mask_folder not found: {self.org_mask_folder}")

self.tgt_mask_folder = Path(tgt_mask_folder).expanduser().resolve()
if not self.tgt_mask_folder.is_dir():
raise FileNotFoundError(f"tgt_mask_folder not found: {self.tgt_mask_folder}")

self.conditional = conditional

# Store configuration flags
self.numpy_mode = numpy_mode
self.return_filename = return_filename
Expand All @@ -141,12 +157,15 @@ def __init__(
self.extensions = tuple(e.lower() for e in extensions)
self.sort = sort
self.load_to_ram = load_to_ram
self.skip_missing_masks = skip_missing_masks

# mmap_mode is only relevant when we DO NOT preload into RAM
self.mmap_mode = None if load_to_ram else mmap_mode

# Collect all .npy files into a list
self._paths: List[str] = self._collect_paths()
if self.conditional and self.skip_missing_masks:
self._paths = self._filter_paths_with_existing_masks(self._paths)

# Fast lookup tables:
# - basename -> [indices]
Expand All @@ -163,11 +182,37 @@ def __init__(

# Optional preload into RAM for faster access during training/inference
self._ram_arrays: Optional[List[np.ndarray]] = None
self._ram_org_masks: Optional[List[np.ndarray]] = None
self._ram_tgt_masks: Optional[List[np.ndarray]] = None

if self.load_to_ram:
self._ram_arrays = []

if self.conditional:
self._ram_org_masks = []
self._ram_tgt_masks = []
same_mask_folders = (self.org_mask_folder == self.tgt_mask_folder)

for p in self._paths:
# allow_pickle=False for safety; loads full array into memory
self._ram_arrays.append(np.load(p, allow_pickle=False))

# Load masks if conditional is True
if self.conditional:
org_mask_path = os.path.join(self.org_mask_folder, os.path.basename(p))
if not os.path.exists(org_mask_path):
raise FileNotFoundError(f"Expected org_mask file missing: {org_mask_path}")

loaded_org_mask = np.load(org_mask_path, allow_pickle=False)
self._ram_org_masks.append(loaded_org_mask)

if same_mask_folders:
self._ram_tgt_masks.append(loaded_org_mask)
else:
tgt_mask_path = os.path.join(self.tgt_mask_folder, os.path.basename(p))
if not os.path.exists(tgt_mask_path):
raise FileNotFoundError(f"Expected tgt_mask file missing: {tgt_mask_path}")
self._ram_tgt_masks.append(np.load(tgt_mask_path, allow_pickle=False))

def _collect_paths(self) -> List[str]:
"""
Expand Down Expand Up @@ -209,6 +254,36 @@ def _collect_paths(self) -> List[str]:

return paths

def _filter_paths_with_existing_masks(self, paths: List[str]) -> List[str]:
"""
Keep only anomaly files with matching org/tgt mask files.
(masks can be missing for anomalies that were filtered out by e.g. min_anomaly_percentage)
"""
filtered_paths = []
skipped = 0
same_mask_folders = (self.org_mask_folder == self.tgt_mask_folder)

for p in paths:
base = os.path.basename(p)
org_mask_path = os.path.join(self.org_mask_folder, base)
tgt_mask_path = org_mask_path if same_mask_folders else os.path.join(self.tgt_mask_folder, base)

if not os.path.exists(org_mask_path) or not os.path.exists(tgt_mask_path):
skipped += 1
continue

filtered_paths.append(p)

if skipped:
print(f"[AnomalyDataset] Skipped {skipped} anomaly file(s) without matching mask files.")

if not filtered_paths:
raise FileNotFoundError(
f"No anomaly files with matching mask files found in: {self.folder}"
)

return filtered_paths

def __len__(self) -> int:
"""
Number of samples in the dataset.
Expand Down Expand Up @@ -279,11 +354,9 @@ def __getitem__(self, idx: int):
Outputs
-------
If return_filename == True:
(x, fname)
- x: torch.Tensor (default) or np.ndarray (if numpy_mode=True)
- fname: str basename (e.g. "foo.npy")
(x, [org_mask, tgt_mask,] fname)
Else:
x only
x or (x, org_mask, tgt_mask)

Notes
-----
Expand All @@ -308,14 +381,52 @@ def __getitem__(self, idx: int):
else:
x = img_np

# Optional transform hook
# Optional transform hook (Intensity transforms only, don't apply to masks)
if self.transform is not None:
x = self.transform(x)

# Return format
if not self.conditional:
if self.return_filename:
return x, fname
return x

# conditional -> loads masks
if self._ram_org_masks is not None:
org_mask_np = self._ram_org_masks[idx]
tgt_mask_np = self._ram_tgt_masks[idx]
else:
org_mask_path = os.path.join(self.org_mask_folder, fname)
org_mask_np = np.load(org_mask_path, allow_pickle=False, mmap_mode=self.mmap_mode)

if self.org_mask_folder == self.tgt_mask_folder:
tgt_mask_np = org_mask_np
else:
tgt_mask_path = os.path.join(self.tgt_mask_folder, fname)
tgt_mask_np = np.load(tgt_mask_path, allow_pickle=False, mmap_mode=self.mmap_mode)

org_mask_tensor = self._to_tensor(org_mask_np)
tgt_mask_tensor = self._to_tensor(tgt_mask_np)

# remove channel dim for masks
org_mask_tensor = org_mask_tensor.long()
if org_mask_tensor.shape[0] == 1:
org_mask_tensor = org_mask_tensor.squeeze(0)

tgt_mask_tensor = tgt_mask_tensor.long()
if tgt_mask_tensor.shape[0] == 1:
tgt_mask_tensor = tgt_mask_tensor.squeeze(0)

if self.numpy_mode:
org_mask_out = org_mask_tensor.numpy()
tgt_mask_out = tgt_mask_tensor.numpy()
else:
org_mask_out = org_mask_tensor
tgt_mask_out = tgt_mask_tensor

if self.return_filename:
return x, fname
return x
return x, org_mask_out, tgt_mask_out, fname

return x, org_mask_out, tgt_mask_out
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lass und das alles vereinfachen bisher war rückgabe immer x oder x, filename; Jetzt brauchen wir masken. Also würde ich vorschlagen neue rückgabe x, mask oder x, mask, filename. Beim Tranining brauchst du nur x, ori_mask und beim generieren brauchst du nur x, tgt_mask. Wir brauchen zu keinem zeitpunkt beides oder? -> somit können wir mit einem einfachen if im AnomalyDataset beide fälle abdecken. Wenn wir uns generell egal ob condition oder nicht auf eine rückgabefrom einigen sparen wir uns in der pipeline viele verschachtelungen die unseren hybrid data generator unnötig komplex macht.


def load_numpy_by_basename(self, basename: str) -> np.ndarray:
"""
Expand Down Expand Up @@ -371,3 +482,4 @@ def load_numpy_by_basename(self, basename: str) -> np.ndarray:
return self._ram_arrays[idx]

return np.load(self._paths[idx], allow_pickle=False, mmap_mode=self.mmap_mode)

4 changes: 2 additions & 2 deletions models/VAE_ConvNeXt_2D.py
Original file line number Diff line number Diff line change
Expand Up @@ -908,8 +908,8 @@ def generate_synth_sample_prior(

if __name__ == "__main__":
# Quick sanity check
cfg = Config(n_res_blocks=2, n_levels=4, z_channels=64, bottleneck_dim=64)
model = ConvNeXtVAE2D(in_channels=1, cfg=cfg)
cfg = Config(in_channels=1, n_res_blocks=2, n_levels=4, z_channels=64, bottleneck_dim=64)
model = ConvNeXtVAE2D(cfg=cfg)
x = torch.randn(2, 1, 128, 128)
out = model(x)
print({k: tuple(v.shape) for k, v in out.items()})
Expand Down
4 changes: 2 additions & 2 deletions models/VAE_ConvNeXt_3D.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,8 +789,8 @@ def warmup(self, shape, device=None, dtype=None):

if __name__ == "__main__":
# Quick sanity check
cfg = Config(n_res_blocks=2, n_levels=4, z_channels=64, bottleneck_dim=64)
model = ConvNeXtVAE3D(in_channels=1, cfg=cfg)
cfg = Config(in_channels=1, n_res_blocks=2, n_levels=4, z_channels=64, bottleneck_dim=64)
model = ConvNeXtVAE3D(cfg=cfg)
x = torch.randn(1, 1, 64, 64, 64)
out = model(x)
print({k: tuple(v.shape) for k, v in out.items()})
Loading