Skip to content

[DRAFT] Mh/jk/diffusion full pipeline forecast#2396

Draft
moritzhauschulz wants to merge 29 commits into
ecmwf:jk/develop/diffusion-full-pipelinefrom
moritzhauschulz:mh/jk/diffusion-full-pipeline-forecast
Draft

[DRAFT] Mh/jk/diffusion full pipeline forecast#2396
moritzhauschulz wants to merge 29 commits into
ecmwf:jk/develop/diffusion-full-pipelinefrom
moritzhauschulz:mh/jk/diffusion-full-pipeline-forecast

Conversation

@moritzhauschulz
Copy link
Copy Markdown
Contributor

Description

DRAFT PR to assess diff between my conditioning branch and the current main diffusion branch.

Issue Number

Is this PR a draft? Mark it as draft.

Checklist before asking for review

  • I have performed a self-review of my code
  • My changes comply with basic sanity checks:
    • I have fixed formatting issues with ./scripts/actions.sh lint
    • I have run unit tests with ./scripts/actions.sh unit-test
    • I have documented my code and I have updated the docstrings.
    • I have added unit tests, if relevant
  • I have tried my changes with data and code:
    • I have run the integration tests with ./scripts/actions.sh integration-test
    • (bigger changes) I have run a full training and I have written in the comment the run_id(s): launch-slurm.py --time 60
    • (bigger changes and experiments) I have shared a hegdedoc in the github issue with all the configurations and runs for this experiments
  • I have informed and aligned with people impacted by my change:
    • for config changes: the MatterMost channels and/or a design doc
    • for changes of dependencies: the MatterMost software development channel

@github-actions github-actions Bot added the model Related to model training or definition (not generic infra) label May 20, 2026
Copy link
Copy Markdown
Contributor

@MatKbauer MatKbauer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two small modifications, but looks good. Currently training some models to see convergence.

tokens, posteriors = self.encoder.encoder(model_params=model_params, batch=batch)
shape = (len(batch), batch.get_num_steps(), *tokens.shape[1:])
tokens_multi = tokens.reshape(shape)
tokens = tokens_multi[:, -1]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's revert this back again

Comment thread src/weathergen/model/model.py Outdated
# Reshape tokens to [B, T, ...]
tokens = tokens.reshape(shape)

if self.cf.get("fe_diffusion_model", False):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To allow unconditional and non-forecast conditional training, this check should be

if self.cf.get("fe_diffusion_model_conditioning", None) == "forecast":

Copy link
Copy Markdown
Contributor

@MatKbauer MatKbauer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some more suggestions to make code more robust

self.streams = cf.streams
self.rank = cf.rank
self.world_size = cf.world_size
self.diffusion_model_conditioning = cf.fe_diffusion_model_conditioning
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cf.get("fe...", None)

embedding_dim=self.embedding_dim, frequency_embedding_dim=self.frequency_embedding_dim
)
self.datetime_embedder = DateTimeEncoder()
self.conditioning = self.cf.fe_diffusion_model_conditioning
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.cf.get("fe_diffusion_model_conditioning", None")

if self.cf.fe_diffusion_model_conditioning in ["date_time", "date", "time"]:
c = meta_info["ERA5"].params["timestamp"]
elif self.cf.fe_diffusion_model_conditioning == "forecast":
c = meta_info["ERA5"].params["conditioning_tokens"] # X_{t-1} as conditioning (model.py extracts last step as target, passes second-to-last here)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.conditioning in both places

# Extract conditioning from meta_info (same as training_forward)
# Extract conditioning (mirrors training_forward).
c = None
if self.cf.fe_diffusion_model_conditioning in ["date_time", "date", "time"]:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.conditioning

qk_norm_type=self.cf.get("qk_norm_type", self.cf.norm_type),
norm_eps=self.cf.norm_eps,
attention_dtype=get_dtype(self.cf.attention_dtype),
is_dit=self.cf.fe_diffusion_model,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's also make this a self.cf.get("fe_diffusion_model", False)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

model Related to model training or definition (not generic infra)

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

2 participants