Log Metric STD#22
Conversation
Refactor metric aggregation to report both mean and population standard deviation (ddof=0) for training/evaluation metrics. Update callback logging flow so all returned metric stats are logged per split. - Change AbstractMetric.compute() to return metric stats dicts - Update L1/L2/Pearson to accumulate count/sum/sum_sq and compute std - Update PSNR/SSIM to aggregate per-sample values and compute std - Keep LPIPS/DISTS compatible with new stats interface - Update epoch evaluator to log all metric stat keys with split suffixes - Preserve training behavior while expanding MLflow metric visibility
wli51
left a comment
There was a problem hiding this comment.
LGTM! Good move on obtaining richer logging.
Given that the granularity of metric computation you are aiming for here is starting to add substantial amount of code duplication, I suggest a quick refactor to abstract out the shared rolling accumulation functionality to benefit both current and future expansion. Please see my comments for details.
| return average_dists | ||
| return { | ||
| self.metric_name: average_dists.item(), | ||
| f"{self.metric_name}_std": 0.0, |
There was a problem hiding this comment.
Maybe the docstring should note that for this metric class metric std is not actually computed?
| self.total_squared_error = torch.tensor(0.0, device=self.device) | ||
| self.total_squared_error_sq = torch.tensor(0.0, device=self.device) | ||
| self.total_examples = torch.tensor(0.0, device=self.device) |
There was a problem hiding this comment.
These variable names made sense and are accurate but I think they are not the most clear and understandable especially given that the exact same variance compute shortcut of expectation of sum of squares - expectation squared pattern is identical across multiple metric classes. I wonder it would be most beneficial to the repo to abstract out the metric value std accumulation process as a standalone and reusable functionality so effectively redundant code can be minimized and future metric addition would be less work?
Given that sample mean and sample std are just the first and second central norm, naturally their rolling accumulation can be handled together as the following generalized form:
class StreamingScalarStats:
"""
Accumulates mean, variance, and std for a stream of per sample scalar values.
By default, computes population variance:
var = E[x^2] - E[x]^2
Set ddof=1 for sample variance.
"""
def __init__(
self,
device: torch.device | str = "cuda",
dtype: torch.dtype = torch.float64,
ddof: int = 0,
):
if ddof < 0:
raise ValueError("ddof must be non-negative")
self.device = device if isinstance(device, torch.device) else torch.device(device)
self.dtype = dtype
self.ddof = ddof
self.reset()
def reset(self) -> None:
self.sum_x = torch.tensor(0.0, device=self.device, dtype=self.dtype)
self.sum_x_sq = torch.tensor(0.0, device=self.device, dtype=self.dtype)
self.n = torch.tensor(0.0, device=self.device, dtype=self.dtype)
@torch.no_grad()
def update(self, values: torch.Tensor) -> None:
"""
Update accumulator with a tensor of scalar observations.
"""
values = values.detach().to(device=self.device, dtype=self.dtype).reshape(-1)
if values.numel() == 0:
return
self.sum_x += values.sum()
self.sum_x_sq += values.pow(2).sum()
self.n += values.numel()
def compute(self) -> dict[str, float]:
if self.n.item() == 0:
return {
"mean": 0.0,
"variance": 0.0,
"std": 0.0,
"n": 0.0,
}
mean = self.sum_x / self.n
denom = self.n - self.ddof
if denom <= 0:
variance = torch.tensor(0.0, device=self.device, dtype=self.dtype)
else:
# For ddof=0: var = sum_x_sq / n - mean^2
# For ddof=1: unbiased sample variance
correction = self.n / denom
variance = correction * (self.sum_x_sq / self.n - mean.pow(2))
variance = torch.clamp(variance, min=0.0)
std = torch.sqrt(variance)
return {
"mean": mean.item(),
"variance": variance.item(),
"std": std.item(),
"n": self.n.item(),
}this would reduce all of your metric functions to merely wrappers of some forward function to compute the sample level measure and StreamingScalarStats to perform the rolling accumulation.
def __init__(self, ...):
...
self.stats = StreamingScalarStats(device=self.device, ddof=0)
...
def forward(self,...):
# do something to get per sample measure tensor
def reset(self):
self.stats.reset()
...
def compute(self) -> dict[str, float]:
stats = self.stats.compute()
return {
self.metric_name: stats["mean"],
f"{self.metric_name}_std": stats["std"],
}
...
This pr logs metric standard deviation alongside mean in epoch evaluation.