Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
8e5fe51
feat: implement ignore_index support in metrics and losses with dedic…
Rusheel86 Feb 28, 2026
a1b0a4f
test: run compute tests, format and lint loss/metric updates
Rusheel86 Mar 1, 2026
f2caaf8
feat: implement ignore_index support for losses and metrics
Rusheel86 Mar 9, 2026
d075009
chore: trigger CI rerun
Rusheel86 Mar 9, 2026
941a73b
chore: trigger CI rerun
Rusheel86 Mar 9, 2026
0f6e05a
DCO Remediation Commit for Rusheel Sharma <rusheelhere@gmail.com>
Rusheel86 Mar 9, 2026
a1f6ef4
fix: revert GWDL reduction handling and apply black formatting
Rusheel86 Mar 9, 2026
f01cbc4
fix: resolve shape issues and CI fails
Rusheel86 Mar 10, 2026
af83422
style: reformat with black 25.11.0
Rusheel86 Mar 10, 2026
187da14
fix: resolve mypy type error in utils.py
Rusheel86 Mar 11, 2026
780b567
fix: complete ignore_index implementation with proper one-hot masking
Rusheel86 Mar 11, 2026
91bb2e5
fix: resolve mypy union-attr error in unified_focal_loss
Rusheel86 Mar 12, 2026
57c2f78
chore: trigger CI with fresh runner
Rusheel86 Mar 12, 2026
9863a93
chore: retrigger CI (previous runs had disk space errors)
Rusheel86 Mar 12, 2026
170f34a
fix: address CodeRabbit minor issues
Rusheel86 Mar 12, 2026
c80eeeb
fix: address CodeRabbit critical and major issues
Rusheel86 Mar 12, 2026
1114907
fix: resolve all mypy and CodeRabbit issues
Rusheel86 Mar 13, 2026
3bd76e7
fix:CodeRabbit issues
Rusheel86 Mar 13, 2026
610756d
refactor: centralize ignore_index masking into create_ignore_mask helper
Rusheel86 Mar 19, 2026
754407b
Fix docstring indentation in create_ignore_mask
Rusheel86 Mar 19, 2026
b6d2362
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 19, 2026
edc9e4e
Fix NoneType error in AsymmetricFocalTverskyLoss for None ignore_index
Rusheel86 Mar 19, 2026
4457e37
Merge branch 'feat-ignore-index-support' of https://github.com/Rushee…
Rusheel86 Mar 19, 2026
c2612ea
style: fix import sorting with isort
Rusheel86 Mar 19, 2026
cfc54ec
fix: CI errors
Rusheel86 Mar 19, 2026
eeda3c7
chore: format and lint code
Rusheel86 Mar 19, 2026
64421e1
Fix : lint and format
Rusheel86 Mar 19, 2026
df0833b
chore: trigger CI
Rusheel86 Mar 19, 2026
9cb6592
style: format with black --skip-magic-trailing-comma for Python 3.9 c…
Rusheel86 Mar 20, 2026
03e5e9b
fix: add type ignore comment for mypy no-any-return in utils.py
Rusheel86 Mar 20, 2026
5fb4d4f
style: reformat utils.py after adding type ignore comment
Rusheel86 Mar 20, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions monai/losses/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from monai.losses.focal_loss import FocalLoss
from monai.losses.spatial_mask import MaskedLoss
from monai.losses.utils import compute_tp_fp_fn
from monai.metrics.utils import create_ignore_mask
from monai.networks import one_hot
from monai.utils import DiceCEReduction, LossReduction, Weight, look_up_option

