[2370][model] Make the EncoderModule self-contained#2372
Conversation
clessig
left a comment
There was a problem hiding this comment.
Thanks for the implementation. I had a very quick look and I think we need separate ROPE coords also for the forecasting now
| if without_grad: | ||
| # Pushforward mode: advance tokens without grad; no decoding with torch.no_grad(): | ||
| tokens = self.forecast_engine(tokens, step, model_params.rope_coords) | ||
| tokens = self.forecast_engine(tokens, step, self.encoder.rope_coords) |
There was a problem hiding this comment.
If we have multiple healpix levels on multiple decoders, this won't work. I think we need to have separate rope coords for the forecast engine in it.
There was a problem hiding this comment.
Thanks for the review. Can you explain why we need seperate rope coords for the forecast engine. At the end, same long/lat points should have the same rope_coords or am I mixing something up?
There was a problem hiding this comment.
We could have healpix level 5 and healpix level 7 for the input and combine them on a healpix level 6 latent grid. The forecast engine would operate on this one.
12e124c to
a95bc9e
Compare
e5fddd5 to
585fe84
Compare
|
@Tewson1 : thanks for the hard work. Could you summarize which cases you have tested? |
…n vectors/buffers like pe_embed
I haven't really started testing. I have launched a job (32 epochs training, 16 epochs finetuning and also inference) to see if the training and model performance behaves the same as before (I also created this hedgedoc: https://gitlab.jsc.fz-juelich.de/hedgedoc/kSSjIW6gQ9OWL_c8ByWTJQ). The job failed last evening because of a bug in trainer.py. We have now buffers/plain vectors in What are the other approaches to testing ? Thanks! |
|
We should test with ERA5 and multiple datasets. We also need to check with SSL and latent diffusion but maybe thinking about these cases is enough. |
@clessig Thanks for the response! Do you think that 32 epochs of training and 16 epochs of finetuning is enough.
I was looking through the code and it seems that training_mode doesn't affect the encoder. So I am also not sure either. Otherwise I am considering a run where I use |
32+16 is definitely enough. I think if it works mechanically then also convergence shouldn't work. But it would be important to test with multiple datasets and pretraining + finetuning (mainly to check model loading) and inference. Also run |
Description
Moves all HealPix-level dependent postitional encoding parameters out of
ModelParamsand intoEncoderModuledirectly, making theEncoderModuleself-contained. Previouslype_global,q_cells_lens,rope_coords, andrope_cell_coordslived in a sharedModelParamsobject that had to be passed intoEncoderModule.forwardat every call. This made it impossible to instantiate multiple encoders at different HEALPix resolutionsIssue Number
Closes #2370
CC: @clessig
Checklist before asking for review
./scripts/actions.sh lint./scripts/actions.sh unit-test./scripts/actions.sh integration-testlaunch-slurm.py --time 60