Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,11 @@ Metrics
.. autoclass:: PSNRMetric
:members:

`Mean absolute percentage error`
---------------------------------
.. autoclass:: MAPEMetric
:members:

`Structural similarity index measure`
-------------------------------------
.. autoclass:: monai.metrics.regression.SSIMMetric
Expand Down
1 change: 1 addition & 0 deletions monai/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from .panoptic_quality import PanopticQualityMetric, compute_panoptic_quality
from .regression import (
MAEMetric,
MAPEMetric,
MSEMetric,
MultiScaleSSIMMetric,
PSNRMetric,
Expand Down
50 changes: 50 additions & 0 deletions monai/metrics/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,39 @@ def _compute_metric(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor
return compute_mean_error_metrics(y_pred, y, func=self.abs_func)


class MAPEMetric(RegressionMetric):
r"""Compute Mean Absolute Percentage Error between two tensors using function:

.. math::
\operatorname {MAPE}\left(Y, \hat{Y}\right) =\frac {100}{n}\sum _{i=1}^{n}\left|\frac{y_i-\hat{y_i}}{y_i}\right|.

More info: https://en.wikipedia.org/wiki/Mean_absolute_percentage_error

Input `y_pred` is compared with ground truth `y`.
Both `y_pred` and `y` are expected to be real-valued, where `y_pred` is output from a regression model.
Note: Tackling the undefined error, a tiny epsilon value is added to the denominator part.

Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.

Args:
reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values,
available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction.
get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans).
epsilon: float. Defaults to 1e-7.

"""

def __init__(
self, reduction: MetricReduction | str = MetricReduction.MEAN, get_not_nans: bool = False, epsilon: float = 1e-7
) -> None:
super().__init__(reduction=reduction, get_not_nans=get_not_nans)
self.epsilon = epsilon

def _compute_metric(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return compute_mape_metric(y_pred, y, epsilon=self.epsilon)


class RMSEMetric(RegressionMetric):
r"""Compute Root Mean Squared Error between two tensors using function:

Expand Down Expand Up @@ -220,6 +253,23 @@ def compute_mean_error_metrics(y_pred: torch.Tensor, y: torch.Tensor, func: Call
return torch.mean(flt(func(y - y_pred)), dim=-1, keepdim=True)


def compute_mape_metric(y_pred: torch.Tensor, y: torch.Tensor, epsilon: float = 1e-7) -> torch.Tensor:
"""
Compute Mean Absolute Percentage Error.

Args:
y_pred: predicted values
y: ground truth values
epsilon: small value to avoid division by zero

Returns:
MAPE value as percentage
"""
flt = partial(torch.flatten, start_dim=1)
percentage_error = torch.abs(y - y_pred) / torch.clamp(torch.abs(y), min=epsilon) * 100.0
return torch.mean(flt(percentage_error), dim=-1, keepdim=True)


class KernelType(StrEnum):
GAUSSIAN = "gaussian"
UNIFORM = "uniform"
Expand Down
25 changes: 15 additions & 10 deletions tests/metrics/test_compute_regression_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import numpy as np
import torch

from monai.metrics import MAEMetric, MSEMetric, PSNRMetric, RMSEMetric
from monai.metrics import MAEMetric, MAPEMetric, MSEMetric, PSNRMetric, RMSEMetric
from monai.utils import set_determinism


Expand All @@ -44,14 +44,19 @@ def psnrmetric_np(max_val, y_pred, y):
return np.mean(20 * np.log10(max_val) - 10 * np.log10(mse))


def mapemetric_np(y_pred, y, epsilon=1e-7):
percentage_error = np.abs(y - y_pred) / np.clip(np.abs(y), a_min=epsilon, a_max=None) * 100.0
return np.mean(flatten(percentage_error))


class TestRegressionMetrics(unittest.TestCase):

def test_shape_reduction(self):
set_determinism(seed=123)
device = "cuda" if torch.cuda.is_available() else "cpu"

# regression metrics to check
metrics = [MSEMetric, MAEMetric, RMSEMetric, partial(PSNRMetric, max_val=1.0)]
metrics = [MSEMetric, MAEMetric, MAPEMetric, RMSEMetric, partial(PSNRMetric, max_val=1.0)]

# define variations in batch/base_dims/spatial_dims
batch_dims = [1, 2, 4, 16]
Expand Down Expand Up @@ -94,8 +99,8 @@ def test_compare_numpy(self):
device = "cuda" if torch.cuda.is_available() else "cpu"

# regression metrics to check + truth metric function in numpy
metrics = [MSEMetric, MAEMetric, RMSEMetric, partial(PSNRMetric, max_val=1.0)]
metrics_np = [msemetric_np, maemetric_np, rmsemetric_np, partial(psnrmetric_np, max_val=1.0)]
metrics = [MSEMetric, MAEMetric, MAPEMetric, RMSEMetric, partial(PSNRMetric, max_val=1.0)]
metrics_np = [msemetric_np, maemetric_np, mapemetric_np, rmsemetric_np, partial(psnrmetric_np, max_val=1.0)]

# define variations in batch/base_dims/spatial_dims
batch_dims = [1, 2, 4, 16]
Expand All @@ -117,14 +122,14 @@ def test_compare_numpy(self):
out_tensor = mt.aggregate(reduction="mean")
out_np = mt_fn_np(y_pred=in_tensor_a.cpu().numpy(), y=in_tensor_b.cpu().numpy())

np.testing.assert_allclose(out_tensor.cpu().numpy(), out_np, atol=1e-4)
np.testing.assert_allclose(out_tensor.cpu().numpy(), out_np, atol=1e-3, rtol=1e-4)

def test_ill_shape(self):
set_determinism(seed=123)
device = "cuda" if torch.cuda.is_available() else "cpu"

# regression metrics to check + truth metric function in numpy
metrics = [MSEMetric, MAEMetric, RMSEMetric, partial(PSNRMetric, max_val=1.0)]
metrics = [MSEMetric, MAEMetric, MAPEMetric, RMSEMetric, partial(PSNRMetric, max_val=1.0)]
basedim = 10

# too small shape
Expand All @@ -143,8 +148,8 @@ def test_ill_shape(self):
def test_same_input(self):
set_determinism(seed=123)
device = "cuda" if torch.cuda.is_available() else "cpu"
metrics = [MSEMetric, MAEMetric, RMSEMetric, partial(PSNRMetric, max_val=1.0)]
results = [0.0, 0.0, 0.0, float("inf")]
metrics = [MSEMetric, MAEMetric, MAPEMetric, RMSEMetric, partial(PSNRMetric, max_val=1.0)]
results = [0.0, 0.0, 0.0, 0.0, float("inf")]

# define variations in batch/base_dims/spatial_dims
batch_dims = [1, 2, 4, 16]
Expand All @@ -168,8 +173,8 @@ def test_same_input(self):
def test_diff_input(self):
set_determinism(seed=123)
device = "cuda" if torch.cuda.is_available() else "cpu"
metrics = [MSEMetric, MAEMetric, RMSEMetric, partial(PSNRMetric, max_val=1.0)]
results = [1.0, 1.0, 1.0, 0.0]
metrics = [MSEMetric, MAEMetric, MAPEMetric, RMSEMetric, partial(PSNRMetric, max_val=1.0)]
results = [1.0, 1.0, 100.0, 1.0, 0.0]

# define variations in batch/base_dims/spatial_dims
batch_dims = [1, 2, 4, 16]
Expand Down
Loading