Skip to content
37 changes: 37 additions & 0 deletions config/pipeline.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
stages:
- name: pretrain
stage: train
config_files:
- config/config_forecasting.yml
options:
- training_config.num_mini_epochs=32
chain_jobs: 4
nodes: 2
slurm_args:
- "--time=10:00:00"

- name: finetune
stage: train
from_run_id: STAGE.pretrain
config_files:
- config/config_forecasting_finetuning.yml
chain_jobs: 2
nodes: 2
slurm_args:
- "--time=10:00:00"

- name: inference
stage: inference
from_run_id: STAGE.pretrain
options:
- training_config.forecast.num_steps=120
- test_config.output.num_samples=10
- test_config.start_date=202310010000
- test_config.end_date=202312300000
- test_config.samples_per_mini_epoch=128
chain_jobs: 2
nodes: 1
slurm_args:
- "--time=10:00:00"


12 changes: 3 additions & 9 deletions src/weathergen/model/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,7 @@ def forward(self, x, x_lens, ada_ln_aux=None, coords=None):
ks = self.lnorm_k(self.proj_heads_k(x).reshape(s)).to(self.dtype)
vs = self.proj_heads_v(x).reshape(s)

qs, ks = apply_rope(
qs, ks, coords, self.rope_mode, 1
)
qs, ks = apply_rope(qs, ks, coords, self.rope_mode, 1)
if self.rope_post_mod_qk_lnorm:
qs = self.post_rope_lnorm_q(qs).to(self.dtype)
ks = self.post_rope_lnorm_k(ks).to(self.dtype)
Expand Down Expand Up @@ -302,9 +300,7 @@ def forward(self, x, coords=None, ada_ln_aux=None):
ks = self.lnorm_k(self.proj_heads_k(x).reshape(s)).to(self.dtype).permute([0, 2, 1, 3])
vs = self.proj_heads_v(x).reshape(s).permute([0, 2, 1, 3])

qs, ks = apply_rope(
qs, ks, coords, self.rope_mode, 1
)
qs, ks = apply_rope(qs, ks, coords, self.rope_mode, 1)
if self.rope_post_mod_qk_lnorm:
qs = self.post_rope_lnorm_q(qs).to(self.dtype)
ks = self.post_rope_lnorm_k(ks).to(self.dtype)
Expand Down Expand Up @@ -621,9 +617,7 @@ def forward(self, x, coords=None, ada_ln_aux=None):
ks = self.lnorm_k(self.proj_heads_k(x).reshape(s)).to(self.dtype)
vs = self.proj_heads_v(x).reshape(s).to(self.dtype)

qs, ks = apply_rope(
qs, ks, coords, self.rope_mode, 2
)
qs, ks = apply_rope(qs, ks, coords, self.rope_mode, 2)
if self.rope_post_mod_qk_lnorm:
qs = self.post_rope_lnorm_q(qs).to(self.dtype)
ks = self.post_rope_lnorm_k(ks).to(self.dtype)
Expand Down
213 changes: 195 additions & 18 deletions src/weathergen/model/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,17 @@
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
import logging
import math

import numpy as np
import torch
from astropy_healpix import healpy
from torch.utils.checkpoint import checkpoint

