Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/weathergen/model/engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,9 @@ def forward(self, batch, pe_embed):

# if the assert is hit, max_number_tokens_local_per_cell in config needs to be increased
max_tokens = self.cf.get("ae_local_max_tokens_per_cell", 64)
assert (
batch.tokens_lens.flatten(0, 2).sum(0).max() <= max_tokens
), "max number of tokens per cell for positional encoding exceeded."
assert batch.tokens_lens.flatten(0, 2).sum(0).max() <= max_tokens, (
"max number of tokens per cell for positional encoding exceeded."
)
" Increase ae_local_max_tokens_per_cell in config."

if batch.tokens_lens.shape[2] == 1:
Expand Down
36 changes: 33 additions & 3 deletions src/weathergen/train/loss_modules/loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,39 @@ def stats_normalized_erf(target, ens, mu, stddev):
return torch.mean(d * d) # + torch.mean( torch.sqrt( stddev) )


def mse_ens(target, ens, mu, stddev):
mse_loss = torch.nn.functional.mse_loss
return torch.stack([mse_loss(target, mem) for mem in ens], 0).mean()
def mse_ens(
target: torch.Tensor,
pred: torch.Tensor,
weights_channels: torch.Tensor | None,
weights_points: torch.Tensor | None,
use_ensemble_mean: bool = False,
):
"""
MSE loss for ensemble predictions, with two modes:

use_ensemble_mean=False (default):
Mean of per-member MSE — equivalent to mean(mse(target, mem) for mem in ens).
Penalises every member independently; each member is pushed toward the target.

use_ensemble_mean=True:
MSE of the ensemble mean against the target.
Collapses the ensemble to a single prediction before comparing, which
ignores spread and rewards a well-calibrated ensemble mean.

target : shape (num_data_points, num_channels)
pred : shape (ens_dim, num_data_points, num_channels)
weights_channels : shape (num_channels,) or None
weights_points : shape (num_data_points,) or None
"""
if use_ensemble_mean:
# lp_loss collapses the ensemble via .mean(0) before computing MSE
return mse(target, pred, weights_channels, weights_points)

losses, losses_chs = zip(
*[mse(target, member.unsqueeze(0), weights_channels, weights_points) for member in pred],
strict=False,
)
return torch.stack(list(losses)).mean(), torch.stack(list(losses_chs)).mean(0)


def kernel_crps(
Expand Down
Loading