Skip to content

Commit b4d202c

Browse files
PatrickRMilesPatrick Milesmichaelmckinsey1
authored
Warmup changes: only warm a few batches; extract to separate method in trainer class (#43)
* apply optimizer every batch, not every epoch; unscale gradients before clipping * trainer tweaks * apply optimizer every batch, not every epoch; unscale gradients before clipping * extract warmup to separate method; switch to warming up set number of batches (user configurable) * whitespace; num_workers revert * ruff * make parallelstrategy, spatial_mesh, ddp_placements attrs of trainer; other small tweaks * remove deprecated config attrs * ruff * get device mesh from ps class attr * ruff * missing self. on some ps accesses * Fix imports and missing self.ps * rm legacy warmup_epochs * Move attributes to base class for clarity * remove warmup_epochs -- not useful to keep support for this * call cleanup_or_resume trainer method directly * rm unused vars --------- Co-authored-by: Patrick Miles <miles30@tioga.llnl.gov> Co-authored-by: Michael McKinsey <michaelmckinsey1@gmail.com>
1 parent 875f2fd commit b4d202c

6 files changed

Lines changed: 174 additions & 172 deletions

File tree

ScaFFold/cli.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,11 @@ def main():
140140
benchmark_parser.add_argument(
141141
"--batch-size", type=int, nargs="+", help="Batch sizes for each volume size."
142142
)
143+
benchmark_parser.add_argument(
144+
"--warmup-batches",
145+
type=int,
146+
help="Number of warmup batches to run per rank before training.",
147+
)
143148
benchmark_parser.add_argument(
144149
"--optimizer",
145150
type=str,

ScaFFold/configs/benchmark_default.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,6 @@ framework: "torch" # The DL framework to train with. Only valid
2929
checkpoint_dir: "checkpoints" # Subfolder in which to save training checkpoints.
3030
loss_freq: 1 # Number of epochs between logging the overall loss.
3131
normalize: 1 # Cateogry search normalization parameter
32-
warmup_epochs: 1 # How many warmup epochs before training
32+
warmup_batches: 5 # How many warmup batches per rank to run before training.
3333
dataset_reuse_enforce_commit_id: 0 # Enforce matching commit IDs for dataset reuse.
34-
target_dice: 0.95
34+
target_dice: 0.95

ScaFFold/configs/benchmark_testing.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,6 @@ framework: "torch" # The DL framework to train with. Only valid
2929
checkpoint_dir: "checkpoints" # Subfolder in which to save training checkpoints.
3030
loss_freq: 1 # Number of epochs between logging the overall loss.
3131
normalize: 1 # Cateogry search normalization parameter
32-
warmup_epochs: 1 # How many warmup epochs before training
32+
warmup_batches: 5 # How many warmup batches per rank to run before training.
3333
dataset_reuse_enforce_commit_id: 0 # Enforce matching commit IDs for dataset reuse.
3434
target_dice: 0.95

ScaFFold/utils/config_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def __init__(self, config_dict):
6666
self.loss_freq = config_dict["loss_freq"]
6767
self.checkpoint_dir = config_dict["checkpoint_dir"]
6868
self.normalize = config_dict["normalize"]
69-
self.warmup_epochs = config_dict["warmup_epochs"]
69+
self.warmup_batches = config_dict.get("warmup_batches")
7070
self.dataset_reuse_enforce_commit_id = config_dict[
7171
"dataset_reuse_enforce_commit_id"
7272
]

0 commit comments

Comments
 (0)