Skip to content

Conversation

@sophie-xhonneux
Copy link
Contributor

Description

See PR #1590

Issue Number

Closes #1587

Is this PR a draft? Mark it as draft.

Checklist before asking for review

  • I have performed a self-review of my code
  • My changes comply with basic sanity checks:
    • I have fixed formatting issues with ./scripts/actions.sh lint
    • I have run unit tests with ./scripts/actions.sh unit-test
    • I have documented my code and I have updated the docstrings.
    • I have added unit tests, if relevant
  • I have tried my changes with data and code:
    • I have run the integration tests with ./scripts/actions.sh integration-test
    • (bigger changes) I have run a full training and I have written in the comment the run_id(s): launch-slurm.py --time 60
    • (bigger changes and experiments) I have shared a hegdedoc in the github issue with all the configurations and runs for this experiments
  • I have informed and aligned with people impacted by my change:
    • for config changes: the MatterMost channels and/or a design doc
    • for changes of dependencies: the MatterMost software development channel

@github-actions github-actions bot added the model Related to model training or definition (not generic infra) label Jan 17, 2026
Copy link
Collaborator

@clessig clessig left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trainer.py and trainer_base.py need to be cleaned up, please.

# currently fixed to 1.0 (due to limitations with flex_attention and triton)
forecast_att_dense_rate: 1.0

sslpred_num_blocks: 12
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think these params should go into the model block but part of the SSL loss term. This here is really also JEPA specific.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sophie took care of this

self.pred_blocks = nn.ModuleList()

# first map to intermediate_dim to introduce a bottleneck
self.pred_blocks.append(nn.Linear(in_dim, intermediate_dim, bias=False))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should call this blocks in all modules.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

# we concatenate the patch and class tokens to process them together
# We concatenate in the token dimension [Batch, Tokens, Dim]
patch_class_tokens = []
if self.class_token:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should use x.class_token and x.patch_token here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I renamed these to

self.use_class_token = use_class_token
self.use_patch_token = use_patch_token

clarify that these are boolians. Now t looks like in the forward()

if self.use_class_token:
   patch_class_tokens.append(x.class_token)
if self.use_patch_token:
   patch_class_tokens.append(x.patch_tokens)
patch_class_tokens = torch.cat(patch_class_tokens, dim=1)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better although I think it's still duplicate and using x.class_token and x.patch_token is more robust.

if isinstance(block, torch.nn.modules.normalization.LayerNorm):
patch_class_tokens = block(patch_class_tokens)
else:
patch_class_tokens = checkpoint(block, patch_class_tokens, use_reentrant=False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The checkpoint should be on a coarser level if possible.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not understand this comment :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The checkpoint should wrap the call to this forward function, not be in the forward function

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not understand this comment :)

The checkpoint should wrap the call to this forward function, not be in the forward function

self.patch_token = patch_token
# For now this is a Linear Layer TBD what this architecture should be
self.layer = nn.Linear(in_dim, out_dim, bias=False)
self.layer = MLP(in_dim, out_dim, num_layers, hidden_factor)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.layer -> self.blocks

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

print("Happy to be here")
batch.to_device(self.device)

print("Batch to device")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove

self.training_cfg.window_offset_prediction,
)

print("Model predictions")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

self.model,
self.training_cfg.window_offset_prediction,
)
print("target predictions")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

targets_and_aux=targets_and_auxs,
metadata=extract_batch_metadata(batch),
)
print("loss calcuclation")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

removed more print() statements

return torch.device("cpu")

local_id_node = os.environ.get("SLURM_LOCALID", "-1")
local_id_node = os.environ.get("LOCAL_RANK", os.environ.get("SLURM_LOCALID", "-1"))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be torch.distributed.get_local_rank(), unless someone can explain why it's not suitable here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure why this even changed...

Using dist.get_node_local_rank (fallback_rank=-1) instead

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

model Related to model training or definition (not generic infra)

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

Implement transformer based predictors for JEPA

4 participants