Skip to content

25 test two year data#27

Open
rogerkuou wants to merge 13 commits intomainfrom
25_test_two_year_data
Open

25 test two year data#27
rogerkuou wants to merge 13 commits intomainfrom
25_test_two_year_data

Conversation

@rogerkuou
Copy link
Collaborator

@rogerkuou rogerkuou commented Feb 27, 2026

fix #25

did not finish the train-validation-test split in this PR, but made a new issue #28

@rogerkuou rogerkuou mentioned this pull request Feb 27, 2026
@rogerkuou rogerkuou marked this pull request as ready for review February 27, 2026 14:00
@rogerkuou
Copy link
Collaborator Author

Hi @SarahAlidoost and @meiertgrootes, I created an exmaple training process on a subset of the two year data., and ran it on Levante. In this PR I included example SLURM training process, with an README on how to config the jobs on Levante.

A copy of the example run can be found on /work/bd0854/b380854/eso4clima. I executed the slurm task on my home dir, and copied the entire experiment here.

Copy link
Member

@SarahAlidoost SarahAlidoost left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rogerkuou thanks for the script. Since the PR #29 fixed a few issues, we need to merge main into this branch. I also left some comments, mainly about the structure of the example.py and the code that should be run with slurm. If something is unclear, please let me know. In meantime, I will work on issue #33.

return ds[["ts"]].sel(lon=lon_subset, lat=lat_subset)


def main():
Copy link
Member

