-
Notifications
You must be signed in to change notification settings - Fork 1
Log Metric STD #22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Log Metric STD #22
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -31,6 +31,7 @@ def reset(self): | |
| """Reset running squared-error accumulators.""" | ||
|
|
||
| self.total_squared_error = torch.tensor(0.0, device=self.device) | ||
| self.total_squared_error_sq = torch.tensor(0.0, device=self.device) | ||
| self.total_examples = torch.tensor(0.0, device=self.device) | ||
|
Comment on lines
33
to
35
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These variable names made sense and are accurate but I think they are not the most clear and understandable especially given that the exact same variance compute shortcut of expectation of sum of squares - expectation squared pattern is identical across multiple metric classes. I wonder it would be most beneficial to the repo to abstract out the metric value std accumulation process as a standalone and reusable functionality so effectively redundant code can be minimized and future metric addition would be less work? Given that sample mean and sample std are just the first and second central norm, naturally their rolling accumulation can be handled together as the following generalized form: class StreamingScalarStats:
"""
Accumulates mean, variance, and std for a stream of per sample scalar values.
By default, computes population variance:
var = E[x^2] - E[x]^2
Set ddof=1 for sample variance.
"""
def __init__(
self,
device: torch.device | str = "cuda",
dtype: torch.dtype = torch.float64,
ddof: int = 0,
):
if ddof < 0:
raise ValueError("ddof must be non-negative")
self.device = device if isinstance(device, torch.device) else torch.device(device)
self.dtype = dtype
self.ddof = ddof
self.reset()
def reset(self) -> None:
self.sum_x = torch.tensor(0.0, device=self.device, dtype=self.dtype)
self.sum_x_sq = torch.tensor(0.0, device=self.device, dtype=self.dtype)
self.n = torch.tensor(0.0, device=self.device, dtype=self.dtype)
@torch.no_grad()
def update(self, values: torch.Tensor) -> None:
"""
Update accumulator with a tensor of scalar observations.
"""
values = values.detach().to(device=self.device, dtype=self.dtype).reshape(-1)
if values.numel() == 0:
return
self.sum_x += values.sum()
self.sum_x_sq += values.pow(2).sum()
self.n += values.numel()
def compute(self) -> dict[str, float]:
if self.n.item() == 0:
return {
"mean": 0.0,
"variance": 0.0,
"std": 0.0,
"n": 0.0,
}
mean = self.sum_x / self.n
denom = self.n - self.ddof
if denom <= 0:
variance = torch.tensor(0.0, device=self.device, dtype=self.dtype)
else:
# For ddof=0: var = sum_x_sq / n - mean^2
# For ddof=1: unbiased sample variance
correction = self.n / denom
variance = correction * (self.sum_x_sq / self.n - mean.pow(2))
variance = torch.clamp(variance, min=0.0)
std = torch.sqrt(variance)
return {
"mean": mean.item(),
"variance": variance.item(),
"std": std.item(),
"n": self.n.item(),
}this would reduce all of your metric functions to merely wrappers of some forward function to compute the sample level measure and StreamingScalarStats to perform the rolling accumulation. def __init__(self, ...):
...
self.stats = StreamingScalarStats(device=self.device, ddof=0)
...
def forward(self,...):
# do something to get per sample measure tensor
def reset(self):
self.stats.reset()
...
def compute(self) -> dict[str, float]:
stats = self.stats.compute()
return {
self.metric_name: stats["mean"],
f"{self.metric_name}_std": stats["std"],
}
... |
||
|
|
||
| def forward( | ||
|
|
@@ -58,6 +59,7 @@ def forward( | |
| per_sample_l2 = sq_error.mean(dim=1) | ||
|
|
||
| self.total_squared_error += per_sample_l2.sum().detach().to(self.device) | ||
| self.total_squared_error_sq += per_sample_l2.pow(2).sum().detach().to(self.device) | ||
| self.total_examples += torch.tensor( | ||
| per_sample_l2.numel(), | ||
| dtype=torch.float32, | ||
|
|
@@ -70,21 +72,33 @@ def update(self, generated_predictions: torch.Tensor, targets: torch.Tensor, **k | |
|
|
||
| self.forward(generated_predictions=generated_predictions, targets=targets, **kwargs) | ||
|
|
||
| def compute(self) -> torch.Tensor: | ||
| """Compute averaged L2 value for currently accumulated state. | ||
| def compute(self) -> dict[str, float]: | ||
| """Compute averaged L2 and population std for current state. | ||
|
|
||
| Returns: | ||
| Scalar tensor with current L2 value. | ||
| Dictionary containing mean and std metric values. | ||
| """ | ||
|
|
||
| average_l2 = torch.where( | ||
| self.total_examples > 0, | ||
| self.total_squared_error / self.total_examples, | ||
| torch.tensor(0.0, device=self.device), | ||
| ) | ||
| variance_l2 = torch.where( | ||
| self.total_examples > 0, | ||
| (self.total_squared_error_sq / self.total_examples) - average_l2.pow(2), | ||
| torch.tensor(0.0, device=self.device), | ||
| ) | ||
| std_l2 = torch.sqrt(torch.clamp(variance_l2, min=0.0)) | ||
| if not torch.isfinite(average_l2): | ||
| average_l2 = torch.tensor(0.0, device=self.device) | ||
| return average_l2 | ||
| if not torch.isfinite(std_l2): | ||
| std_l2 = torch.tensor(0.0, device=self.device) | ||
|
|
||
| return { | ||
| self.metric_name: average_l2.item(), | ||
| f"{self.metric_name}_std": std_l2.item(), | ||
| } | ||
|
|
||
| @property | ||
| def metric_name(self) -> str: | ||
|
|
@@ -93,8 +107,8 @@ def metric_name(self) -> str: | |
| return "l2_total" | ||
|
|
||
| def get_metric_data(self) -> dict[str, float]: | ||
| """Backward-compatible helper that computes and resets state.""" | ||
| """Compute metric stats and reset state.""" | ||
|
|
||
| metric_data = {self.metric_name: self.compute().item()} | ||
| metric_data = self.compute() | ||
| self.reset() | ||
| return metric_data | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe the docstring should note that for this metric class metric std is not actually computed?