Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions climatem/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,8 @@ def __init__(
num_layers: int = 2,
num_output: int = 2, # NOT SURE
position_embedding_dim: int = 100, # Dimension of positional embedding
reduce_encoding_pos_dim: bool = False,
transition_param_sharing: bool = True,
position_embedding_transition: int = 100,
fixed: bool = False, # Do we fix the causal graph? Should be in gt_params maybe
fixed_output_fraction=None, # NOT SURE, Remove this?
tau_neigh: int = 0, # NOT SURE
Expand All @@ -177,7 +178,8 @@ def __init__(
self.num_hidden_mixing = num_hidden_mixing
self.num_layers_mixing = num_layers_mixing
self.position_embedding_dim = position_embedding_dim
self.reduce_encoding_pos_dim = reduce_encoding_pos_dim
self.transition_param_sharing = transition_param_sharing
self.position_embedding_transition = position_embedding_transition
self.fixed = fixed
self.fixed_output_fraction = fixed_output_fraction
self.tau_neigh = tau_neigh
Expand Down Expand Up @@ -295,6 +297,13 @@ def __init__(
seasonality: bool = False, # Seasonality in synthetic data
overlap: bool = False, # Modes overlap
is_forced: bool = False, # Forcings in synthetic data
f_1: int = 1,
f_2: int = 2,
f_time_1: int = 4000,
f_time_2: int = 8000,
ramp_type: str = "linear",
linearity: str = "linear",
poly_degrees: List[int] = [2],
plot_original_data: bool = True,
use_correct_hyperparams: bool = True, # Override some of the model params to match those of savar data if true
):
Expand All @@ -306,6 +315,13 @@ def __init__(
self.seasonality = seasonality
self.overlap = overlap
self.is_forced = is_forced
self.f_1 = f_1
self.f_2 = f_2
self.f_time_1 = f_time_1
self.f_time_2 = f_time_2
self.ramp_type = ramp_type
self.linearity = linearity
self.poly_degrees = poly_degrees
self.plot_original_data = plot_original_data
self.use_correct_hyperparams = use_correct_hyperparams

Expand Down
11 changes: 9 additions & 2 deletions climatem/data_loader/causal_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
# import relevant data loading modules
from climatem.data_loader.climate_datamodule import ClimateDataModule
from climatem.data_loader.cmip6_dataset import CMIP6Dataset
from climatem.data_loader.input4mip_dataset import Input4MipsDataset
from climatem.data_loader.era5_dataset import ERA5Dataset
from climatem.data_loader.input4mip_dataset import Input4MipsDataset
from climatem.data_loader.savar_dataset import SavarDataset


Expand Down Expand Up @@ -99,6 +99,13 @@ def setup(self, stage: Optional[str] = None):
seasonality=self.hparams.seasonality,
overlap=self.hparams.overlap,
is_forced=self.hparams.is_forced,
f_1=self.hparams.f_1,
f_2=self.hparams.f_2,
f_time_1=self.hparams.f_time_1,
f_time_2=self.hparams.f_time_2,
ramp_type=self.hparams.ramp_type,
linearity=self.hparams.linearity,
poly_degrees=self.hparams.poly_degrees,
plot_original_data=self.hparams.plot_original_data,
)
elif (
Expand Down Expand Up @@ -213,4 +220,4 @@ def setup(self, stage: Optional[str] = None):
else OPENBURNING_MODEL_MAPPING["other"]
)
for test_model in self.hparams.test_models
}
}
2 changes: 2 additions & 0 deletions climatem/data_loader/climate_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ def __init__(
seasonality: bool = False,
overlap: bool = False,
is_forced: bool = False,
linearity: str = "linear",
poly_degrees: List[int] = [2],
plot_original_data: bool = True,
):
"""
Expand Down
28 changes: 26 additions & 2 deletions climatem/data_loader/savar_dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from pathlib import Path
from typing import Optional
from typing import List, Optional

import numpy as np
import torch
Expand All @@ -27,11 +27,21 @@ def __init__(
seasonality: bool = False,
overlap: bool = False,
is_forced: bool = False,
f_1: int = 1,
f_2: int = 2,
f_time_1: int = 4000,
f_time_2: int = 8000,
ramp_type: str = "linear",
linearity: str = "linear",
poly_degrees: List[int] = [2, 3],
plot_original_data: bool = True,
):
super().__init__()
self.output_save_dir = Path(output_save_dir)
self.savar_name = f"modes_{n_per_col**2}_tl_{time_len}_isforced_{is_forced}_difficulty_{difficulty}_noisestrength_{noise_val}_seasonality_{seasonality}_overlap_{overlap}"
savar_poly_deg = (
str(poly_degrees)[1:-1].translate({ord("'"): None}).translate({ord(","): None}).translate({ord(" "): None})
)
self.savar_name = f"modes_{n_per_col**2}_tl_{time_len}_forced_{is_forced}_dif_{difficulty}_noise_{noise_val}_season_{seasonality}_over_{overlap}_lin_{linearity}_poldeg_{savar_poly_deg}"
self.savar_path = self.output_save_dir / f"{self.savar_name}.npy"

self.global_normalization = global_normalization
Expand All @@ -51,6 +61,13 @@ def __init__(
self.seasonality = seasonality
self.overlap = overlap
self.is_forced = is_forced
self.f_1 = f_1
self.f_2 = f_2
self.f_time_1 = f_time_1
self.f_time_2 = f_time_2
self.ramp_type = ramp_type
self.linearity = linearity
self.poly_degrees = poly_degrees
self.plot_original_data = plot_original_data

if self.reload_climate_set_data:
Expand Down Expand Up @@ -173,6 +190,13 @@ def get_causal_data(
self.seasonality,
self.overlap,
self.is_forced,
self.f_1,
self.f_2,
self.f_time_1,
self.f_time_2,
self.ramp_type,
self.linearity,
self.poly_degrees,
self.plot_original_data,
)
time_steps = data.shape[1]
Expand Down
4 changes: 4 additions & 0 deletions climatem/model/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@ def __init__(
self,
model,
datamodule,
data_params,
exp_params,
gt_params,
model_params,
train_params,
optim_params,
plot_params,
savar_params,
save_path,
plots_path,
best_metrics,
Expand All @@ -45,6 +47,7 @@ def __init__(
self.data_loader_train = iter(datamodule.train_dataloader(accelerator=accelerator))
self.data_loader_val = iter(datamodule.val_dataloader())
self.coordinates = datamodule.coordinates
self.data_params = data_params
self.exp_params = exp_params
self.train_params = train_params
self.optim_params = optim_params
Expand All @@ -55,6 +58,7 @@ def __init__(
)

self.plot_params = plot_params
self.savar_params = savar_params
self.best_metrics = best_metrics
self.save_path = save_path
self.plots_path = plots_path
Expand Down
73 changes: 35 additions & 38 deletions climatem/model/tsdcd_latent.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,8 @@ def __init__(
num_layers_mixing: int,
num_hidden_mixing: int,
position_embedding_dim: int,
reduce_encoding_pos_dim: bool,
transition_param_sharing: bool,
position_embedding_transition: int,
coeff_kl: float,
distr_z0: str,
distr_encoder: str,
Expand Down Expand Up @@ -274,7 +275,8 @@ def __init__(
self.num_layers_mixing = num_layers_mixing
self.num_hidden_mixing = num_hidden_mixing
self.position_embedding_dim = position_embedding_dim
self.reduce_encoding_pos_dim = reduce_encoding_pos_dim
self.transition_param_sharing = transition_param_sharing
self.position_embedding_transition = position_embedding_transition
self.coeff_kl = coeff_kl

self.d = d
Expand Down Expand Up @@ -352,7 +354,6 @@ def __init__(
use_gumbel_mask=False,
tied=tied_w,
embedding_dim=self.position_embedding_dim,
reduce_encoding_pos_dim=self.reduce_encoding_pos_dim,
gt_w=None,
)
else:
Expand All @@ -363,16 +364,27 @@ def __init__(
if debug_gt_w:
self.decoder.w = gt_w

self.transition_model = TransitionModelParamSharing(
self.d,
self.d_z,
self.total_tau,
self.nonlinear_dynamics,
self.num_layers,
self.num_hidden,
self.num_output,
self.position_embedding_dim,
)
if self.transition_param_sharing:
self.transition_model = TransitionModelParamSharing(
self.d,
self.d_z,
self.total_tau,
self.nonlinear_dynamics,
self.num_layers,
self.num_hidden,
self.num_output,
self.position_embedding_transition,
)
else:
self.transition_model = TransitionModel(
self.d,
self.d_z,
self.total_tau,
self.nonlinear_dynamics,
self.num_layers,
self.num_hidden,
self.num_output,
)

# print("We are setting the Mask here.")
self.mask = Mask(
Expand Down Expand Up @@ -456,19 +468,13 @@ def transition(self, z, mask):

# TODO Can we remove this for loop
for i in range(self.d):
pz_params = torch.zeros(b, self.d_z, 1)
# print("This is pz_params shape, before we fill it up with a for loop, where the 2nd dimension is filled with the result of the transition model.", pz_params.shape)
# for k in range(self.d_z):
# print("Doing the transition, and this is the k at the moment.", k)
# print("**************************************************")
# print("What is the shape of the mask?", mask.shape)
# print("What is the shape of mask[:, :, i * self.d_z + k]?", mask[:, :, i * self.d_z + k].shape)
# print("THIS DEFINES THE MASK THAT IS USED TO PRODUCE A PARTICULAR LATENT, Z_k.")
pz_params = self.transition_model(z, mask[:, :, i * self.d_z : (i + 1) * self.d_z], i)

# print("Note here that mu[:, i] is the same as pz_params[:, :, 0], once we have filled up pz_params [:, k] wise, with each k being a forward pass.")
# print("What is the shape of mu[:, i] and std[:, i]?", mu[:, i].shape, std[:, i].shape)
# print("What is the shape of pz_params[:, :, 0]?", pz_params[:, :, 0].shape)

if self.transition_param_sharing:
pz_params = self.transition_model(z, mask[:, :, i * self.d_z : (i + 1) * self.d_z], i)
else:
pz_params = torch.zeros(b, self.d_z, 1)
for k in range(self.d_z):
pz_params[:, k] = self.transition_model(z, mask[:, :, i * self.d_z + k], i, k)
mu[:, i] = pz_params[:, :, 0]
std[:, i] = torch.exp(0.5 * self.transition_model.logvar[i])

Expand Down Expand Up @@ -954,8 +960,6 @@ def select_decoder_mask(self, mask, i, j):

class NonLinearAutoEncoderUniqueMLP_noloop(NonLinearAutoEncoder):

# TODO: SURELY A BUG??? EMBEDDING DECODER/ENCODER not correctly used?

def __init__(
self,
d,
Expand All @@ -966,19 +970,11 @@ def __init__(
use_gumbel_mask,
tied,
embedding_dim,
reduce_encoding_pos_dim,
gt_w=None,
):
super().__init__(d, d_x, d_z, num_hidden, num_layer, use_gumbel_mask, tied, gt_w)
# embedding_dim_encoding = d_z // 10
if not reduce_encoding_pos_dim:
self.embedding_encoder = nn.Embedding(d_z, embedding_dim)
self.encoder = MLP(num_layer, num_hidden, d_x + embedding_dim, 1) # embedding_dim_encoding
else:
self.embedding_encoder = nn.Embedding(d_z, embedding_dim // 10)
self.encoder = MLP(num_layer, num_hidden, d_x + embedding_dim // 10, 1) # embedding_dim_encoding
# self.encoder = MLP(num_layer, num_hidden, d_x + embedding_dim, 1)
# self.embedding_encoder = nn.Embedding(d_z, embedding_dim)
self.embedding_encoder = nn.Embedding(d_z, embedding_dim)
self.encoder = MLP(num_layer, num_hidden, d_x + embedding_dim, 1) # embedding_dim_encoding

self.decoder = MLP(num_layer, num_hidden, d_z + embedding_dim, 1)
self.embedding_decoder = nn.Embedding(d_x, embedding_dim)
Expand Down Expand Up @@ -1177,6 +1173,7 @@ def __init__(
num_hidden: number of hidden units
num_output: number of outputs
"""

super().__init__()
self.d = d # number of variables
self.d_z = d_z
Expand Down
Loading