Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions callbacks/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,16 @@ def _evaluate_split(
):
break

loss_metric_data = self.loss.compute()
split_metric_data = {
f"{self.loss.metric_name}_{data_split}": self.loss.compute().item(),
f"{metric_name}_{data_split}": metric_value
for metric_name, metric_value in loss_metric_data.items()
}
self.loss.reset()
for metric in self.metrics:
split_metric_data[f"{metric.metric_name}_{data_split}"] = metric.compute().item()
metric_data = metric.compute()
for metric_name, metric_value in metric_data.items():
split_metric_data[f"{metric_name}_{data_split}"] = metric_value
metric.reset()

return split_metric_data
4 changes: 2 additions & 2 deletions metrics/AbstractMetric.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ def update(self, generated_predictions: torch.Tensor, targets: torch.Tensor, **k
pass

@abstractmethod
def compute(self) -> torch.Tensor:
"""Compute the current aggregated metric value without resetting state."""
def compute(self) -> dict[str, float]:
"""Compute current aggregated metric stats without resetting state."""

pass

Expand Down
15 changes: 9 additions & 6 deletions metrics/DISTS.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,17 +77,20 @@ def update(self, generated_predictions: torch.Tensor, targets: torch.Tensor, **k

self.forward(generated_predictions=generated_predictions, targets=targets, **kwargs)

def compute(self) -> torch.Tensor:
"""Compute averaged DISTS for currently accumulated state.
def compute(self) -> dict[str, float]:
"""Compute averaged DISTS and std for current state.

Returns:
Scalar tensor with current DISTS value.
Dictionary containing mean and std metric values.
"""

average_dists = self.dists_metric.compute().to(self.device)
if not torch.isfinite(average_dists):
average_dists = torch.tensor(0.0, device=self.device)
return average_dists
return {
self.metric_name: average_dists.item(),
f"{self.metric_name}_std": 0.0,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe the docstring should note that for this metric class metric std is not actually computed?

}

@property
def metric_name(self) -> str:
Expand All @@ -96,8 +99,8 @@ def metric_name(self) -> str:
return "dists_total"

def get_metric_data(self) -> dict[str, float]:
"""Backward-compatible helper that computes and resets state."""
"""Compute metric stats and reset state."""

metric_data = {self.metric_name: self.compute().item()}
metric_data = self.compute()
self.reset()
return metric_data
26 changes: 20 additions & 6 deletions metrics/L1.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def reset(self):
"""Reset running L1 accumulators used for epoch-level logging."""

self.total_abs_error = torch.tensor(0.0, device=self.device)
self.total_abs_error_sq = torch.tensor(0.0, device=self.device)
self.total_examples = torch.tensor(0.0, device=self.device)

def forward(
Expand Down Expand Up @@ -58,6 +59,7 @@ def forward(
per_sample_l1 = abs_error.mean(dim=1)

self.total_abs_error += per_sample_l1.sum().detach().to(self.device)
self.total_abs_error_sq += per_sample_l1.pow(2).sum().detach().to(self.device)
self.total_examples += torch.tensor(
per_sample_l1.numel(),
dtype=torch.float32,
Expand All @@ -70,21 +72,33 @@ def update(self, generated_predictions: torch.Tensor, targets: torch.Tensor, **k

self.forward(generated_predictions=generated_predictions, targets=targets, **kwargs)

def compute(self) -> torch.Tensor:
"""Compute averaged L1 value for currently accumulated state.
def compute(self) -> dict[str, float]:
"""Compute averaged L1 and population std for current state.

Returns:
Scalar tensor with current L1 value.
Dictionary containing mean and std metric values.
"""

average_l1 = torch.where(
self.total_examples > 0,
self.total_abs_error / self.total_examples,
torch.tensor(0.0, device=self.device),
)
variance_l1 = torch.where(
self.total_examples > 0,
(self.total_abs_error_sq / self.total_examples) - average_l1.pow(2),
torch.tensor(0.0, device=self.device),
)
std_l1 = torch.sqrt(torch.clamp(variance_l1, min=0.0))
if not torch.isfinite(average_l1):
average_l1 = torch.tensor(0.0, device=self.device)
return average_l1
if not torch.isfinite(std_l1):
std_l1 = torch.tensor(0.0, device=self.device)

return {
self.metric_name: average_l1.item(),
f"{self.metric_name}_std": std_l1.item(),
}

@property
def metric_name(self) -> str:
Expand All @@ -93,8 +107,8 @@ def metric_name(self) -> str:
return "l1_total"

def get_metric_data(self) -> dict[str, float]:
"""Backward-compatible helper that computes and resets state."""
"""Compute metric stats and reset state."""

metric_data = {self.metric_name: self.compute().item()}
metric_data = self.compute()
self.reset()
return metric_data
26 changes: 20 additions & 6 deletions metrics/L2.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def reset(self):
"""Reset running squared-error accumulators."""

self.total_squared_error = torch.tensor(0.0, device=self.device)
self.total_squared_error_sq = torch.tensor(0.0, device=self.device)
self.total_examples = torch.tensor(0.0, device=self.device)
Comment on lines 33 to 35
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These variable names made sense and are accurate but I think they are not the most clear and understandable especially given that the exact same variance compute shortcut of expectation of sum of squares - expectation squared pattern is identical across multiple metric classes. I wonder it would be most beneficial to the repo to abstract out the metric value std accumulation process as a standalone and reusable functionality so effectively redundant code can be minimized and future metric addition would be less work?

Given that sample mean and sample std are just the first and second central norm, naturally their rolling accumulation can be handled together as the following generalized form:

class StreamingScalarStats:
    """
    Accumulates mean, variance, and std for a stream of per sample scalar values.
    By default, computes population variance:
        var = E[x^2] - E[x]^2
    Set ddof=1 for sample variance.
    """

    def __init__(
        self,
        device: torch.device | str = "cuda",
        dtype: torch.dtype = torch.float64,
        ddof: int = 0,
    ):
        if ddof < 0:
            raise ValueError("ddof must be non-negative")

        self.device = device if isinstance(device, torch.device) else torch.device(device)
        self.dtype = dtype
        self.ddof = ddof
        self.reset()

    def reset(self) -> None:
        self.sum_x = torch.tensor(0.0, device=self.device, dtype=self.dtype)
        self.sum_x_sq = torch.tensor(0.0, device=self.device, dtype=self.dtype)
        self.n = torch.tensor(0.0, device=self.device, dtype=self.dtype)

    @torch.no_grad()
    def update(self, values: torch.Tensor) -> None:
        """
        Update accumulator with a tensor of scalar observations.
        """
        values = values.detach().to(device=self.device, dtype=self.dtype).reshape(-1)

        if values.numel() == 0:
            return

        self.sum_x += values.sum()
        self.sum_x_sq += values.pow(2).sum()
        self.n += values.numel()

    def compute(self) -> dict[str, float]:
        if self.n.item() == 0:
            return {
                "mean": 0.0,
                "variance": 0.0,
                "std": 0.0,
                "n": 0.0,
            }

        mean = self.sum_x / self.n

        denom = self.n - self.ddof
        if denom <= 0:
            variance = torch.tensor(0.0, device=self.device, dtype=self.dtype)
        else:
            # For ddof=0: var = sum_x_sq / n - mean^2
            # For ddof=1: unbiased sample variance
            correction = self.n / denom
            variance = correction * (self.sum_x_sq / self.n - mean.pow(2))

        variance = torch.clamp(variance, min=0.0)
        std = torch.sqrt(variance)

        return {
            "mean": mean.item(),
            "variance": variance.item(),
            "std": std.item(),
            "n": self.n.item(),
        }

this would reduce all of your metric functions to merely wrappers of some forward function to compute the sample level measure and StreamingScalarStats to perform the rolling accumulation.

def __init__(self, ...):
        ...
        self.stats = StreamingScalarStats(device=self.device, ddof=0)
        ...
def forward(self,...):
        # do something to get per sample measure tensor
def reset(self):
        self.stats.reset()
        ...
def compute(self) -> dict[str, float]:
        stats = self.stats.compute()

        return {
            self.metric_name: stats["mean"],
            f"{self.metric_name}_std": stats["std"],
        }
        ...


def forward(
Expand Down Expand Up @@ -58,6 +59,7 @@ def forward(
per_sample_l2 = sq_error.mean(dim=1)

self.total_squared_error += per_sample_l2.sum().detach().to(self.device)
self.total_squared_error_sq += per_sample_l2.pow(2).sum().detach().to(self.device)
self.total_examples += torch.tensor(
per_sample_l2.numel(),
dtype=torch.float32,
Expand All @@ -70,21 +72,33 @@ def update(self, generated_predictions: torch.Tensor, targets: torch.Tensor, **k

self.forward(generated_predictions=generated_predictions, targets=targets, **kwargs)

def compute(self) -> torch.Tensor:
"""Compute averaged L2 value for currently accumulated state.
def compute(self) -> dict[str, float]:
"""Compute averaged L2 and population std for current state.

Returns:
Scalar tensor with current L2 value.
Dictionary containing mean and std metric values.
"""

average_l2 = torch.where(
self.total_examples > 0,
self.total_squared_error / self.total_examples,
torch.tensor(0.0, device=self.device),
)
variance_l2 = torch.where(
self.total_examples > 0,
(self.total_squared_error_sq / self.total_examples) - average_l2.pow(2),
torch.tensor(0.0, device=self.device),
)
std_l2 = torch.sqrt(torch.clamp(variance_l2, min=0.0))
if not torch.isfinite(average_l2):
average_l2 = torch.tensor(0.0, device=self.device)
return average_l2
if not torch.isfinite(std_l2):
std_l2 = torch.tensor(0.0, device=self.device)

return {
self.metric_name: average_l2.item(),
f"{self.metric_name}_std": std_l2.item(),
}

@property
def metric_name(self) -> str:
Expand All @@ -93,8 +107,8 @@ def metric_name(self) -> str:
return "l2_total"

def get_metric_data(self) -> dict[str, float]:
"""Backward-compatible helper that computes and resets state."""
"""Compute metric stats and reset state."""

metric_data = {self.metric_name: self.compute().item()}
metric_data = self.compute()
self.reset()
return metric_data
15 changes: 9 additions & 6 deletions metrics/LPIPS.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,17 +87,20 @@ def update(self, generated_predictions: torch.Tensor, targets: torch.Tensor, **k

self.forward(generated_predictions=generated_predictions, targets=targets, **kwargs)

def compute(self) -> torch.Tensor:
"""Compute averaged LPIPS for currently accumulated state.
def compute(self) -> dict[str, float]:
"""Compute averaged LPIPS and std for current state.

Returns:
Scalar tensor with current LPIPS value.
Dictionary containing mean and std metric values.
"""

average_lpips = self.lpips_metric.compute().to(self.device)
if not torch.isfinite(average_lpips):
average_lpips = torch.tensor(0.0, device=self.device)
return average_lpips
return {
self.metric_name: average_lpips.item(),
f"{self.metric_name}_std": 0.0,
}

@property
def metric_name(self) -> str:
Expand All @@ -106,8 +109,8 @@ def metric_name(self) -> str:
return "lpips_total"

def get_metric_data(self) -> dict[str, float]:
"""Backward-compatible helper that computes and resets state."""
"""Compute metric stats and reset state."""

metric_data = {self.metric_name: self.compute().item()}
metric_data = self.compute()
self.reset()
return metric_data
49 changes: 41 additions & 8 deletions metrics/PSNR.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(
)
self.psnr_metric = PeakSignalNoiseRatio(
data_range=max_pixel_value,
reduction="elementwise_mean",
reduction="none",
dim=(1, 2, 3),
).to(self.device)
self.reset()
Expand All @@ -43,6 +43,9 @@ def reset(self):
"""Reset running PSNR accumulators."""

self.psnr_metric.reset()
self.total_psnr = torch.tensor(0.0, device=self.device)
self.total_psnr_sq = torch.tensor(0.0, device=self.device)
self.total_examples = torch.tensor(0.0, device=self.device)

def forward(
self,
Expand All @@ -65,24 +68,54 @@ def forward(
raise ValueError("The generated predictions and targets must be the same shape.")

self.psnr_metric.update(generated_predictions, targets)
per_sample_psnr = self.psnr_metric.compute().to(self.device).reshape(-1)
self.psnr_metric.reset()
finite_psnr = torch.where(
torch.isfinite(per_sample_psnr),
per_sample_psnr,
torch.tensor(self.nonfinite_cap, device=self.device),
)
self.total_psnr += finite_psnr.sum().detach()
self.total_psnr_sq += finite_psnr.pow(2).sum().detach()
self.total_examples += torch.tensor(
finite_psnr.numel(),
dtype=torch.float32,
device=self.device,
)
return None

def update(self, generated_predictions: torch.Tensor, targets: torch.Tensor, **kwargs) -> None:
"""Alias for state updates to align with TorchMetrics-like API."""

self.forward(generated_predictions=generated_predictions, targets=targets, **kwargs)

def compute(self) -> torch.Tensor:
"""Compute averaged PSNR for currently accumulated state.
def compute(self) -> dict[str, float]:
"""Compute averaged PSNR and population std for current state.

Returns:
Scalar tensor with current PSNR value.
Dictionary containing mean and std metric values.
"""

average_psnr = self.psnr_metric.compute().to(self.device)
average_psnr = torch.where(
self.total_examples > 0,
self.total_psnr / self.total_examples,
torch.tensor(0.0, device=self.device),
)
variance_psnr = torch.where(
self.total_examples > 0,
(self.total_psnr_sq / self.total_examples) - average_psnr.pow(2),
torch.tensor(0.0, device=self.device),
)
std_psnr = torch.sqrt(torch.clamp(variance_psnr, min=0.0))
if not torch.isfinite(average_psnr):
average_psnr = torch.tensor(self.nonfinite_cap, device=self.device)
return average_psnr
if not torch.isfinite(std_psnr):
std_psnr = torch.tensor(0.0, device=self.device)

return {
self.metric_name: average_psnr.item(),
f"{self.metric_name}_std": std_psnr.item(),
}

@property
def metric_name(self) -> str:
Expand All @@ -91,8 +124,8 @@ def metric_name(self) -> str:
return "psnr_total"

def get_metric_data(self) -> dict[str, float]:
"""Backward-compatible helper that computes and resets state."""
"""Compute metric stats and reset state."""

metric_data = {self.metric_name: self.compute().item()}
metric_data = self.compute()
self.reset()
return metric_data
26 changes: 20 additions & 6 deletions metrics/PearsonCorrelation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def reset(self):
"""Reset running Pearson correlation accumulators."""

self.total_pearson = torch.tensor(0.0, device=self.device)
self.total_pearson_sq = torch.tensor(0.0, device=self.device)
self.total_examples = torch.tensor(0.0, device=self.device)

def forward(
Expand Down Expand Up @@ -75,6 +76,7 @@ def forward(
)

self.total_pearson += per_sample_pearson.sum().detach()
self.total_pearson_sq += per_sample_pearson.pow(2).sum().detach()
self.total_examples += torch.tensor(
per_sample_pearson.shape[0],
dtype=torch.float32,
Expand All @@ -87,21 +89,33 @@ def update(self, generated_predictions: torch.Tensor, targets: torch.Tensor, **k

self.forward(generated_predictions=generated_predictions, targets=targets, **kwargs)

def compute(self) -> torch.Tensor:
"""Compute averaged Pearson correlation for currently accumulated state.
def compute(self) -> dict[str, float]:
"""Compute averaged Pearson and population std for current state.

Returns:
Scalar tensor with current Pearson correlation value.
Dictionary containing mean and std metric values.
"""

average_pearson = torch.where(
self.total_examples > 0,
self.total_pearson / self.total_examples,
torch.tensor(0.0, device=self.device),
)
variance_pearson = torch.where(
self.total_examples > 0,
(self.total_pearson_sq / self.total_examples) - average_pearson.pow(2),
torch.tensor(0.0, device=self.device),
)
std_pearson = torch.sqrt(torch.clamp(variance_pearson, min=0.0))
if not torch.isfinite(average_pearson):
average_pearson = torch.tensor(0.0, device=self.device)
return average_pearson
if not torch.isfinite(std_pearson):
std_pearson = torch.tensor(0.0, device=self.device)

return {
self.metric_name: average_pearson.item(),
f"{self.metric_name}_std": std_pearson.item(),
}

@property
def metric_name(self) -> str:
Expand All @@ -110,8 +124,8 @@ def metric_name(self) -> str:
return "pearson_total"

def get_metric_data(self) -> dict[str, float]:
"""Backward-compatible helper that computes and resets state."""
"""Compute metric stats and reset state."""

metric_data = {self.metric_name: self.compute().item()}
metric_data = self.compute()
self.reset()
return metric_data
Loading