Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions mipcandy/common/optim/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def _forward(self, outputs: torch.Tensor, labels: torch.Tensor) -> tuple[torch.T
if not self.include_background:
outputs = outputs[:, 1:]
labels = labels[:, 1:]
dice = soft_dice(outputs, labels, smooth=self.smooth)
dice = soft_dice(outputs.float(), labels.float(), smooth=self.smooth)
metrics = {"soft dice": dice.item(), "ce loss": ce.item()}
c = self.lambda_ce * ce + self.lambda_soft_dice * (1 - dice)
return c, metrics
Expand Down Expand Up @@ -88,10 +88,10 @@ def __init__(self, *, lambda_bce: float = 1, lambda_soft_dice: float = 1,
self.min_percentage_per_class: float | None = min_percentage_per_class

def _forward(self, outputs: torch.Tensor, labels: torch.Tensor) -> tuple[torch.Tensor, dict[str, float]]:
labels = labels.to(dtype=outputs.dtype)
bce = nn.functional.binary_cross_entropy_with_logits(outputs, labels)
outputs = outputs.sigmoid()
labels = labels.float()
bce = nn.functional.binary_cross_entropy(outputs, labels)
dice = soft_dice(outputs, labels, smooth=self.smooth)
dice = soft_dice(outputs.float(), labels.float(), smooth=self.smooth)
metrics = {"soft dice": dice.item(), "bce loss": bce.item()}
c = self.lambda_bce * bce + self.lambda_soft_dice * (1 - dice)
return c, metrics
Expand Down
4 changes: 3 additions & 1 deletion mipcandy/presets/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,9 @@ def backward(self, images: torch.Tensor, labels: torch.Tensor, toolbox: TrainerT
outputs = list(torch.unbind(outputs, dim=1))
labels = self.prepare_deep_supervision_targets(labels, [m.shape[2:] for m in outputs])
loss, metrics = toolbox.criterion(outputs, labels)
loss.backward()
self._do_backward(loss, toolbox)
if toolbox.scaler:
toolbox.scaler.unscale_(toolbox.optimizer)
nn.utils.clip_grad_norm_(toolbox.model.parameters(), 12)
return loss.item(), metrics

Expand Down
52 changes: 42 additions & 10 deletions mipcandy/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class TrainerToolbox(object):
scheduler: optim.lr_scheduler.LRScheduler
criterion: nn.Module
ema: nn.Module | None = None
scaler: torch.amp.GradScaler | None = None


@dataclass
Expand Down Expand Up @@ -85,11 +86,14 @@ def save_everything_for_recovery(self, toolbox: TrainerToolbox, tracker: Trainer
**training_arguments) -> None:
if self._unrecoverable:
return
torch.save({
state_dicts = {
"optimizer": toolbox.optimizer.state_dict(),
"scheduler": toolbox.scheduler.state_dict(),
"criterion": toolbox.criterion.state_dict()
}, f"{self.experiment_folder()}/state_dicts.pth")
}
if toolbox.scaler:
state_dicts["scaler"] = toolbox.scaler.state_dict()
torch.save(state_dicts, f"{self.experiment_folder()}/state_dicts.pth")
with open(f"{self.experiment_folder()}/state_orb.json", "w") as f:
dump({"tracker": asdict(tracker), "training_arguments": training_arguments}, f)

Expand All @@ -116,6 +120,9 @@ def load_toolbox(self, num_epochs: int, example_shape: AmbiguousShape, compile_m
toolbox.optimizer.load_state_dict(state_dicts["optimizer"])
toolbox.scheduler.load_state_dict(state_dicts["scheduler"])
toolbox.criterion.load_state_dict(state_dicts["criterion"])
if "scaler" in state_dicts:
toolbox.scaler = torch.amp.GradScaler(self._device_type())
Comment thread
perctrix marked this conversation as resolved.
toolbox.scaler.load_state_dict(state_dicts["scaler"])
return toolbox

def recover_from(self, experiment_id: str) -> Self:
Expand Down Expand Up @@ -391,6 +398,12 @@ def empty_cache(self) -> None:

# Training methods

def _do_backward(self, loss: torch.Tensor, toolbox: TrainerToolbox) -> None:
if toolbox.scaler:
toolbox.scaler.scale(loss).backward()
else:
loss.backward()

def sanity_check(self, template_model: nn.Module, example_shape: AmbiguousShape) -> SanityCheckResult:
try:
return sanity_check(template_model, example_shape, device=self._device)
Expand All @@ -402,14 +415,27 @@ def backward(self, images: torch.Tensor, labels: torch.Tensor, toolbox: TrainerT
str, float]]:
raise NotImplementedError

def _device_type(self) -> str:
return self._device.type if isinstance(self._device, torch.device) else str(self._device).split(":")[0]

def train_batch(self, images: torch.Tensor, labels: torch.Tensor, toolbox: TrainerToolbox) -> tuple[float, dict[
str, float]]:
toolbox.optimizer.zero_grad()
loss, metrics = self.backward(images, labels, toolbox)
toolbox.optimizer.step()
toolbox.scheduler.step()
if toolbox.ema:
toolbox.ema.update_parameters(toolbox.model)
with torch.amp.autocast(self._device_type(), enabled=toolbox.scaler is not None):
loss, metrics = self.backward(images, labels, toolbox)
Comment thread
perctrix marked this conversation as resolved.
Comment thread
perctrix marked this conversation as resolved.
if toolbox.scaler:
old_scale = toolbox.scaler.get_scale()
toolbox.scaler.step(toolbox.optimizer)
toolbox.scaler.update()
if old_scale <= toolbox.scaler.get_scale():
toolbox.scheduler.step()
Comment thread
perctrix marked this conversation as resolved.
if toolbox.ema:
toolbox.ema.update_parameters(toolbox.model)
else:
toolbox.optimizer.step()
toolbox.scheduler.step()
if toolbox.ema:
toolbox.ema.update_parameters(toolbox.model)
return loss, metrics

def train_epoch(self, toolbox: TrainerToolbox) -> dict[str, list[float]]:
Expand Down Expand Up @@ -440,7 +466,7 @@ def train_epoch(self, toolbox: TrainerToolbox) -> dict[str, list[float]]:
def train(self, num_epochs: int, *, note: str = "", num_checkpoints: int = 5, compile_model: bool = True,
ema: bool = True, seed: int | None = None, early_stop_tolerance: int = 5,
val_score_prediction: bool = True, val_score_prediction_degree: int = 5, save_preview: bool = True,
preview_quality: float = .75) -> None:
preview_quality: float = .75, amp: bool = False) -> None:
training_arguments = self.filter_train_params(**locals())
self.init_experiment()
if note:
Expand Down Expand Up @@ -468,6 +494,12 @@ def train(self, num_epochs: int, *, note: str = "", num_checkpoints: int = 5, co
toolbox = (self.load_toolbox if self.recovery() else self.build_toolbox)(
num_epochs, example_shape, compile_model, ema
)
if amp and not toolbox.scaler:
toolbox.scaler = torch.amp.GradScaler(self._device_type())
self.log("Mixed precision training enabled")
Comment thread
perctrix marked this conversation as resolved.
elif not amp and toolbox.scaler:
toolbox.scaler = None
self.log("Mixed precision training disabled")
checkpoint_path = lambda v: f"{self.experiment_folder()}/checkpoint_{v}.pth"
es_tolerance = early_stop_tolerance
self._frontend.on_experiment_created(self._experiment_id, self._trainer_variant, model_name, note,
Expand Down Expand Up @@ -550,7 +582,7 @@ def train(self, num_epochs: int, *, note: str = "", num_checkpoints: int = 5, co
def filter_train_params(**kwargs) -> dict[str, Setting]:
return {k: v for k, v in kwargs.items() if k in (
"note", "num_checkpoints", "compile_model", "ema", "seed", "early_stop_tolerance", "val_score_prediction",
"val_score_prediction_degree", "save_preview", "preview_quality"
"val_score_prediction_degree", "save_preview", "preview_quality", "amp"
)}

def train_with_settings(self, num_epochs: int, **kwargs) -> None:
Expand Down Expand Up @@ -580,7 +612,7 @@ def validate(self, toolbox: TrainerToolbox) -> tuple[float, dict[str, list[float
worst_score = float("+inf")
metrics = {}
num_cases = len(self._validation_dataloader)
with torch.no_grad(), Progress(
with torch.no_grad(), torch.amp.autocast(self._device_type(), enabled=toolbox.scaler is not None), Progress(
*Progress.get_default_columns(), SpinnerColumn(), console=self._console
) as progress:
task = progress.add_task(f"Validating", total=num_cases)
Expand Down