Expand Down Expand Up @@ -67,6 +68,7 @@ def __init__(
batch: bool = False,
weight: Sequence[float] | float | int | torch.Tensor | None = None,
soft_label: bool = False,
ignore_index: int | None = None,
) -> None:
"""
Args:
Expand Down Expand Up @@ -100,6 +102,7 @@ def __init__(
The value/values should be no less than 0. Defaults to None.
soft_label: whether the target contains non-binary values (soft labels) or not.
If True a soft label formulation of the loss will be used.
ignore_index: class index to ignore from the loss computation.

Raises:
TypeError: When ``other_act`` is not an ``Optional[Callable]``.
Expand All @@ -122,6 +125,7 @@ def __init__(
self.smooth_nr = float(smooth_nr)
self.smooth_dr = float(smooth_dr)
self.batch = batch
self.ignore_index = ignore_index
weight = torch.as_tensor(weight) if weight is not None else None
self.register_buffer("class_weight", weight)
self.class_weight: None | torch.Tensor
Expand Down Expand Up @@ -163,12 +167,15 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
if self.other_act is not None:
input = self.other_act(input)

original_target = target if self.ignore_index is not None else None

if self.to_onehot_y:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
else:
target = one_hot(target, num_classes=n_pred_ch)

mask = create_ignore_mask(original_target if original_target is not None else target, self.ignore_index)
if not self.include_background:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `include_background=False` ignored.")
Expand All @@ -180,6 +187,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
if target.shape != input.shape:
raise AssertionError(f"ground truth has different shape ({target.shape}) from input ({input.shape})")

if mask is not None:
input = input * mask
target = target * mask

# reducing only spatial dimensions (not batch nor channels)
reduce_axis: list[int] = torch.arange(2, len(input.shape)).tolist()
if self.batch:
Expand Down
9 changes: 9 additions & 0 deletions monai/losses/focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch.nn.functional as F
from torch.nn.modules.loss import _Loss

from monai.metrics.utils import create_ignore_mask
from monai.networks import one_hot
from monai.utils import LossReduction

Expand Down Expand Up @@ -73,6 +74,7 @@ def __init__(
weight: Sequence[float] | float | int | torch.Tensor | None = None,
reduction: LossReduction | str = LossReduction.MEAN,
use_softmax: bool = False,
ignore_index: int | None = None,
) -> None:
"""
Args:
Expand All @@ -99,6 +101,7 @@ def __init__(

use_softmax: whether to use softmax to transform the original logits into probabilities.
If True, softmax is used. If False, sigmoid is used. Defaults to False.
ignore_index: class index to ignore from the loss computation.

Example:
>>> import torch
Expand All @@ -124,6 +127,7 @@ def __init__(
weight = torch.as_tensor(weight) if weight is not None else None
self.register_buffer("class_weight", weight)
self.class_weight: None | torch.Tensor
self.ignore_index = ignore_index

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Expand Down Expand Up @@ -161,6 +165,11 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
if target.shape != input.shape:
raise ValueError(f"ground truth has different shape ({target.shape}) from input ({input.shape})")

mask = create_ignore_mask(target, self.ignore_index)
if mask is not None:
input = input * mask
target = target * mask

loss: torch.Tensor | None = None
input = input.float()
target = target.float()
Expand Down
12 changes: 12 additions & 0 deletions monai/losses/tversky.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from torch.nn.modules.loss import _Loss

from monai.losses.utils import compute_tp_fp_fn
from monai.metrics.utils import create_ignore_mask
from monai.networks import one_hot
from monai.utils import LossReduction

Expand Down Expand Up @@ -51,6 +52,7 @@ def __init__(
smooth_dr: float = 1e-5,
batch: bool = False,
soft_label: bool = False,
ignore_index: int | None = None,
) -> None:
"""
Args:
Expand All @@ -77,6 +79,7 @@ def __init__(
before any `reduction`.
soft_label: whether the target contains non-binary values (soft labels) or not.
If True a soft label formulation of the loss will be used.
ignore_index: index of the class to ignore during calculation.

Raises:
TypeError: When ``other_act`` is not an ``Optional[Callable]``.
Expand All @@ -101,6 +104,7 @@ def __init__(
self.smooth_dr = float(smooth_dr)
self.batch = batch
self.soft_label = soft_label
self.ignore_index = ignore_index

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Expand Down Expand Up @@ -129,8 +133,16 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
else:
original_target = target
target = one_hot(target, num_classes=n_pred_ch)

if self.ignore_index is not None:
mask_src = original_target if self.to_onehot_y and n_pred_ch > 1 else target
mask = create_ignore_mask(mask_src, self.ignore_index)
if mask is not None:
input = input * mask
target = target * mask

if not self.include_background:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `include_background=False` ignored.")
Expand Down
115 changes: 89 additions & 26 deletions monai/losses/unified_focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import torch
from torch.nn.modules.loss import _Loss

from monai.metrics.utils import create_ignore_mask
from monai.networks import one_hot
from monai.utils import LossReduction

Expand All @@ -39,48 +40,69 @@ def __init__(
gamma: float = 0.75,
epsilon: float = 1e-7,
reduction: LossReduction | str = LossReduction.MEAN,
ignore_index: int | None = None,
) -> None:
"""
Args:
to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
delta : weight of the background. Defaults to 0.7.
gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75.
epsilon : it defines a very small number each time. similarly smooth value. Defaults to 1e-7.
ignore_index: class index to ignore from the loss computation.
"""
super().__init__(reduction=LossReduction(reduction).value)
self.to_onehot_y = to_onehot_y
self.delta = delta
self.gamma = gamma
self.epsilon = epsilon
self.ignore_index = ignore_index

def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
n_pred_ch = y_pred.shape[1]

# Save original for masking
original_y_true = y_true if self.ignore_index is not None else None

if self.to_onehot_y:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
else:
if self.ignore_index is not None:
# Replace ignore_index with valid class before one_hot
y_true = torch.where(y_true == self.ignore_index, torch.tensor(0, device=y_true.device), y_true)
y_true = one_hot(y_true, num_classes=n_pred_ch)

if y_true.shape != y_pred.shape:
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")

# clip the prediction to avoid NaN
mask = create_ignore_mask(original_y_true if original_y_true is not None else y_true, self.ignore_index)
if mask is not None:
mask = mask.expand_as(y_true)

y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon)
axis = list(range(2, len(y_pred.shape)))

# Calculate true positives (tp), false negatives (fn) and false positives (fp)
tp = torch.sum(y_true * y_pred, dim=axis)
fn = torch.sum(y_true * (1 - y_pred), dim=axis)
fp = torch.sum((1 - y_true) * y_pred, dim=axis)
if mask is not None:
tp = torch.sum(y_true * y_pred * mask, dim=axis)
fn = torch.sum(y_true * (1 - y_pred) * mask, dim=axis)
fp = torch.sum((1 - y_true) * y_pred * mask, dim=axis)
else:
tp = torch.sum(y_true * y_pred, dim=axis)
fn = torch.sum(y_true * (1 - y_pred), dim=axis)
fp = torch.sum((1 - y_true) * y_pred, dim=axis)
dice_class = (tp + self.epsilon) / (tp + self.delta * fn + (1 - self.delta) * fp + self.epsilon)

# Calculate losses separately for each class, enhancing both classes
back_dice = 1 - dice_class[:, 0]
fore_dice = (1 - dice_class[:, 1]) * torch.pow(1 - dice_class[:, 1], -self.gamma)
fore_dice = torch.pow(1 - dice_class[:, 1], 1 - self.gamma)

# Average class scores
loss = torch.mean(torch.stack([back_dice, fore_dice], dim=-1))
loss = torch.stack([back_dice, fore_dice], dim=-1)
if self.reduction == LossReduction.MEAN.value:
return torch.mean(loss)
if self.reduction == LossReduction.SUM.value:
return torch.sum(loss)
return loss


Expand All @@ -103,42 +125,66 @@ def __init__(
gamma: float = 2,
epsilon: float = 1e-7,
reduction: LossReduction | str = LossReduction.MEAN,
ignore_index: int | None = None,
):
"""
Args:
to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False.
delta : weight of the background. Defaults to 0.7.
gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 2.
epsilon : it defines a very small number each time. similarly smooth value. Defaults to 1e-7.
ignore_index: class index to ignore from the loss computation.
"""
super().__init__(reduction=LossReduction(reduction).value)
self.to_onehot_y = to_onehot_y
self.delta = delta
self.gamma = gamma
self.epsilon = epsilon
self.ignore_index = ignore_index

def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
n_pred_ch = y_pred.shape[1]

# Save original for masking
original_y_true = y_true if self.ignore_index is not None else None

if self.to_onehot_y:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
else:
if self.ignore_index is not None:
# Replace ignore_index with valid class before one_hot
y_true = torch.where(y_true == self.ignore_index, torch.tensor(0, device=y_true.device), y_true)
y_true = one_hot(y_true, num_classes=n_pred_ch)

if y_true.shape != y_pred.shape:
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")

mask = create_ignore_mask(original_y_true if original_y_true is not None else y_true, self.ignore_index)
if mask is not None:
mask = mask.expand_as(y_true)

y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon)
cross_entropy = -y_true * torch.log(y_pred)

if mask is not None:
cross_entropy = cross_entropy * mask

back_ce = torch.pow(1 - y_pred[:, 0], self.gamma) * cross_entropy[:, 0]
back_ce = (1 - self.delta) * back_ce

fore_ce = cross_entropy[:, 1]
fore_ce = self.delta * fore_ce

loss = torch.mean(torch.sum(torch.stack([back_ce, fore_ce], dim=1), dim=1))
loss = torch.stack([back_ce, fore_ce], dim=1) # [B, 2, H, W]

if self.reduction == LossReduction.MEAN.value:
if mask is not None:
masked_loss = loss * mask
return masked_loss.sum() / mask.expand_as(loss).sum().clamp(min=1e-5)
return loss.mean()
if self.reduction == LossReduction.SUM.value:
return loss.sum()
return loss


Expand All @@ -162,6 +208,7 @@ def __init__(
gamma: float = 0.5,
delta: float = 0.7,
reduction: LossReduction | str = LossReduction.MEAN,
ignore_index: int | None = None,
):
"""
Args:
Expand All @@ -170,8 +217,7 @@ def __init__(
weight : weight for each loss function. Defaults to 0.5.
gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.5.
delta : weight of the background. Defaults to 0.7.


ignore_index: class index to ignore from the loss computation.

Example:
>>> import torch
Expand All @@ -187,10 +233,14 @@ def __init__(
self.gamma = gamma
self.delta = delta
self.weight: float = weight
self.asy_focal_loss = AsymmetricFocalLoss(gamma=self.gamma, delta=self.delta)
self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(gamma=self.gamma, delta=self.delta)
self.asy_focal_loss = AsymmetricFocalLoss(
to_onehot_y=False, gamma=self.gamma, delta=self.delta, ignore_index=ignore_index
)
self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(
to_onehot_y=False, gamma=self.gamma, delta=self.delta, ignore_index=ignore_index
)
self.ignore_index = ignore_index

# TODO: Implement this function to support multiple classes segmentation
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
"""
Args:
Expand All @@ -207,28 +257,41 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
ValueError: When num_classes
ValueError: When the number of classes entered does not match the expected number
"""
if y_pred.shape != y_true.shape:
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")

if len(y_pred.shape) != 4 and len(y_pred.shape) != 5:
raise ValueError(f"input shape must be 4 or 5, but got {y_pred.shape}")

# Transform binary inputs to 2-channel space
if y_pred.shape[1] == 1:
y_pred = one_hot(y_pred, num_classes=self.num_classes)
y_true = one_hot(y_true, num_classes=self.num_classes)
y_pred = torch.cat([1 - y_pred, y_pred], dim=1)

if torch.max(y_true) != self.num_classes - 1:
raise ValueError(f"Please make sure the number of classes is {self.num_classes - 1}")
# Save original for masking before one-hot conversion
original_y_true = y_true if self.ignore_index is not None else None

n_pred_ch = y_pred.shape[1]
if self.to_onehot_y:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
else:
y_true = one_hot(y_true, num_classes=n_pred_ch)
if self.ignore_index is not None:
# Replace ignore_index with valid class before one_hot
y_true = torch.where(y_true == self.ignore_index, torch.tensor(0, device=y_true.device), y_true)
y_true = one_hot(y_true, num_classes=self.num_classes)
elif y_true.shape[1] == 1 and y_pred.shape[1] == 2:
y_true = torch.cat([1 - y_true, y_true], dim=1)

if y_true.shape != y_pred.shape:
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")
if self.ignore_index is None and torch.max(y_true) > self.num_classes - 1:
raise ValueError(f"Invalid class index found. Maximum class should be {self.num_classes - 1}")

mask = create_ignore_mask(original_y_true if original_y_true is not None else y_true, self.ignore_index)

if mask is not None:
mask_expanded = mask.expand_as(y_true)
y_pred_masked = y_pred * mask_expanded
y_true_masked = y_true * mask_expanded
else:
y_pred_masked = y_pred
y_true_masked = y_true

asy_focal_loss = self.asy_focal_loss(y_pred, y_true)
asy_focal_tversky_loss = self.asy_focal_tversky_loss(y_pred, y_true)
asy_focal_loss = self.asy_focal_loss(y_pred_masked, y_true_masked)
asy_focal_tversky_loss = self.asy_focal_tversky_loss(y_pred_masked, y_true_masked)

loss: torch.Tensor = self.weight * asy_focal_loss + (1 - self.weight) * asy_focal_tversky_loss

Expand Down
Loading
Loading