Conversation
|
Hi @SarahAlidoost and @meiertgrootes, I created an exmaple training process on a subset of the two year data., and ran it on Levante. In this PR I included example SLURM training process, with an README on how to config the jobs on Levante. A copy of the example run can be found on |
SarahAlidoost
left a comment
There was a problem hiding this comment.
@rogerkuou thanks for the script. Since the PR #29 fixed a few issues, we need to merge main into this branch. I also left some comments, mainly about the structure of the example.py and the code that should be run with slurm. If something is unclear, please let me know. In meantime, I will work on issue #33.
| return ds[["ts"]].sel(lon=lon_subset, lat=lat_subset) | ||
|
|
||
|
|
||
| def main(): |
There was a problem hiding this comment.
This function is currently doing a lot: creating the model, training it, making predictions, and saving results. We should split these responsibilities.
If this is a "training script", it should only handle reading the data, creating the model with the correct arguments, and passing both to a separate training function (that will be added in #33).
Then, in another script (e.g. "inference script"), we can load the saved model and make predictions (see #32). This separation is needed because training and inference require different computing resources.
Any plotting or result inspection can be done in a separate script if needed.
There was a problem hiding this comment.
I have splited this into a training script and an inference script. plotting part has been removed
scripts/example.py
Outdated
| lon_subset = slice(-10, 10) | ||
| lat_subset = slice(-5, 5) |
There was a problem hiding this comment.
slicing should not be needed, we want to work with global data on HPC.
| # Compute monthly climatology stats without persisting the full (time, lat, lon) monthly field | ||
| monthly_ts = daily_data["ts"].resample(time="MS").mean(skipna=True) | ||
| mean = monthly_ts.mean(dim=["lat", "lon"], skipna=True).compute().values | ||
| std = monthly_ts.std(dim=["lat", "lon"], skipna=True).compute().values | ||
| print(f"mean: {mean}, std: {std}") | ||
|
|
||
| # Make a dataset | ||
| dataset = STDataset( | ||
| daily_da=daily_data["ts"], | ||
| monthly_da=monthly_data["ts"], | ||
| land_mask=lsm_mask["lsm"], | ||
| patch_size=(patch_size_training, patch_size_training), | ||
| ) |
There was a problem hiding this comment.
All these lines should be moved to training script in #33.
scripts/example.py
Outdated
| optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) | ||
| decoder = model.decoder | ||
| with torch.no_grad(): | ||
| decoder.bias.copy_(torch.from_numpy(mean)) | ||
| decoder.scale.copy_(torch.from_numpy(std) + 1e-6) | ||
|
|
||
| # Make a dataloader | ||
| dataloader = DataLoader( | ||
| dataset, | ||
| batch_size=1, | ||
| shuffle=True, | ||
| pin_memory=False, | ||
| ) | ||
|
|
||
| # Training process | ||
| best_loss = float("inf") | ||
| patience = 10 | ||
| counter = 0 | ||
| model.train() | ||
| for epoch in range(101): | ||
| for batch in dataloader: | ||
| optimizer.zero_grad() | ||
|
|
||
| daily_batch = batch["daily_patch"] | ||
| daily_mask = batch["daily_mask_patch"] | ||
| monthly_target = batch["monthly_patch"] | ||
| land_mask_patch = batch["land_mask_patch"][0, ...] | ||
| padded_days_mask = batch["padded_days_mask"] | ||
|
|
||
| pred = model(daily_batch, daily_mask, land_mask_patch, padded_days_mask) | ||
|
|
||
| ocean = (~land_mask_patch).to(pred.device) | ||
| ocean = ocean[None, None, :, :] | ||
|
|
||
| loss = ( | ||
| torch.nn.functional.l1_loss(pred, monthly_target, reduction="none") | ||
| * ocean | ||
| ) | ||
| loss_per_month = loss.sum(dim=(-2, -1)) / ocean.sum(dim=(-2, -1)) | ||
| loss = loss_per_month.mean() | ||
|
|
||
| loss.backward() | ||
| optimizer.step() | ||
|
|
||
| if loss.item() < best_loss: | ||
| best_loss = loss.item() | ||
| counter = 0 | ||
|
|
||
| if epoch % 20 == 0: | ||
| print(f"The loss is {best_loss} at epoch {epoch}") | ||
| else: | ||
| counter += 1 | ||
| if counter >= patience: | ||
| print( | ||
| f"No improvement for {patience} epochs, stopping early at epoch {epoch}." | ||
| ) | ||
| break | ||
|
|
||
| print("training done!") | ||
| print(f"Final loss: {loss.item()}") |
There was a problem hiding this comment.
All these lines should be moved to training script in #33.
There was a problem hiding this comment.
Agree. I will leave this to another PR
scripts/example.py
Outdated
| # Calculate prediction and error | ||
| dataset_pred = STDataset( | ||
| daily_da=daily_data["ts"], | ||
| monthly_da=monthly_data["ts"], | ||
| land_mask=lsm_mask["lsm"], | ||
| patch_size=(daily_data.sizes["lat"], daily_data.sizes["lon"]), | ||
| ) | ||
| dataloader_pred = DataLoader( | ||
| dataset_pred, | ||
| batch_size=len(dataset_pred), | ||
| pin_memory=False, | ||
| ) | ||
| full_batch = next(iter(dataloader_pred)) | ||
| daily_batch = full_batch["daily_patch"] | ||
| daily_mask = full_batch["daily_mask_patch"] | ||
| monthly_target = full_batch["monthly_patch"] | ||
| land_mask_patch = full_batch["land_mask_patch"][0, ...] | ||
| padded_days_mask = full_batch["padded_days_mask"] | ||
| model.eval() | ||
| with torch.no_grad(): | ||
| pred = model(daily_batch, daily_mask, land_mask_patch, padded_days_mask) | ||
| monthly_prediction = pred_to_numpy(pred, land_mask=land_mask_patch)[0] | ||
| monthly_data["ts_pred"] = (("time", "lat", "lon"), monthly_prediction) |
There was a problem hiding this comment.
inference should be done in a separate script.
scripts/example.py
Outdated
| # Save the trained model | ||
| model_save_path = Path("./models/spatio_temporal_model.pth") | ||
| model_save_path.parent.mkdir(parents=True, exist_ok=True) | ||
| torch.save(model.state_dict(), model_save_path) |
There was a problem hiding this comment.
If we want to load the model later, we need to know how to create the model instance. Therefore, it is better to save the model config (arguments and defaults) with the model.
scripts/example.py
Outdated
| # Save the xr.Dataset with predictions | ||
| predictions_save_path = Path("./predicted_data/predictions.nc") | ||
| predictions_save_path.parent.mkdir(parents=True, exist_ok=True) | ||
| monthly_data.to_netcdf(predictions_save_path) | ||
| print(f"Saved model to: {model_save_path}") | ||
| print(f"Saved predictions to: {predictions_save_path}") |
There was a problem hiding this comment.
These should be moved to the inference script.
Please don't use print statement in a slurm job. Those information should be probably logged in a log file.
scripts/example.py
Outdated
| # Plot and save inspections | ||
| plot_path = Path("./figures/") # local | ||
| plot_path.mkdir(parents=True, exist_ok=True) | ||
| # 1) Prediction (t=0) | ||
| fig, ax = plt.subplots(figsize=(8, 4)) | ||
| monthly_data["ts_pred"].isel(time=0).plot(ax=ax) | ||
| fig.savefig(plot_path / "ts_pred_t0.png", dpi=200, bbox_inches="tight") | ||
| plt.close(fig) | ||
|
|
||
| # 2) Target (t=0) | ||
| fig, ax = plt.subplots(figsize=(8, 4)) | ||
| monthly_data["ts"].where(~lsm_mask["lsm"].values).isel(time=0).plot(ax=ax) | ||
| fig.savefig(plot_path / "ts_target_t0.png", dpi=200, bbox_inches="tight") | ||
| plt.close(fig) | ||
|
|
||
| # 3) Error (t=0) | ||
| fig, ax = plt.subplots(figsize=(8, 4)) | ||
| err.isel(time=0).plot(ax=ax) | ||
| fig.savefig(plot_path / "err_t0.png", dpi=200, bbox_inches="tight") | ||
| plt.close(fig) | ||
|
|
||
| # 4) Error (t=1) | ||
| fig, ax = plt.subplots(figsize=(8, 4)) | ||
| err.isel(time=1).plot(ax=ax) | ||
| fig.savefig(plot_path / "err_t1.png", dpi=200, bbox_inches="tight") | ||
| plt.close(fig) |
There was a problem hiding this comment.
We dont need these in a training script. They can be done later if we have the model and the predictions saved on disk.
Co-authored-by: SarahAlidoost <55081872+SarahAlidoost@users.noreply.github.com>
|
Hi @SarahAlidoost, thanks for the review! I implemented most of your comments:
I did not implemente the training utility function and will leave it to #33 . Can you give another look? |
|
|
||
| # Initialize training | ||
| device = "cuda" if torch.cuda.is_available() else "cpu" | ||
| model = SpatioTemporalModel( |
There was a problem hiding this comment.
Can you please use the same arguments for the model as those in the example notebook in the main branch?
| decoder.scale.copy_(torch.from_numpy(std) + 1e-6) | ||
|
|
||
| # Make a dataloader | ||
| dataloader = DataLoader( |
There was a problem hiding this comment.
Can you please use the same arguments for the dataloader as those in the example notebook in the main branch?
| patience = 10 | ||
| counter = 0 | ||
| model.train() | ||
| for epoch in range(101): |
There was a problem hiding this comment.
Can you please use the same training loop as in the example notebook in the main branch?
| counter = 0 | ||
|
|
||
| if epoch % 20 == 0: | ||
| logger.info(f"The loss is {best_loss} at epoch {epoch}") |
There was a problem hiding this comment.
where logger info will be stored, in slurm log file? 🤔
| self.config = { | ||
| 'in_chans': in_chans, | ||
| 'embed_dim': embed_dim, | ||
| 'patch_size': patch_size, | ||
| 'max_days': max_days, | ||
| 'max_months': max_months, | ||
| 'num_months': num_months, | ||
| 'hidden': hidden, | ||
| 'overlap': overlap, | ||
| 'max_H': max_H, | ||
| 'max_W': max_W, | ||
| 'spatial_depth': spatial_depth, | ||
| 'spatial_heads': spatial_heads, | ||
| } |
There was a problem hiding this comment.
| self.config = { | |
| 'in_chans': in_chans, | |
| 'embed_dim': embed_dim, | |
| 'patch_size': patch_size, | |
| 'max_days': max_days, | |
| 'max_months': max_months, | |
| 'num_months': num_months, | |
| 'hidden': hidden, | |
| 'overlap': overlap, | |
| 'max_H': max_H, | |
| 'max_W': max_W, | |
| 'spatial_depth': spatial_depth, | |
| 'spatial_heads': spatial_heads, | |
| } |
SarahAlidoost
left a comment
There was a problem hiding this comment.
@rogerkuou thanks for addressing the comments 👍 . Here some more suggestions:
- I see that the example notebook has been changed in this PR. I cannot see exactly what is changed, but since this PR is about testing large data on HPC, let's not change the example notebook.
- No need to add inference script in this PR. For now we can skip that one. Let's focus on setup of the training on HPC in this PR. Also, in fixing #32 we can add inefrence script later.
- After implementing these suggestions and re-running the slurm job, can you please add the slurm logfile to the PR as well? Also, can you perhaps give an indication how much resources have been used to complete the job.
If something not clear, please let me know.
fix #25
did not finish the train-validation-test split in this PR, but made a new issue #28