from weathergen.common.config import Config
from weathergen.datasets.batch import ModelBatch
from weathergen.datasets.utils import healpix_verts_rots, r3tos2
from weathergen.model.engines import (
EmbeddingEngine,
GlobalAssimilationEngine,
Expand All @@ -24,7 +28,15 @@

# from weathergen.model.model import ModelParams
from weathergen.model.parametrised_prob_dist import LatentInterpolator
from weathergen.model.positional_encoding import positional_encoding_harmonic
from weathergen.model.positional_encoding import (
build_spherical_rope_coeff_tensors,
get_rope_mode,
get_rope_spherical_band,
positional_encoding_harmonic,
)
from weathergen.utils.utils import get_dtype

logger = logging.getLogger(__name__)


class EncoderModule(torch.nn.Module):
Expand All @@ -44,7 +56,76 @@ def __init__(self, cf: Config, sources_size, targets_num_channels, targets_coord
self.healpix_level = cf.healpix_level
self.num_healpix_cells = 12 * 4**self.healpix_level

self.cf = cf
self.dtype = get_dtype(cf.attention_dtype)

# Positional embeddings
self.max_tokens_local_per_cell = cf.get("ae_local_max_tokens_per_cell", 64)
self.register_buffer(
"pe_embed",
torch.zeros(self.max_tokens_local_per_cell, cf.ae_local_dim_embed, dtype=self.dtype),
)

self.register_buffer(
"q_cells_lens", torch.ones(self.num_healpix_cells + 1, dtype=torch.int32)
)
self.q_cells_lens[0] = 0

self.register_buffer(
"pe_global",
torch.zeros(self.num_healpix_cells, cf.ae_local_num_queries, cf.ae_global_dim_embed, dtype=self.dtype),
)


# RoPE coordinates
self.rope_mode = get_rope_mode(cf, logger)
if self.rope_mode != "none":
self.num_extra_tokens = cf.num_register_tokens + cf.num_class_tokens
total_tokens = (
self.num_healpix_cells + self.num_extra_tokens
) * cf.ae_local_num_queries
self.register_buffer(
"rope_coords",
torch.zeros(
1,
total_tokens,
2,
dtype=self.dtype,
),
)
self.register_buffer(
"rope_cell_coords",
torch.zeros(
self.num_healpix_cells,
2,
dtype=self.dtype,
),
)
if self.rope_mode == "spherical":
rope_spherical_band = get_rope_spherical_band(cf)
num_modes = 2 * int(rope_spherical_band) + 1
self.register_buffer(
"rope_spherical_coeffs",
torch.zeros(1, total_tokens, num_modes, 2, dtype=self.dtype),
)
self.register_buffer(
"rope_spherical_cell_coeffs",
torch.zeros(self.num_healpix_cells, num_modes, 2, dtype=self.dtype),
)
self.register_buffer(
"rope_spherical_extra_coeffs",
torch.zeros(self.num_extra_tokens, num_modes, 2, dtype=self.dtype),
)
else:
self.rope_spherical_coeffs = None
self.rope_spherical_cell_coeffs = None
self.rope_spherical_extra_coeffs = None
else:
self.rope_coords = None
self.rope_cell_coords = None
self.rope_spherical_coeffs = None
self.rope_spherical_cell_coeffs = None
self.rope_spherical_extra_coeffs = None

self.sources_size = sources_size
self.targets_num_channels = targets_num_channels
self.targets_coords_size = targets_coords_size
Expand Down Expand Up @@ -117,33 +198,131 @@ def __init__(self, cf: Config, sources_size, targets_num_channels, targets_coord
# global assimilation engine
self.ae_global_engine = GlobalAssimilationEngine(cf, self.num_healpix_cells)

def forward(self, model_params, batch):
def reset_parameters(self) -> None:
"""Creates positional embedding for each grid point for each stream used after stream
embedding, positional embedding for all stream assimilated cell-level local embedding,
initializing queries for local-to-global adapters, HEALPix neighbourhood based parameter
initializing for target prediction.

Sinusoidal positional encoding: Harmonic positional encoding based upon sine and cosine for
both per stream after stream embedding and per cell level for local assimilation.

Query len based parameter creation: Calculate parameters for the calculated token length at
each cell after local assimilation."""

cf = self.cf

dim_embed = cf.ae_local_dim_embed
token_idx_bias = 16
freq_bias = 8
self.pe_embed.data.fill_(0.0)
position = torch.arange(
token_idx_bias,
token_idx_bias + self.max_tokens_local_per_cell,
device=self.pe_embed.device,
).unsqueeze(1)
div = torch.exp(
torch.arange(freq_bias, freq_bias + dim_embed, 2, device=self.pe_embed.device)
* -(math.log(self.max_tokens_local_per_cell) / dim_embed),
)
self.pe_embed.data[:, 0::2] = torch.sin(position * div[: self.pe_embed[:, 0::2].shape[1]])
self.pe_embed.data[:, 1::2] = torch.cos(position * div[: self.pe_embed[:, 1::2].shape[1]])

dim_embed = cf.ae_global_dim_embed

if self.rope_mode != "none":
verts, _ = healpix_verts_rots(self.healpix_level, 0.5, 0.5)
coords = r3tos2(verts.to(self.rope_coords.device)).to(self.rope_coords.dtype)
self.rope_cell_coords.data.copy_(coords)
coords = coords.unsqueeze(1).repeat(1, cf.ae_local_num_queries, 1)
coords_flat = coords.flatten(0, 1).unsqueeze(0)
offset = self.num_extra_tokens * cf.ae_local_num_queries
self.rope_coords.data.fill_(0.0)
self.rope_coords.data[:, offset : offset + coords_flat.shape[1], :].copy_(coords_flat)

if self.rope_mode == "spherical":
band = int(get_rope_spherical_band(cf))
(
(cell_real, cell_imag),
(extra_real, extra_imag),
(packed_extra_real, packed_extra_imag),
(packed_real, packed_imag),
) = build_spherical_rope_coeff_tensors(
nside=2**self.healpix_level,
band=band,
num_local_queries=cf.ae_local_num_queries,
num_extra_tokens=self.num_extra_tokens,
device=self.rope_spherical_coeffs.device,
dtype=self.rope_spherical_coeffs.dtype,
)
self.rope_spherical_cell_coeffs.data[..., 0].copy_(cell_real)
self.rope_spherical_cell_coeffs.data[..., 1].copy_(cell_imag)
self.rope_spherical_extra_coeffs.data[..., 0].copy_(extra_real)
self.rope_spherical_extra_coeffs.data[..., 1].copy_(extra_imag)

self.rope_spherical_coeffs.data.fill_(0.0)
self.rope_spherical_coeffs.data[:, :offset, :, 0].copy_(packed_extra_real)
self.rope_spherical_coeffs.data[:, :offset, :, 1].copy_(packed_extra_imag)
self.rope_spherical_coeffs.data[
:, offset : offset + packed_real.shape[1], :, 0
].copy_(packed_real)
self.rope_spherical_coeffs.data[
:, offset : offset + packed_imag.shape[1], :, 1
].copy_(packed_imag)

self.pe_global.data.fill_(0.0)
xs = 2.0 * np.pi * torch.arange(0, dim_embed, 2, device=self.pe_global.device) / dim_embed
self.pe_global.data[..., 0::2] = 0.5 * torch.sin(
torch.outer(8 * torch.arange(cf.ae_local_num_queries, device=self.pe_global.device), xs)
)
self.pe_global.data[..., 0::2] += (
torch.sin(
torch.outer(torch.arange(self.num_healpix_cells, device=self.pe_global.device), xs)
)
.unsqueeze(1)
.repeat((1, cf.ae_local_num_queries, 1))
)
self.pe_global.data[..., 1::2] = 0.5 * torch.cos(
torch.outer(8 * torch.arange(cf.ae_local_num_queries, device=self.pe_global.device), xs)
)
self.pe_global.data[..., 1::2] += (
torch.cos(
torch.outer(torch.arange(self.num_healpix_cells, device=self.pe_global.device), xs)
)
.unsqueeze(1)
.repeat((1, cf.ae_local_num_queries, 1))
)

self.q_cells_lens.data.fill_(1)
self.q_cells_lens.data[0] = 0

def forward(self, batch):
"""
Encoder forward
"""

stream_cell_tokens = checkpoint(
self.embed_engine, batch, model_params.pe_embed, use_reentrant=False
self.embed_engine, batch, self.pe_embed, use_reentrant=False
)

tokens_global, posteriors = checkpoint(
self.assimilate_local, model_params, stream_cell_tokens, batch, use_reentrant=False
self.assimilate_local, stream_cell_tokens, batch, use_reentrant=False
)

tokens_global = checkpoint(
self.ae_global_engine,
tokens_global,
coords=(
model_params.rope_spherical_coeffs.unbind(dim=-1)
if model_params.rope_spherical_coeffs is not None
else model_params.rope_coords
self.rope_spherical_coeffs.unbind(dim=-1)
if self.rope_spherical_coeffs is not None
else self.rope_coords
),
use_reentrant=False,
)

return tokens_global, posteriors

def interpolate_latents(self, tokens: torch.Tensor) -> (torch.Tensor, torch.Tensor):
def interpolate_latents(self, tokens: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
""" "
TODO
"""
Expand Down Expand Up @@ -289,9 +468,7 @@ def aggregation_engine_unmasked(

return tokens_global_unmasked

def assimilate_local(
self, model_params, tokens: torch.Tensor, batch: ModelBatch
) -> torch.Tensor:
def assimilate_local(self, tokens: torch.Tensor, batch: ModelBatch) -> torch.Tensor:
"""
Processes embedded tokens locally and prepares them for the global assimilation

Expand All @@ -316,25 +493,25 @@ def assimilate_local(

# TODO: re-enable or remove ae_local_queries_per_cell
if self.cf.ae_local_queries_per_cell:
tokens_global = (self.q_cells + model_params.pe_global).repeat(rs, 1, 1)
tokens_global = (self.q_cells + self.pe_global).repeat(rs, 1, 1)
else:
num_tokens = self.num_healpix_cells
tokens_global = self.q_cells.repeat(num_tokens, 1, 1) + model_params.pe_global
tokens_global = self.q_cells.repeat(num_tokens, 1, 1) + self.pe_global
tokens_global = tokens_global.repeat(rs, 1, 1)

# apply local assimilation engine and project onto global latent vectors
tokens_global_unmasked, posteriors = self.assimilate_local_project_chunked(
tokens, tokens_global, cell_lens, model_params.q_cells_lens
tokens, tokens_global, cell_lens, self.q_cells_lens
)

# apply aggregation engine on unmasked tokens
tokens_global_unmasked = self.aggregation_engine_unmasked(
tokens_global_unmasked,
tokens_global_register_class,
batch.tokens_lens,
rope_cell_coords=model_params.rope_cell_coords,
rope_cell_coeffs=model_params.rope_spherical_cell_coeffs,
rope_extra_coeffs=model_params.rope_spherical_extra_coeffs,
rope_cell_coords=self.rope_cell_coords,
rope_cell_coeffs=self.rope_spherical_cell_coeffs,
rope_extra_coeffs=self.rope_spherical_extra_coeffs,
)

# final processing
Expand Down
Loading
Loading