Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@ This link points you to additional references for setting up your environment co

4. Install `pre-commit` hooks using `pre-commit install`.

5. INstall netcdf4 with pip. After activating your environment, run `pip install netcdf4`. This package cannot be installed with poetry because of dependencies.

### 3. Downloading input data

For running the model on real cliamte data, please download monthly climate model data and regrid it to an icosahedral grid using ClimateSet https://github.com/RolnickLab/ClimateSet.
Expand Down
162 changes: 145 additions & 17 deletions climatem/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@


class expParams:
"""Experiment setup: paths, dimensions, random seed, and hardware config."""

def __init__(
self,
exp_path, # Path to where the output will be saved i.e. model runs, plots
Expand Down Expand Up @@ -38,6 +40,8 @@ def __init__(


class dataParams:
"""Data loading: paths, scenarios, variables, batch size, and preprocessing options."""

def __init__(
self,
data_dir, # The processed (normalized, deseasonalized, numpy...) data will be stored here
Expand All @@ -64,6 +68,7 @@ def __init__(
channels_last: bool = False, # last dimension of data is the channel
ishdf5: bool = False, # numpy vs hdf5. for now only numpy is supported. Redundant with next param
data_format: str = "numpy", # numpy vs hdf5. for now only numpy is supported
forcing_conditioning: str = "raw", # how to condition on forcings: raw | template | mode | region (SAVAR)
seq_to_seq: bool = True, # predicting a sequence from a sequence?
train_val_interval_length: int = 11,
load_train_into_mem: bool = True,
Expand Down Expand Up @@ -100,6 +105,7 @@ def __init__(
self.channels_last = channels_last
self.ishdf5 = ishdf5
self.data_format = data_format
self.forcing_conditioning = forcing_conditioning
self.seq_to_seq = seq_to_seq
self.train_val_interval_length = train_val_interval_length
self.load_train_into_mem = load_train_into_mem
Expand All @@ -123,6 +129,8 @@ def __init__(


class trainParams:
"""Training loop: learning rate, iterations, patience for phase transition, and validation frequency."""

def __init__(
self,
ratio_train: float = 0.9,
Expand All @@ -147,6 +155,8 @@ def __init__(


class modelParams:
"""Model architecture: latent dynamics type, MLP sizes, embedding, and causal mask options."""

def __init__(
self,
instantaneous: bool = False, # Allow instantaneous connections?
Expand All @@ -160,11 +170,21 @@ def __init__(
num_layers: int = 2,
num_output: int = 2, # NOT SURE
position_embedding_dim: int = 100, # Dimension of positional embedding
reduce_encoding_pos_dim: bool = False, # Reduce encoder positional embedding dimension by x10
tau_neigh: int = 0, # Legacy neighborhood radius used in older configs
hard_gumbel: bool = False, # Legacy mask sampling flag used in analysis scripts
transition_param_sharing: bool = True,
position_embedding_transition: int = 100,
fixed: bool = False, # Do we fix the causal graph? Should be in gt_params maybe
fixed_output_fraction=None, # This is used if we fix the mask, and want to get a fix number of 0 and 1
constraint_func: str = "trace", # This is used for the constraint - trace is the correct one here
use_exogenous: bool = False, # NEW: Enable conditioning on exogenous forcings (CO2 + aerosols)
d_y_co2: int = 1, # NEW: Dimension of CO2 forcing (typically 1 for global, or spatial_dim for local)
d_y_aerosol: int = 900, # NEW: Dimension of aerosol forcing (typically spatial_dim for local effects)
use_forced_latents: bool = False, # NEW: Map forcings directly to dedicated latent dimensions
n_forced_latents_co2: int = 1, # NEW: Number of latent dimensions for CO2 forcing
n_forced_latents_aerosol: int = 2, # NEW: Number of latent dimensions for aerosol forcing
forcing_arch: str = "baseline", # NEW: baseline | transitioned | predefined
):
self.instantaneous = instantaneous
self.no_w_constraint = no_w_constraint
Expand All @@ -177,14 +197,26 @@ def __init__(
self.num_hidden_mixing = num_hidden_mixing
self.num_layers_mixing = num_layers_mixing
self.position_embedding_dim = position_embedding_dim
self.reduce_encoding_pos_dim = reduce_encoding_pos_dim
self.tau_neigh = tau_neigh
self.hard_gumbel = hard_gumbel
self.transition_param_sharing = transition_param_sharing
self.position_embedding_transition = position_embedding_transition
self.fixed = fixed
self.fixed_output_fraction = fixed_output_fraction
self.constraint_func = constraint_func
self.use_exogenous = use_exogenous
self.d_y_co2 = d_y_co2
self.d_y_aerosol = d_y_aerosol
self.use_forced_latents = use_forced_latents
self.n_forced_latents_co2 = n_forced_latents_co2
self.n_forced_latents_aerosol = n_forced_latents_aerosol
self.forcing_arch = forcing_arch


class optimParams:
"""Optimization: loss coefficients, ALM penalty parameters, and constraint schedules."""

def __init__(
self,
optimizer: str = "rmsprop",
Expand Down Expand Up @@ -227,6 +259,11 @@ def __init__(
acyclic_min_iter_convergence: float = 1_000,
mu_acyclic_init: float = 0,
h_acyclic_threshold: float = 0,
forcing_co2_coeff: float = 10.0, # Weight for CO2 forcing reconstruction loss
forcing_aerosol_coeff: float = 10.0, # Weight for aerosol forcing reconstruction loss
forcing_latent_supervision_coeff: float = 10.0, # Weight for direct forcing latent supervision loss
decoder_utilization_coeff: float = 0.1, # Penalty coefficient for underutilized forcing latent decoder weights
min_forcing_decoder_norm: float = 1.5, # Target minimum L2 norm for forcing latent decoder weights
udpate_ALM_using_valid: bool = True, # If False use training loss convergence if True uses valid loss convergence
udpate_ALM_using_nll: bool = True, # If False use augmented loss convergence if True uses NLL convergence
):
Expand Down Expand Up @@ -275,11 +312,19 @@ def __init__(
self.mu_acyclic_init = mu_acyclic_init
self.h_acyclic_threshold = h_acyclic_threshold

self.forcing_co2_coeff = forcing_co2_coeff
self.forcing_aerosol_coeff = forcing_aerosol_coeff
self.forcing_latent_supervision_coeff = forcing_latent_supervision_coeff
self.decoder_utilization_coeff = decoder_utilization_coeff
self.min_forcing_decoder_norm = min_forcing_decoder_norm

self.udpate_ALM_using_valid = udpate_ALM_using_valid
self.udpate_ALM_using_nll = udpate_ALM_using_nll


class plotParams:
"""Plotting frequency and toggle options for training diagnostics."""

def __init__(
self, plot_freq: int = 500, plot_through_time: bool = True, print_freq: int = 500, savar: bool = False
):
Expand All @@ -290,25 +335,80 @@ def __init__(


class savarParams:
# Params for generating synthetic data
"""
Configuration for SAVAR synthetic data generation.

Controls all aspects of the Seasonal Vector Auto-Regressive data generator:
spatial grid, temporal length, causal graph structure, seasonality, external
forcing (CO2 + aerosol), noise characteristics, and background state.
See ``climatem/synthetic_data/savar.py`` for the generator implementation.
"""

def __init__(
self,
time_len: int = 10_000, # Time length of the data
comp_size: int = 10, # Each component size
noise_val: float = 0.2, # Noise variance relative to signal
n_per_col: int = 2, # square grid, equivalent of lat/lon
difficulty: str = "easy", # easy, med_easy, med_hard, hard: difficulty of the graph
seasonality: bool = False, # Seasonality in synthetic data
overlap: float = 0, # Modes overlap between 0 and 1
is_forced: bool = False, # Forcings in synthetic data
f_1: int = 1,
f_2: int = 2,
f_time_1: int = 4000,
f_time_2: int = 8000,
ramp_type: str = "linear",
linearity: str = "linear",
poly_degrees: List[int] = [2],
plot_original_data: bool = True,
# Basic data generation parameters
time_len: int = 10_000, # Total number of timesteps to generate (longer = more data for training)
comp_size: int = 10, # Size of each spatial component/mode
noise_val: float = 0.02, # Noise strength relative to signal (higher = noisier data)
n_per_col: int = 2, # Number of grid points per row/column in square spatial grid (total spatial size = n_per_col^2 * comp_size)
# Causal graph structure
difficulty: str = "easy", # Complexity of causal graph: "easy" (sparse), "med_easy", "med_hard", "hard" (dense/complex)
# Seasonality parameters
seasonality: bool = False, # Whether to add seasonal variations (e.g., annual cycles like climate data)
periods: List[float] = [
365,
182.5,
60,
], # Seasonal periods in days (e.g., annual=365, semi-annual=182.5, bi-monthly=60)
amplitudes: List[float] = [0.06, 0.02, 0.01], # Amplitude of each seasonal component (matched to periods list)
phases: List[float] = [
0.0,
0.7853981634,
1.5707963268,
], # Phase shifts for seasonality in radians (0, π/4, π/2)
yearly_jitter_amp: float = 0.05, # Year-to-year random variation in seasonal amplitude (adds realism)
yearly_jitter_phase: float = 0.10, # Year-to-year random variation in seasonal phase (adds realism)
# Spatial structure
overlap: float = 0, # Whether spatial modes can overlap between 0 and 1 (True = modes share spatial regions)
# External forcing parameters
is_forced: bool = False, # Whether to include external forcings like CO2 and aerosols (mimics climate change)
f_1: int = 0, # Initial forcing value at start of ramp (baseline level). NOTE: used as float downstream
f_2: int = 1, # Final forcing value at end of ramp (target level). NOTE: used as float downstream
f_time_1: int = 4000, # Timestep when forcing ramp begins (relative to start after transient)
f_time_2: int = 8000, # Timestep when forcing ramp ends and forcing becomes constant at f_2
ramp_type: str = "linear", # Temporal evolution of forcing: "linear", "quadratic", "exponential", "sigmoid", "sinusoidal"
# Dynamics type
linearity: str = "linear", # Type of dynamics: "linear" (VAR model), "polynomial", or "nonlinear" (neural net)
poly_degrees: List[int] = [
2
], # Polynomial degrees to use if linearity="polynomial" (e.g., [2] for quadratic, [2,3] for quad+cubic)
# Visualization
plot_original_data: bool = True, # Whether to generate plots during data generation
# Separate forcing fields (more realistic than single forcing)
use_separate_forcings: bool = False, # Use distinct CO2 and aerosol forcing fields with different dynamics
forcing_amplification: float = 1.2, # Overall scaling factor for forcing magnitudes
# Aerosol forcing parameters
aerosol_scale: float = 0.02, # Strength of aerosol forcing (typically negative for cooling effect, positive here for magnitude)
aerosol_spatial_contrast: float = 1.05, # Regional variability of aerosol effects (>1 increases heterogeneity across space)
aerosol_ramp_up_time: int = 2000, # When aerosol forcing starts increasing (default: 20% of time_len)
aerosol_peak_time: int = 5000, # When aerosol forcing reaches maximum (default: 50% of time_len)
aerosol_decline_time: int = 8000, # When aerosol forcing finishes declining to baseline (default: 80% of time_len)
aerosol_timing_stagger: float = 0.3, # Fraction of timeline to stagger aerosol latents (creates distinct temporal patterns per latent)
# Forcing causal structure parameters
n_co2_latents: int = 1, # Number of latent variables representing CO2 forcing in causal graph (typically 1 for global)
n_aerosol_latents: int = 2, # Number of latent variables representing aerosol forcing (multiple for regional effects)
co2_effect_strength: float = 0.25, # Causal coefficient strength for CO2 → climate mode links (larger = stronger influence)
aerosol_effect_strength: float = 0.20, # Causal coefficient strength for aerosol → climate mode links (larger = stronger influence)
# Noise temporal correlation (AR(1) / Ornstein-Uhlenbeck)
noise_ar1_rho: float = 0.95, # AR(1) persistence parameter ρ (0=white noise, 0.95=realistic red noise). Can also be "decay" for mode-dependent ρₖ = exp(-k/K)
noise_ar1: bool = True, # Use AR(1) (red) noise instead of white noise for realistic temporal correlations
# Background state parameters
enable_background: bool = False, # Whether to add low-frequency background state (slow climate mean state drift)
background_strength: float = 0.3, # Strength relative to mode std (if < 1 and mode="relative") or absolute magnitude
background_strength_mode: str = "relative", # "relative" to mode std or "absolute"
background_smoothness: float = 0.15, # Controls spatial frequency (higher = smoother spatial patterns)
background_timescale_rho: float = 0.995, # AR(1) persistence (higher = slower temporal evolution, 0.995 ≈ 200 step timescale)
background_n_modes: int = 3, # Number of low-frequency Fourier components for spatial smoothness
use_correct_hyperparams: bool = True, # Override some of the model params to match those of savar data if true
):
self.time_len = time_len
Expand All @@ -317,6 +417,11 @@ def __init__(
self.n_per_col = n_per_col
self.difficulty = difficulty
self.seasonality = seasonality
self.periods = periods
self.amplitudes = amplitudes
self.phases = phases
self.yearly_jitter_amp = yearly_jitter_amp
self.yearly_jitter_phase = yearly_jitter_phase
self.overlap = overlap
self.is_forced = is_forced
self.f_1 = f_1
Expand All @@ -327,6 +432,29 @@ def __init__(
self.linearity = linearity
self.poly_degrees = poly_degrees
self.plot_original_data = plot_original_data
self.use_separate_forcings = use_separate_forcings
self.forcing_amplification = forcing_amplification
self.aerosol_scale = aerosol_scale
self.aerosol_spatial_contrast = aerosol_spatial_contrast
self.aerosol_ramp_up_time = aerosol_ramp_up_time
self.aerosol_peak_time = aerosol_peak_time
self.aerosol_decline_time = aerosol_decline_time
self.aerosol_timing_stagger = aerosol_timing_stagger
# Forcing causal structure
self.n_co2_latents = n_co2_latents
self.n_aerosol_latents = n_aerosol_latents
self.co2_effect_strength = co2_effect_strength
self.aerosol_effect_strength = aerosol_effect_strength
# Noise temporal correlation
self.noise_ar1_rho = noise_ar1_rho
self.noise_ar1 = noise_ar1
# Background state parameters
self.enable_background = enable_background
self.background_strength = background_strength
self.background_strength_mode = background_strength_mode
self.background_smoothness = background_smoothness
self.background_timescale_rho = background_timescale_rho
self.background_n_modes = background_n_modes
self.use_correct_hyperparams = use_correct_hyperparams


Expand Down
Loading