Spin/Charge System Conditioning#1080
Spin/Charge System Conditioning#1080JonathanSchmidt1 wants to merge 40 commits intometatensor:mainfrom
Conversation
Adds an `extra_data_options` parameter to `MemmapDataset` that loads per-system scalar arrays from `.bin` files alongside the training targets. Each key (e.g. `mtt::charge`) maps to a `TensorMap` in the sample namedtuple and is forwarded to the `extra` argument of `CollateFn` callables. Also adds `get_extra_data_info()` to expose `TargetInfo` metadata for the extra_data keys, mirroring `get_target_info()`. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Drop stale system.add_data() calls for charge/spin from MemmapDataset.__getitem__
(data now flows through extra_data_options + get_system_conditioning_transform)
- Remove charge/spin fields from SystemsHypers in base_hypers.py
(config now lives under extra_data: {mtt::charge: {key: ...}})
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…rade Old muon2 checkpoints already contain system_conditioning.* weights. Setting system_conditioning=False was dropping them silently. Now the upgrade checks for the presence of those weights and enables the hyper automatically, so converted checkpoints use the embedding they were trained with. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
or per-atom extra data is never passed to the system transform - system_data: raise if a TensorMap is per-atom (samples != ["system"]) rather than silently misindexing into the systems list - model: validate charge/spin are integer-valued before .long() conversion; log.debug when a system falls back to default 0/1 - conditioning: document zero-init gate design intent
…add test that confirms eval is working with spin/charge
|
What's the policy now on how to update the classifier/llpr checkpoints? |
|
I updated the checkpoints so the only test that is still failing will require the metatomic part to be merged (we could also switch that to metatomic but I I think it's good to have it here so we will notice directly if we mess up the interaction with metatomic) |
pfebrer
left a comment
There was a problem hiding this comment.
The failing test, isn't it a test that should go to metatomic?
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
There was a problem hiding this comment.
very interesting 😊 thanks a lot for implementation!
i left some comments. would be nice to fix the use of "charge"/"spin" and "mtt::charge"/"mtt::spin" (those metatomic PRs should be merged first I suppose? metatensor/metatomic#183 metatensor/metatomic#189) to make sure the tests verify not the default values but actual inputs
|
|
||
| :param checkpoint: The checkpoint to update. | ||
| """ | ||
| import logging |
There was a problem hiding this comment.
would be nice to have all imports at the top
|
|
||
|
|
||
| @with_config(ConfigDict(extra="forbid", strict=True)) | ||
| class SystemDataKeyHypers(TypedDict): |
There was a problem hiding this comment.
seems to be not used anywhere
| the range ``[1, max_spin]``. | ||
| """ | ||
|
|
||
| required_data_keys: List[str] = ["mtt::charge", "mtt::spin"] |
There was a problem hiding this comment.
why mtt:: here? in all other places you use simply "charge" and "spin". that's why CI job failes i guess
| objects describing each per-system scalar array. | ||
| """ | ||
| extra_data_info_dict: Dict[str, TargetInfo] = {} | ||
| if not self.extra_data_config: |
There was a problem hiding this comment.
it seems we don't initialize self.extra_data_config anywhere in the DiskDataset
| checkpoint["model_data"]["model_hypers"]["max_charge"] = 10 | ||
| if "max_spin" not in checkpoint["model_data"]["model_hypers"]: | ||
| checkpoint["model_data"]["model_hypers"]["max_spin"] = 10 | ||
| # Rename edge_linear -> edge_embedder (muon2 branch used edge_linear) |
There was a problem hiding this comment.
is muon2 branch merged? or why we need to rename it?
| self.system_conditioning( | ||
| inputs["charge"], | ||
| inputs["spin"], | ||
| inputs["system_indices"], | ||
| ) |
There was a problem hiding this comment.
we currently call self.system_conditioning(...) with the same inputs at every gnn layer => output is the same => we can move it outside of the loop over gnn layers and then simply do:
output_node_embeddings = output_node_embeddings + cond_embedding
same for _residual_featurization_impl
| # the `extra` argument of CollateFn callables | ||
| extra_data_dict = {} | ||
| for key, arr in self.extra_data_arrays.items(): | ||
| is_per_atom = arr.shape[0] == self.na[-1] |
There was a problem hiding this comment.
looks fragile, it is possible to have a case when the total number of atoms equals the number of systems... we should check ["per_atom"] from self.extra_data_config here, seems to be available
| # Extract per-system charge and spin for conditioning | ||
| if self.system_conditioning is not None: |
There was a problem hiding this comment.
i would move extraction outside to a separate function to remain the forward function readable!
| def validate(self, charge: torch.Tensor, spin: torch.Tensor) -> None: | ||
| """Check that charge and spin values are within the supported range. | ||
|
|
||
| Call this outside of ``torch.compile`` regions to get descriptive errors. |
There was a problem hiding this comment.
validate is called inside scripted forward, have you tried to call the exported (mtt export) model? i am just wondering to make sure the validate call survives export
| collate_fn_train = CollateFn( | ||
| target_keys=list(train_targets.keys()), | ||
| callables=[ | ||
| rotational_augmenter.apply_random_augmentations, | ||
| get_system_with_neighbor_lists_transform(requested_neighbor_lists), | ||
| *conditioning_callables, | ||
| get_remove_additive_transform(additive_models, train_targets), | ||
| get_remove_scale_transform(scaler), | ||
| ], | ||
| batch_atom_bounds=self.hypers["batch_atom_bounds"], | ||
| ) | ||
| collate_fn_val = CollateFn( | ||
| target_keys=list(train_targets.keys()), | ||
| callables=[ # no augmentation for validation | ||
| get_system_with_neighbor_lists_transform(requested_neighbor_lists), | ||
| *conditioning_callables, | ||
| get_remove_additive_transform(additive_models, train_targets), | ||
| get_remove_scale_transform(scaler), | ||
| ], | ||
| batch_atom_bounds=self.hypers["batch_atom_bounds"], | ||
| ) |
There was a problem hiding this comment.
sorry it gives me urge for a small refactor 😁
| base_callables = [ | |
| get_system_with_neighbor_lists_transform(requested_neighbor_lists), | |
| *conditioning_callables, | |
| get_remove_additive_transform(additive_models, train_targets), | |
| get_remove_scale_transform(scaler), | |
| ] | |
| collate_fn_train = CollateFn( | |
| target_keys=list(train_targets.keys()), | |
| callables=[rotational_augmenter.apply_random_augmentations, *base_callables], | |
| batch_atom_bounds=self.hypers["batch_atom_bounds"], | |
| ) | |
| collate_fn_val = CollateFn( | |
| target_keys=list(train_targets.keys()), | |
| callables=base_callables, | |
| batch_atom_bounds=self.hypers["batch_atom_bounds"], | |
| ) |
…oning-rebased # Conflicts: # pyproject.toml # src/metatrain/pet/trainer.py # tests/utils/data/test_dataset.py
…ntity Metatomic's standard per-system input name is `spin_multiplicity` (`spin` is only the short ASE info key on the calculator side). To make exported PET models pluggable into MetatomicCalculator without an extra prefix, rename throughout: `required_data_keys`, `system.get_data` reads, the hyperparameter `max_spin` → `max_spin_multiplicity`, the internal embedding attribute `spin_embedding` → `spin_multiplicity_embedding`, validate/forward parameter names, the test fixtures, and the MemmapDataset extra_data example. Also restore the rename of `required_data_keys` from `mtt::charge`/`mtt::spin` (which had been accidentally reverted) to the unprefixed standard names. Bump the model checkpoint version to 13 and add `model_update_v12_v13` that renames the `max_spin` hyperparameter and the `system_conditioning.spin_embedding.*` state-dict keys so existing v12 checkpoints continue to load. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
Let's take a look at this only once charge/spin are fully merged into metatomic |
- checkpoints.py: hoist `import logging` to module top.
- share/base_hypers.py: drop unused `SystemDataKeyHypers` TypedDict.
- utils/data/dataset.py:
* initialize `self.extra_data_config = {}` in `DiskDataset.__init__`
so `get_extra_data_info()` no longer AttributeErrors when called
on a disk dataset that did not set it externally.
* inside `MemmapDataset.__getitem__`, replace the fragile
`arr.shape[0] == self.na[-1]` per-atom heuristic with the explicit
`self.extra_data_config[key]["per_atom"]` flag, which is unambiguous
when n_atoms == n_systems.
- pet/model.py: hoist `system_conditioning(...)` out of both featurization
GNN loops. Inputs (charge, spin_multiplicity, system_indices) are
loop-invariant, so compute the embedding once and add it inside the loop.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Round 2 of metatensor#1080 review (sofiia-chorna): - pet/model.py: extract per-system charge/spin_multiplicity reading from `forward()` into a module-level helper `_extract_charge_spin_multiplicity` so the core forward stays readable. - pet/trainer.py: factor out `base_callables` shared between the train and validation `CollateFn`s. Train just prepends `rotational_augmenter` to the base list. - pet/tests/test_conditioning.py: add `test_export_with_conditioning_preserves_validate` to regression-test that the in-forward `validate(...)` call survives TorchScript compilation in `model.export()`. Drives the saved model end-to-end via `MetatomicCalculator` and checks both the valid path and the out-of-range error path. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
System conditioning for PET (charge & spin)
Adds per-system charge and spin multiplicity conditioning to the PET architecture,
allowing a single model to be trained and evaluated across multiple charge and spin
states. The feature is activated through
architecture.model.system_conditioning: true.System Conditioning is separated into its own
SystemConditioningEmbeddingmodule (pet/modules/conditioning.py). The resulting embedding is added toPET's node features via a zero-initialised gated projection, so the model starts as
the unconditioned baseline and learns to use charge/spin information only as needed.
Charge and spin are supplied as
mtt::charge(integer, elementary charges) andmtt::spin(integer, spin multiplicity 2S+1) in theextra_datasection of thedataset config or via
atoms.infoin ASE (requires merging of a PR into metatrain).Changes
New files
src/metatrain/pet/modules/conditioning.py—SystemConditioningEmbeddingmoduleand
get_system_conditioning_transform(re-exported fromutils/system_data.py)src/metatrain/utils/system_data.py— genericget_system_data_transformcallablefor attaching per-system scalar TensorMaps to
Systemobjects in aCollateFnsrc/metatrain/pet/tests/test_conditioning.py— test suite for the featurePET model (
pet/model.py)system_conditioninghyper; if enabled, buildsSystemConditioningEmbeddingand injects the embedding into node features during both initial featurisation and
residual updates
mtt::charge/mtt::spininrequested_inputs()so the exported modelcommunicates its requirements to downstream tools (ASE calculator, eval pipeline)
Training (
pet/trainer.py)model.system_conditioning.required_data_keysand registersget_system_conditioning_transformas aCollateFncallable so charge/spin areattached to
Systemobjects during trainingCheckpoint upgrade (
pet/checkpoints.py)system_conditioning.*weights in thestate dict to auto-enable the hyper for checkpoints trained with conditioning
(avoids silent neutral-singlet fallback when loading old muon-branch checkpoints)
Eval (
cli/eval.py+utils/system_data.py)mtt evalnow readsextra_datafrom the dataset config and routes any keyspresent in the model's
requested_inputs()throughget_system_data_transform,so charge/spin reach the model during evaluation
preventing silent index errors on mixed datasets
Hypers (
pet/documentation.py,share/base_hypers.py)system_conditioningblock inModelHypers:system_conditioning: bool,max_charge: int = 10,max_spin: int = 10Training config example
Contributor (creator of pull-request) checklist
Maintainer/Reviewer checklist
📚 Documentation preview 📚: https://metatrain--1080.org.readthedocs.build/en/1080/