-
Notifications
You must be signed in to change notification settings - Fork 51
Sophiex/kerem/pr/transformer head #1649
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
clessig
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
trainer.py and trainer_base.py need to be cleaned up, please.
| # currently fixed to 1.0 (due to limitations with flex_attention and triton) | ||
| forecast_att_dense_rate: 1.0 | ||
|
|
||
| sslpred_num_blocks: 12 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think these params should go into the model block but part of the SSL loss term. This here is really also JEPA specific.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sophie took care of this
| self.pred_blocks = nn.ModuleList() | ||
|
|
||
| # first map to intermediate_dim to introduce a bottleneck | ||
| self.pred_blocks.append(nn.Linear(in_dim, intermediate_dim, bias=False)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should call this blocks in all modules.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
| # we concatenate the patch and class tokens to process them together | ||
| # We concatenate in the token dimension [Batch, Tokens, Dim] | ||
| patch_class_tokens = [] | ||
| if self.class_token: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should use x.class_token and x.patch_token here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I renamed these to
self.use_class_token = use_class_token
self.use_patch_token = use_patch_tokenclarify that these are boolians. Now t looks like in the forward()
if self.use_class_token:
patch_class_tokens.append(x.class_token)
if self.use_patch_token:
patch_class_tokens.append(x.patch_tokens)
patch_class_tokens = torch.cat(patch_class_tokens, dim=1)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better although I think it's still duplicate and using x.class_token and x.patch_token is more robust.
| if isinstance(block, torch.nn.modules.normalization.LayerNorm): | ||
| patch_class_tokens = block(patch_class_tokens) | ||
| else: | ||
| patch_class_tokens = checkpoint(block, patch_class_tokens, use_reentrant=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The checkpoint should be on a coarser level if possible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did not understand this comment :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The checkpoint should wrap the call to this forward function, not be in the forward function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did not understand this comment :)
The checkpoint should wrap the call to this forward function, not be in the forward function
| self.patch_token = patch_token | ||
| # For now this is a Linear Layer TBD what this architecture should be | ||
| self.layer = nn.Linear(in_dim, out_dim, bias=False) | ||
| self.layer = MLP(in_dim, out_dim, num_layers, hidden_factor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.layer -> self.blocks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
| print("Happy to be here") | ||
| batch.to_device(self.device) | ||
|
|
||
| print("Batch to device") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove
| self.training_cfg.window_offset_prediction, | ||
| ) | ||
|
|
||
| print("Model predictions") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
| self.model, | ||
| self.training_cfg.window_offset_prediction, | ||
| ) | ||
| print("target predictions") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
| targets_and_aux=targets_and_auxs, | ||
| metadata=extract_batch_metadata(batch), | ||
| ) | ||
| print("loss calcuclation") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
removed more print() statements
| return torch.device("cpu") | ||
|
|
||
| local_id_node = os.environ.get("SLURM_LOCALID", "-1") | ||
| local_id_node = os.environ.get("LOCAL_RANK", os.environ.get("SLURM_LOCALID", "-1")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be torch.distributed.get_local_rank(), unless someone can explain why it's not suitable here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure why this even changed...
Using dist.get_node_local_rank (fallback_rank=-1) instead
Description
See PR #1590
Issue Number
Closes #1587
Is this PR a draft? Mark it as draft.
Checklist before asking for review
./scripts/actions.sh lint./scripts/actions.sh unit-test./scripts/actions.sh integration-testlaunch-slurm.py --time 60