@SarahAlidoost SarahAlidoost Mar 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is currently doing a lot: creating the model, training it, making predictions, and saving results. We should split these responsibilities.
If this is a "training script", it should only handle reading the data, creating the model with the correct arguments, and passing both to a separate training function (that will be added in #33).

Then, in another script (e.g. "inference script"), we can load the saved model and make predictions (see #32). This separation is needed because training and inference require different computing resources.

Any plotting or result inspection can be done in a separate script if needed.

Copy link
Collaborator Author

@rogerkuou rogerkuou Mar 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have splited this into a training script and an inference script. plotting part has been removed

Comment on lines +40 to +41
lon_subset = slice(-10, 10)
lat_subset = slice(-5, 5)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

slicing should not be needed, we want to work with global data on HPC.

Comment on lines +70 to +82
# Compute monthly climatology stats without persisting the full (time, lat, lon) monthly field
monthly_ts = daily_data["ts"].resample(time="MS").mean(skipna=True)
mean = monthly_ts.mean(dim=["lat", "lon"], skipna=True).compute().values
std = monthly_ts.std(dim=["lat", "lon"], skipna=True).compute().values
print(f"mean: {mean}, std: {std}")

# Make a dataset
dataset = STDataset(
daily_da=daily_data["ts"],
monthly_da=monthly_data["ts"],
land_mask=lsm_mask["lsm"],
patch_size=(patch_size_training, patch_size_training),
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All these lines should be moved to training script in #33.

Comment on lines +92 to +151
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
decoder = model.decoder
with torch.no_grad():
decoder.bias.copy_(torch.from_numpy(mean))
decoder.scale.copy_(torch.from_numpy(std) + 1e-6)

# Make a dataloader
dataloader = DataLoader(
dataset,
batch_size=1,
shuffle=True,
pin_memory=False,
)

# Training process
best_loss = float("inf")
patience = 10
counter = 0
model.train()
for epoch in range(101):
for batch in dataloader:
optimizer.zero_grad()

daily_batch = batch["daily_patch"]
daily_mask = batch["daily_mask_patch"]
monthly_target = batch["monthly_patch"]
land_mask_patch = batch["land_mask_patch"][0, ...]
padded_days_mask = batch["padded_days_mask"]

pred = model(daily_batch, daily_mask, land_mask_patch, padded_days_mask)

ocean = (~land_mask_patch).to(pred.device)
ocean = ocean[None, None, :, :]

loss = (
torch.nn.functional.l1_loss(pred, monthly_target, reduction="none")
* ocean
)
loss_per_month = loss.sum(dim=(-2, -1)) / ocean.sum(dim=(-2, -1))
loss = loss_per_month.mean()

loss.backward()
optimizer.step()

if loss.item() < best_loss:
best_loss = loss.item()
counter = 0

if epoch % 20 == 0:
print(f"The loss is {best_loss} at epoch {epoch}")
else:
counter += 1
if counter >= patience:
print(
f"No improvement for {patience} epochs, stopping early at epoch {epoch}."
)
break

print("training done!")
print(f"Final loss: {loss.item()}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All these lines should be moved to training script in #33.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree. I will leave this to another PR

Comment on lines +153 to +175
# Calculate prediction and error
dataset_pred = STDataset(
daily_da=daily_data["ts"],
monthly_da=monthly_data["ts"],
land_mask=lsm_mask["lsm"],
patch_size=(daily_data.sizes["lat"], daily_data.sizes["lon"]),
)
dataloader_pred = DataLoader(
dataset_pred,
batch_size=len(dataset_pred),
pin_memory=False,
)
full_batch = next(iter(dataloader_pred))
daily_batch = full_batch["daily_patch"]
daily_mask = full_batch["daily_mask_patch"]
monthly_target = full_batch["monthly_patch"]
land_mask_patch = full_batch["land_mask_patch"][0, ...]
padded_days_mask = full_batch["padded_days_mask"]
model.eval()
with torch.no_grad():
pred = model(daily_batch, daily_mask, land_mask_patch, padded_days_mask)
monthly_prediction = pred_to_numpy(pred, land_mask=land_mask_patch)[0]
monthly_data["ts_pred"] = (("time", "lat", "lon"), monthly_prediction)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inference should be done in a separate script.

# Save the trained model
model_save_path = Path("./models/spatio_temporal_model.pth")
model_save_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(model.state_dict(), model_save_path)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we want to load the model later, we need to know how to create the model instance. Therefore, it is better to save the model config (arguments and defaults) with the model.

Comment on lines +183 to +188
# Save the xr.Dataset with predictions
predictions_save_path = Path("./predicted_data/predictions.nc")
predictions_save_path.parent.mkdir(parents=True, exist_ok=True)
monthly_data.to_netcdf(predictions_save_path)
print(f"Saved model to: {model_save_path}")
print(f"Saved predictions to: {predictions_save_path}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These should be moved to the inference script.

Please don't use print statement in a slurm job. Those information should be probably logged in a log file.

Comment on lines +190 to +215
# Plot and save inspections
plot_path = Path("./figures/") # local
plot_path.mkdir(parents=True, exist_ok=True)
# 1) Prediction (t=0)
fig, ax = plt.subplots(figsize=(8, 4))
monthly_data["ts_pred"].isel(time=0).plot(ax=ax)
fig.savefig(plot_path / "ts_pred_t0.png", dpi=200, bbox_inches="tight")
plt.close(fig)

# 2) Target (t=0)
fig, ax = plt.subplots(figsize=(8, 4))
monthly_data["ts"].where(~lsm_mask["lsm"].values).isel(time=0).plot(ax=ax)
fig.savefig(plot_path / "ts_target_t0.png", dpi=200, bbox_inches="tight")
plt.close(fig)

# 3) Error (t=0)
fig, ax = plt.subplots(figsize=(8, 4))
err.isel(time=0).plot(ax=ax)
fig.savefig(plot_path / "err_t0.png", dpi=200, bbox_inches="tight")
plt.close(fig)

# 4) Error (t=1)
fig, ax = plt.subplots(figsize=(8, 4))
err.isel(time=1).plot(ax=ax)
fig.savefig(plot_path / "err_t1.png", dpi=200, bbox_inches="tight")
plt.close(fig)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We dont need these in a training script. They can be done later if we have the model and the predictions saved on disk.

@rogerkuou
Copy link
Collaborator Author

Hi @SarahAlidoost, thanks for the review! I implemented most of your comments:

  1. I separated the example script to two: training and inference
  2. Now the training scipt export models with checkpoint. I slightlt modified the class to make it returning the config
  3. I used logging to replace the print statements
  4. The plotting part has been removed

I did not implemente the training utility function and will leave it to #33 .

Can you give another look?


# Initialize training
device = "cuda" if torch.cuda.is_available() else "cpu"
model = SpatioTemporalModel(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please use the same arguments for the model as those in the example notebook in the main branch?

decoder.scale.copy_(torch.from_numpy(std) + 1e-6)

# Make a dataloader
dataloader = DataLoader(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please use the same arguments for the dataloader as those in the example notebook in the main branch?

patience = 10
counter = 0
model.train()
for epoch in range(101):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please use the same training loop as in the example notebook in the main branch?

counter = 0

if epoch % 20 == 0:
logger.info(f"The loss is {best_loss} at epoch {epoch}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where logger info will be stored, in slurm log file? 🤔

Comment on lines +554 to +567
self.config = {
'in_chans': in_chans,
'embed_dim': embed_dim,
'patch_size': patch_size,
'max_days': max_days,
'max_months': max_months,
'num_months': num_months,
'hidden': hidden,
'overlap': overlap,
'max_H': max_H,
'max_W': max_W,
'spatial_depth': spatial_depth,
'spatial_heads': spatial_heads,
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.config = {
'in_chans': in_chans,
'embed_dim': embed_dim,
'patch_size': patch_size,
'max_days': max_days,
'max_months': max_months,
'num_months': num_months,
'hidden': hidden,
'overlap': overlap,
'max_H': max_H,
'max_W': max_W,
'spatial_depth': spatial_depth,
'spatial_heads': spatial_heads,
}

Copy link
Member

@SarahAlidoost SarahAlidoost left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rogerkuou thanks for addressing the comments 👍 . Here some more suggestions:

  • I see that the example notebook has been changed in this PR. I cannot see exactly what is changed, but since this PR is about testing large data on HPC, let's not change the example notebook.
  • No need to add inference script in this PR. For now we can skip that one. Let's focus on setup of the training on HPC in this PR. Also, in fixing #32 we can add inefrence script later.
  • After implementing these suggestions and re-running the slurm job, can you please add the slurm logfile to the PR as well? Also, can you perhaps give an indication how much resources have been used to complete the job.

If something not clear, please let me know.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Test training on the two year dataset

2 participants