Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/style.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ jobs:
# Check format. To fix, run "ruff format ."
ruff format --diff .
# Check PEP8 violations, logic/correctness errors, and sort imports. To fix, run "ruff check --fix ."
ruff check --diff .
ruff check .
2 changes: 1 addition & 1 deletion ScaFFold/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def main():
combined_config["point_num"] = int(combined_config["vol_size"] ** 3 / 256)

# Handle Restart / Resume logic
if hasattr(args, "restart") and args.restart == True:
if hasattr(args, "restart") and args.restart:
print("Restart flag detected: Forcing train_from_scratch = False")
combined_config["train_from_scratch"] = False
combined_config["restart"] = True
Expand Down
4 changes: 2 additions & 2 deletions ScaFFold/datagen/get_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,12 @@ def _git_commit_short() -> str:
)
except subprocess.CalledProcessError:
print(
f"Tried to get git commit id in non-git repo. No commit id will be enforced for dataset reuse."
"Tried to get git commit id in non-git repo. No commit id will be enforced for dataset reuse."
)
return "no-commit-id"
except Exception:
print(
f"Exception when trying to get git commit for dataset. No commit id will be enforced for dataset reuse."
"Exception when trying to get git commit for dataset. No commit id will be enforced for dataset reuse."
)
return "no-commit-id"

Expand Down
6 changes: 2 additions & 4 deletions ScaFFold/utils/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,9 @@
#
# SPDX-License-Identifier: (Apache-2.0)

import copy
import math
import random
import shutil
import time
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Any, Dict, Optional
Expand Down Expand Up @@ -238,7 +236,7 @@ def save_checkpoint(
self.best_ckpt_path,
is_best,
)
self._log(f"Async checkpoint offloaded to background thread.")
self._log("Async checkpoint offloaded to background thread.")
else:
# Synchronous Save
self._write_to_disk(
Expand Down Expand Up @@ -314,7 +312,7 @@ def _get_rng_snapshot(self) -> Dict[str, Any]:
pass
try:
snap["rng_state_python"] = random.getstate()
except:
except Exception:
pass
return snap

Expand Down
2 changes: 1 addition & 1 deletion ScaFFold/utils/create_restart_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import stat
import sys
from pathlib import Path
from typing import List, Literal, Optional, Union
from typing import List, Literal, Union

import torch

Expand Down
2 changes: 1 addition & 1 deletion ScaFFold/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def force_cuda_visible_devices(force: bool = False) -> None:
other GPUs.

"""
print(f"force_cuda_visible_devices is deprecated. Skipping...")
print("force_cuda_visible_devices is deprecated. Skipping...")


def get_device() -> torch.device:
Expand Down
6 changes: 1 addition & 5 deletions ScaFFold/utils/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,15 @@

import math

import numpy as np
import torch
import torch.nn.functional as F
from distconv import DCTensor
from torch.distributed.tensor import DTensor, Replicate, Shard, distribute_tensor
from torch.distributed.tensor import DTensor, Replicate, Shard
from tqdm import tqdm

from ScaFFold.utils.dice_score import (
SpatialAllReduce,
compute_sharded_dice,
dice_coeff,
dice_loss,
multiclass_dice_coeff,
)
from ScaFFold.utils.perf_measure import annotate

Expand Down
14 changes: 5 additions & 9 deletions ScaFFold/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,14 @@
# SPDX-License-Identifier: (Apache-2.0)

# Standard library
import json
import math
import os
import random
import shutil
import time
from pathlib import Path

# Third party
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from distconv import DCTensor
Expand All @@ -35,7 +31,7 @@

from ScaFFold.utils.checkpointing import CheckpointManager
from ScaFFold.utils.data_loading import FractalDataset
from ScaFFold.utils.dice_score import SpatialAllReduce, compute_sharded_dice, dice_loss
from ScaFFold.utils.dice_score import SpatialAllReduce, compute_sharded_dice
from ScaFFold.utils.distributed import get_local_rank, get_world_rank, get_world_size

# Local
Expand Down Expand Up @@ -389,9 +385,9 @@ def warmup(self):
enabled=self.config.torch_amp,
):
# Forward on DCTensor
self.log.debug(f" warmup: running forward pass")
self.log.debug(" warmup: running forward pass")
masks_pred_dc = self.model(images_dc)
self.log.debug(f" warmup: forward pass complete")
self.log.debug(" warmup: forward pass complete")

# Extract the underlying PyTorch local tensors
local_preds = masks_pred_dc
Expand Down Expand Up @@ -442,14 +438,14 @@ def warmup(self):
loss = loss_ce + loss_dice

self.log.debug(
f" warmup: loss calculation complete. Proceeding to backward pass"
" warmup: loss calculation complete. Proceeding to backward pass"
)

# Backward pass
self.grad_scaler.scale(loss).backward()
self.grad_scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
self.log.debug(f" warmup: backward pass complete. Stepping optimizer")
self.log.debug(" warmup: backward pass complete. Stepping optimizer")

self.grad_scaler.step(self.optimizer)
self.grad_scaler.update()
Expand Down
3 changes: 1 addition & 2 deletions ScaFFold/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from ScaFFold.unet import UNet
from ScaFFold.utils.distributed import (
get_device,
get_job_id,
get_local_rank,
get_local_size,
get_world_rank,
Expand Down Expand Up @@ -266,7 +265,7 @@ def main(kwargs_dict: dict = {}):
# Generate plots
#
if rank == 0:
log.info(f"Generating figures on rank 0...")
log.info("Generating figures on rank 0...")
begin_code_region("generate_figures")
standard_viz.main(config)
end_code_region("generate_figures")
Expand Down
Loading