Skip to content

Log Metric STD#22

Open
MattsonCam wants to merge 1 commit into
mainfrom
metric_std_logging
Open

Log Metric STD#22
MattsonCam wants to merge 1 commit into
mainfrom
metric_std_logging

Conversation

@MattsonCam
Copy link
Copy Markdown
Member

This pr logs metric standard deviation alongside mean in epoch evaluation.

Refactor metric aggregation to report both mean and population standard
deviation (ddof=0) for training/evaluation metrics. Update callback
logging flow so all returned metric stats are logged per split.

- Change AbstractMetric.compute() to return metric stats dicts
- Update L1/L2/Pearson to accumulate count/sum/sum_sq and compute std
- Update PSNR/SSIM to aggregate per-sample values and compute std
- Keep LPIPS/DISTS compatible with new stats interface
- Update epoch evaluator to log all metric stat keys with split suffixes
- Preserve training behavior while expanding MLflow metric visibility
@MattsonCam MattsonCam requested a review from wli51 May 28, 2026 20:16
Copy link
Copy Markdown

@wli51 wli51 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Good move on obtaining richer logging.
Given that the granularity of metric computation you are aiming for here is starting to add substantial amount of code duplication, I suggest a quick refactor to abstract out the shared rolling accumulation functionality to benefit both current and future expansion. Please see my comments for details.

Comment thread metrics/DISTS.py
return average_dists
return {
self.metric_name: average_dists.item(),
f"{self.metric_name}_std": 0.0,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe the docstring should note that for this metric class metric std is not actually computed?

Comment thread metrics/L2.py
Comment on lines 33 to 35
self.total_squared_error = torch.tensor(0.0, device=self.device)
self.total_squared_error_sq = torch.tensor(0.0, device=self.device)
self.total_examples = torch.tensor(0.0, device=self.device)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These variable names made sense and are accurate but I think they are not the most clear and understandable especially given that the exact same variance compute shortcut of expectation of sum of squares - expectation squared pattern is identical across multiple metric classes. I wonder it would be most beneficial to the repo to abstract out the metric value std accumulation process as a standalone and reusable functionality so effectively redundant code can be minimized and future metric addition would be less work?

Given that sample mean and sample std are just the first and second central norm, naturally their rolling accumulation can be handled together as the following generalized form:

class StreamingScalarStats:
    """
    Accumulates mean, variance, and std for a stream of per sample scalar values.
    By default, computes population variance:
        var = E[x^2] - E[x]^2
    Set ddof=1 for sample variance.
    """

    def __init__(
        self,
        device: torch.device | str = "cuda",
        dtype: torch.dtype = torch.float64,
        ddof: int = 0,
    ):
        if ddof < 0:
            raise ValueError("ddof must be non-negative")

        self.device = device if isinstance(device, torch.device) else torch.device(device)
        self.dtype = dtype
        self.ddof = ddof
        self.reset()

    def reset(self) -> None:
        self.sum_x = torch.tensor(0.0, device=self.device, dtype=self.dtype)
        self.sum_x_sq = torch.tensor(0.0, device=self.device, dtype=self.dtype)
        self.n = torch.tensor(0.0, device=self.device, dtype=self.dtype)

    @torch.no_grad()
    def update(self, values: torch.Tensor) -> None:
        """
        Update accumulator with a tensor of scalar observations.
        """
        values = values.detach().to(device=self.device, dtype=self.dtype).reshape(-1)

        if values.numel() == 0:
            return

        self.sum_x += values.sum()
        self.sum_x_sq += values.pow(2).sum()
        self.n += values.numel()

    def compute(self) -> dict[str, float]:
        if self.n.item() == 0:
            return {
                "mean": 0.0,
                "variance": 0.0,
                "std": 0.0,
                "n": 0.0,
            }

        mean = self.sum_x / self.n

        denom = self.n - self.ddof
        if denom <= 0:
            variance = torch.tensor(0.0, device=self.device, dtype=self.dtype)
        else:
            # For ddof=0: var = sum_x_sq / n - mean^2
            # For ddof=1: unbiased sample variance
            correction = self.n / denom
            variance = correction * (self.sum_x_sq / self.n - mean.pow(2))

        variance = torch.clamp(variance, min=0.0)
        std = torch.sqrt(variance)

        return {
            "mean": mean.item(),
            "variance": variance.item(),
            "std": std.item(),
            "n": self.n.item(),
        }

this would reduce all of your metric functions to merely wrappers of some forward function to compute the sample level measure and StreamingScalarStats to perform the rolling accumulation.

def __init__(self, ...):
        ...
        self.stats = StreamingScalarStats(device=self.device, ddof=0)
        ...
def forward(self,...):
        # do something to get per sample measure tensor
def reset(self):
        self.stats.reset()
        ...
def compute(self) -> dict[str, float]:
        stats = self.stats.compute()

        return {
            self.metric_name: stats["mean"],
            f"{self.metric_name}_std": stats["std"],
        }
        ...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants