Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 84 additions & 16 deletions climanet/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import numpy as np
from .utils import add_month_day_dims, calc_stats
from .geo_embedding_utils import calculate_SH_geo_pos_embeddings, compute_patch_geo_pos_embedding
from .geo_embedding_utils import compute_patch_scale_features
import xarray as xr
import torch
from torch.utils.data import Dataset
Expand All @@ -20,13 +22,21 @@ def __init__(
spatial_dims: Tuple[str, str] = ("lat", "lon"),
patch_size: Tuple[int, int] = (16, 16), # (lat, lon)
stride: Tuple[int, int] = None,
sh_pos_table: str = None,
sh_embed_dim: int = 96, # sh_embed_dim should <= (sh_order_L + 1)**2
sh_order_L: int = 10,
):
self.spatial_dims = spatial_dims
self.patch_size = patch_size
self.daily_da = daily_da
self.monthly_da = monthly_da
self.stride = stride if stride is not None else patch_size

self.sh_embed_dim = sh_embed_dim
self.sh_order_L = sh_order_L



# Check that the input data has the expected dimensions
if time_dim not in daily_da.dims or time_dim not in monthly_da.dims:
raise ValueError(f"Time dimension '{time_dim}' not found in input data")
Expand Down Expand Up @@ -84,6 +94,21 @@ def __init__(
H, W = self.daily_np.shape[2], self.daily_np.shape[3]
self.patch_indices = self._compute_patch_indices(H, W)

# Precompute geoposition and scale embeddings for patches
self.geo_pos_table = self._set_geo_pos_table(sh_pos_table)
self.patch_geo_embeddings, self.patch_scale_features = self._compute_geoscalepatch_embeddings()



def _set_geo_pos_table(self, sh_pos_table: str):
""" Calculate or retrieve spherical harmonics based geo position embeddings."""
if sh_pos_table is None:
self.sh_geo_pos = calculate_SH_geo_pos_embeddings(self.lat_coords,
self.lon_coords, self.sh_order_L, self.sh_embed_dim)
else:
#load then set embed dim and sh order L from here
raise(RuntimeError('load method not implemented'))

def _compute_patch_indices(self, H: int, W: int) -> list:
"""Generate patch start indices with coverage warning (overlap support)."""
ph, pw = self.patch_size
Expand Down Expand Up @@ -126,6 +151,27 @@ def _compute_patch_indices(self, H: int, W: int) -> list:
print(f"Overlap: {overlap_h} pixels (height), {overlap_w} pixels (width)")

return [(i, j) for i in i_starts for j in j_starts]

def _compute_geoscalepatch_embeddings(self):
patch_geo_embeddings = []
patch_scale_features = []

for i, j in self.patch_indices:
ph, pw = self.patch_size
geo_pos_tensor = self.sh_geo_pos[i:i+ph, j:j+pw,]
lat_patch = self.lat_coords[i:i+ph]
lon_patch = self.lon_coords[j:j+pw]

geo_emb = compute_patch_geo_pos_embedding(geo_pos_tensor,lat_patch,)
scale_feat = compute_patch_scale_features( lat_patch, lon_patch,)

patch_geo_embeddings.append(geo_emb)
patch_scale_features.append(scale_feat)

patch_geo_embeddings = torch.stack(patch_geo_embeddings)
patch_scale_features = torch.stack(patch_scale_features )

return patch_geo_embeddings, patch_scale_features

def __len__(self):
return len(self.patch_indices)
Expand All @@ -140,49 +186,71 @@ def __getitem__(self, idx):
ph, pw = self.patch_size

# Extract spatial patch via numpy slicing — faster than xarray indexing
daily_patch = self.daily_np[:, :, i : i + ph, j : j + pw] # (M, T, H, W)
monthly_patch = self.monthly_np[:, i : i + ph, j : j + pw] # (M, H, W)
daily_patch = self.daily_np[:, :, i : i + ph, j : j + pw] # (M, T, H, W) -> (M,T,pH, pW)
monthly_patch = self.monthly_np[:, i : i + ph, j : j + pw] # (M, H, W) -> (M, pH, pW)
daily_nan_mask = self.daily_nan_mask[
:, :, i : i + ph, j : j + pw
] # (M, T, H, W)
] # (M, T, H, W) -> (M, T, pH, pW)

if self.land_mask_np is not None:
land_patch = self.land_mask_np[i : i + ph, j : j + pw] # (H, W)
land_patch = self.land_mask_np[i : i + ph, j : j + pw] # (H, W) -> (pH,pW)
land_tensor = torch.from_numpy(land_patch.copy()).bool()
else:
land_tensor = torch.zeros(ph, pw, dtype=torch.bool)

#geo_pos_tensor = self.sh_geo_pos[i: i + ph, j: j + pw] # (H,W, sh_emb_dim) -> (pH, pW, sh_embed_dim)


# Convert to tensors (from_numpy is zero-copy on contiguous arrays)
# (1, M, T, H, W)
# (1, M, T, pH, pW)
daily_tensor = torch.from_numpy(daily_patch).float().unsqueeze(0)
# (M, H, W)
# (M, pH, pW)
monthly_tensor = torch.from_numpy(monthly_patch).float()
# (1, M, T, H, W)
# (1, M, T, pH, pW)
daily_nan_mask = torch.from_numpy(daily_nan_mask).unsqueeze(0)
# ( M, T, 2)
daily_timef_tensor = torch.from_numpy(self.daily_timef_np).float()

# daily_mask: NaN locations that are NOT land
# Reshape land_tensor for broadcasting: (H, W) → (1, 1, 1, H, W)
# Reshape land_tensor for broadcasting: (pH, pW) → (1, 1, 1, pH, pW)
daily_mask_tensor = daily_nan_mask & (
~land_tensor.unsqueeze(0).unsqueeze(0).unsqueeze(0)
)

# Extract lat/lon coordinates for this patch
lat_patch = self.lat_coords[i : i + ph]
lon_patch = self.lon_coords[j : j + pw]
lat_patch = self.lat_coords[i : i + ph] # (H,) -> (pH,)
lon_patch = self.lon_coords[j : j + pw] # (W,) -> (pW,)

#get patch geo pos embedding
#geo_pos_embedding_tensor = compute_patch_geo_pos_embedding(geo_pos_tensor, lat_patch)
geo_pos_embedding_tensor = self.patch_geo_embeddings[idx]

#get scale feature for patch
#scale_feature_tensor = compute_patch_scale_features(lat_patch, lon_patch) # -> (10,)
scale_feature_tensor = self.patch_scale_features[idx]

#create tensors to pass sh embedding dimension, harmonic order, and scale feature dim
sh_embed_dim = torch.tensor(self.sh_embed_dim)
harmonic_order = torch.tensor(self.sh_order_L)
scale_f_dim = torch.tensor(len(scale_feature_tensor))

# Convert to tensors
return {
"daily_patch": daily_tensor, # (C=1, M, T=31, H, W)
"monthly_patch": monthly_tensor, # (M, H, W)
"daily_mask_patch": daily_mask_tensor, # (C=1, M, T=31, H, W)
"land_mask_patch": land_tensor, # (H,W) True=Land
"daily_patch": daily_tensor, # (C=1, M, T=31, pH, pW)
"monthly_patch": monthly_tensor, # (M, pH, pW)
"daily_mask_patch": daily_mask_tensor, # (C=1, M, T=31, pH, pW)
"land_mask_patch": land_tensor, # (pH,pW) True=Land
"daily_timef_patch": daily_timef_tensor, #(M, T=31, 2)
"padded_days_mask": self.padded_days_tensor, # (M, T=31) True=padded
#"sh_geo_pos_patch": geo_pos_tensor, # (pH, pW, sh_embed_dim)
"scale_feature_patch": scale_feature_tensor, #(10,)
"geo_pos_embedding_patch": geo_pos_embedding_tensor, #(sh_embed_dim,)
"sh_embed_dim": sh_embed_dim,
"harmonic_order": harmonic_order,
"scale_f_dim":scale_f_dim,
"coords": (i, j),
"lat_patch": lat_patch, # (H,)
"lon_patch": lon_patch, # (W,)
"lat_patch": lat_patch, # (pH,)
"lon_patch": lon_patch, # (pW,)
}

def compute_stats(self, indices: list = None) -> Tuple[np.ndarray, np.ndarray]:
Expand Down
Loading
Loading