Skip to content

Spin/Charge System Conditioning#1080

Open
JonathanSchmidt1 wants to merge 40 commits intometatensor:mainfrom
JonathanSchmidt1:only-system-conditioning-rebased
Open

Spin/Charge System Conditioning#1080
JonathanSchmidt1 wants to merge 40 commits intometatensor:mainfrom
JonathanSchmidt1:only-system-conditioning-rebased

Conversation

@JonathanSchmidt1
Copy link
Copy Markdown
Contributor

@JonathanSchmidt1 JonathanSchmidt1 commented Mar 22, 2026

System conditioning for PET (charge & spin)

Adds per-system charge and spin multiplicity conditioning to the PET architecture,
allowing a single model to be trained and evaluated across multiple charge and spin
states. The feature is activated through architecture.model.system_conditioning: true.
System Conditioning is separated into its own SystemConditioningEmbedding module (pet/modules/conditioning.py). The resulting embedding is added to
PET's node features via a zero-initialised gated projection, so the model starts as
the unconditioned baseline and learns to use charge/spin information only as needed.

Charge and spin are supplied as mtt::charge (integer, elementary charges) and
mtt::spin (integer, spin multiplicity 2S+1) in the extra_data section of the
dataset config or via atoms.info in ASE (requires merging of a PR into metatrain).

Changes

New files

  • src/metatrain/pet/modules/conditioning.pySystemConditioningEmbedding module
    and get_system_conditioning_transform (re-exported from utils/system_data.py)
  • src/metatrain/utils/system_data.py — generic get_system_data_transform callable
    for attaching per-system scalar TensorMaps to System objects in a CollateFn
  • src/metatrain/pet/tests/test_conditioning.py — test suite for the feature

PET model (pet/model.py)

  • Adds system_conditioning hyper; if enabled, builds SystemConditioningEmbedding
    and injects the embedding into node features during both initial featurisation and
    residual updates
  • Declares mtt::charge / mtt::spin in requested_inputs() so the exported model
    communicates its requirements to downstream tools (ASE calculator, eval pipeline)

Training (pet/trainer.py)

  • Reads model.system_conditioning.required_data_keys and registers
    get_system_conditioning_transform as a CollateFn callable so charge/spin are
    attached to System objects during training

Checkpoint upgrade (pet/checkpoints.py)

  • v11 → v12 upgrade detects the presence of system_conditioning.* weights in the
    state dict to auto-enable the hyper for checkpoints trained with conditioning
    (avoids silent neutral-singlet fallback when loading old muon-branch checkpoints)

Eval (cli/eval.py + utils/system_data.py)

  • mtt eval now reads extra_data from the dataset config and routes any keys
    present in the model's requested_inputs() through get_system_data_transform,
    so charge/spin reach the model during evaluation
  • The transform raises early if a TensorMap is per-atom rather than per-system,
    preventing silent index errors on mixed datasets

Hypers (pet/documentation.py, share/base_hypers.py)

  • New system_conditioning block in ModelHypers:
    system_conditioning: bool, max_charge: int = 10, max_spin: int = 10

Training config example

architecture:
  model:
    system_conditioning: true
    max_charge: 5   # embeds charges in [-5, +5]
    max_spin: 5     # embeds multiplicities in [1, 5]

training_set:
  - path: dataset.mtt
    extra_data:
      mtt::charge:
        field: charge
      mtt::spin:
        field: spin

Contributor (creator of pull-request) checklist

  • Tests updated (for new features and bugfixes)?
  • Documentation updated (for new features)?
  • Issue referenced (for PRs that solve an issue)?

Maintainer/Reviewer checklist

  • CHANGELOG updated with public API or any other important changes?
  • GPU tests passed (maintainer comment: "cscs-ci run")?

📚 Documentation preview 📚: https://metatrain--1080.org.readthedocs.build/en/1080/

JonathanSchmidt1 and others added 26 commits March 19, 2026 13:36
Adds an `extra_data_options` parameter to `MemmapDataset` that loads
per-system scalar arrays from `.bin` files alongside the training targets.
Each key (e.g. `mtt::charge`) maps to a `TensorMap` in the sample namedtuple
and is forwarded to the `extra` argument of `CollateFn` callables.

Also adds `get_extra_data_info()` to expose `TargetInfo` metadata for the
extra_data keys, mirroring `get_target_info()`.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Drop stale system.add_data() calls for charge/spin from MemmapDataset.__getitem__
  (data now flows through extra_data_options + get_system_conditioning_transform)
- Remove charge/spin fields from SystemsHypers in base_hypers.py
  (config now lives under extra_data: {mtt::charge: {key: ...}})

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…rade

Old muon2 checkpoints already contain system_conditioning.* weights.
Setting system_conditioning=False was dropping them silently. Now the
upgrade checks for the presence of those weights and enables the hyper
automatically, so converted checkpoints use the embedding they were
trained with.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
  or per-atom extra data is never passed to the system transform
- system_data: raise if a TensorMap is per-atom (samples != ["system"])
  rather than silently misindexing into the systems list
- model: validate charge/spin are integer-valued before .long()
  conversion; log.debug when a system falls back to default 0/1
- conditioning: document zero-init gate design intent
…add test that confirms eval is working with spin/charge
@JonathanSchmidt1
Copy link
Copy Markdown
Contributor Author

What's the policy now on how to update the classifier/llpr checkpoints?

@JonathanSchmidt1
Copy link
Copy Markdown
Contributor Author

I updated the checkpoints so the only test that is still failing will require the metatomic part to be merged (we could also switch that to metatomic but I I think it's good to have it here so we will notice directly if we mess up the interaction with metatomic)

Copy link
Copy Markdown
Contributor

@pfebrer pfebrer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The failing test, isn't it a test that should go to metatomic?

Comment thread src/metatrain/cli/eval.py Outdated
Copy link
Copy Markdown
Collaborator

@sofiia-chorna sofiia-chorna left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very interesting 😊 thanks a lot for implementation!

i left some comments. would be nice to fix the use of "charge"/"spin" and "mtt::charge"/"mtt::spin" (those metatomic PRs should be merged first I suppose? metatensor/metatomic#183 metatensor/metatomic#189) to make sure the tests verify not the default values but actual inputs

Comment thread src/metatrain/pet/checkpoints.py Outdated

:param checkpoint: The checkpoint to update.
"""
import logging
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be nice to have all imports at the top

Comment thread src/metatrain/share/base_hypers.py Outdated


@with_config(ConfigDict(extra="forbid", strict=True))
class SystemDataKeyHypers(TypedDict):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems to be not used anywhere

the range ``[1, max_spin]``.
"""

required_data_keys: List[str] = ["mtt::charge", "mtt::spin"]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why mtt:: here? in all other places you use simply "charge" and "spin". that's why CI job failes i guess

objects describing each per-system scalar array.
"""
extra_data_info_dict: Dict[str, TargetInfo] = {}
if not self.extra_data_config:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems we don't initialize self.extra_data_config anywhere in the DiskDataset

checkpoint["model_data"]["model_hypers"]["max_charge"] = 10
if "max_spin" not in checkpoint["model_data"]["model_hypers"]:
checkpoint["model_data"]["model_hypers"]["max_spin"] = 10
# Rename edge_linear -> edge_embedder (muon2 branch used edge_linear)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is muon2 branch merged? or why we need to rename it?

Comment thread src/metatrain/pet/model.py Outdated
Comment on lines +684 to +688
self.system_conditioning(
inputs["charge"],
inputs["spin"],
inputs["system_indices"],
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we currently call self.system_conditioning(...) with the same inputs at every gnn layer => output is the same => we can move it outside of the loop over gnn layers and then simply do:

output_node_embeddings = output_node_embeddings + cond_embedding

same for _residual_featurization_impl

Comment thread src/metatrain/utils/data/dataset.py Outdated
# the `extra` argument of CollateFn callables
extra_data_dict = {}
for key, arr in self.extra_data_arrays.items():
is_per_atom = arr.shape[0] == self.na[-1]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks fragile, it is possible to have a case when the total number of atoms equals the number of systems... we should check ["per_atom"] from self.extra_data_config here, seems to be available

Comment thread src/metatrain/pet/model.py Outdated
Comment on lines +466 to +467
# Extract per-system charge and spin for conditioning
if self.system_conditioning is not None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i would move extraction outside to a separate function to remain the forward function readable!

def validate(self, charge: torch.Tensor, spin: torch.Tensor) -> None:
"""Check that charge and spin values are within the supported range.

Call this outside of ``torch.compile`` regions to get descriptive errors.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

validate is called inside scripted forward, have you tried to call the exported (mtt export) model? i am just wondering to make sure the validate call survives export

Comment on lines 235 to 255
collate_fn_train = CollateFn(
target_keys=list(train_targets.keys()),
callables=[
rotational_augmenter.apply_random_augmentations,
get_system_with_neighbor_lists_transform(requested_neighbor_lists),
*conditioning_callables,
get_remove_additive_transform(additive_models, train_targets),
get_remove_scale_transform(scaler),
],
batch_atom_bounds=self.hypers["batch_atom_bounds"],
)
collate_fn_val = CollateFn(
target_keys=list(train_targets.keys()),
callables=[ # no augmentation for validation
get_system_with_neighbor_lists_transform(requested_neighbor_lists),
*conditioning_callables,
get_remove_additive_transform(additive_models, train_targets),
get_remove_scale_transform(scaler),
],
batch_atom_bounds=self.hypers["batch_atom_bounds"],
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry it gives me urge for a small refactor 😁

Suggested change
base_callables = [
get_system_with_neighbor_lists_transform(requested_neighbor_lists),
*conditioning_callables,
get_remove_additive_transform(additive_models, train_targets),
get_remove_scale_transform(scaler),
]
collate_fn_train = CollateFn(
target_keys=list(train_targets.keys()),
callables=[rotational_augmenter.apply_random_augmentations, *base_callables],
batch_atom_bounds=self.hypers["batch_atom_bounds"],
)
collate_fn_val = CollateFn(
target_keys=list(train_targets.keys()),
callables=base_callables,
batch_atom_bounds=self.hypers["batch_atom_bounds"],
)

JonathanSchmidt1 and others added 2 commits April 28, 2026 11:37
…oning-rebased

# Conflicts:
#	pyproject.toml
#	src/metatrain/pet/trainer.py
#	tests/utils/data/test_dataset.py
…ntity

Metatomic's standard per-system input name is `spin_multiplicity` (`spin`
is only the short ASE info key on the calculator side). To make exported
PET models pluggable into MetatomicCalculator without an extra prefix,
rename throughout: `required_data_keys`, `system.get_data` reads, the
hyperparameter `max_spin` → `max_spin_multiplicity`, the internal embedding
attribute `spin_embedding` → `spin_multiplicity_embedding`, validate/forward
parameter names, the test fixtures, and the MemmapDataset extra_data
example. Also restore the rename of `required_data_keys` from
`mtt::charge`/`mtt::spin` (which had been accidentally reverted) to the
unprefixed standard names.

Bump the model checkpoint version to 13 and add `model_update_v12_v13`
that renames the `max_spin` hyperparameter and the
`system_conditioning.spin_embedding.*` state-dict keys so existing
v12 checkpoints continue to load.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@JonathanSchmidt1
Copy link
Copy Markdown
Contributor Author

Let's take a look at this only once charge/spin are fully merged into metatomic

JonathanSchmidt1 and others added 2 commits April 28, 2026 12:08
- checkpoints.py: hoist `import logging` to module top.
- share/base_hypers.py: drop unused `SystemDataKeyHypers` TypedDict.
- utils/data/dataset.py:
    * initialize `self.extra_data_config = {}` in `DiskDataset.__init__`
      so `get_extra_data_info()` no longer AttributeErrors when called
      on a disk dataset that did not set it externally.
    * inside `MemmapDataset.__getitem__`, replace the fragile
      `arr.shape[0] == self.na[-1]` per-atom heuristic with the explicit
      `self.extra_data_config[key]["per_atom"]` flag, which is unambiguous
      when n_atoms == n_systems.
- pet/model.py: hoist `system_conditioning(...)` out of both featurization
  GNN loops. Inputs (charge, spin_multiplicity, system_indices) are
  loop-invariant, so compute the embedding once and add it inside the loop.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Round 2 of metatensor#1080 review (sofiia-chorna):

- pet/model.py: extract per-system charge/spin_multiplicity reading from
  `forward()` into a module-level helper `_extract_charge_spin_multiplicity`
  so the core forward stays readable.
- pet/trainer.py: factor out `base_callables` shared between the train and
  validation `CollateFn`s. Train just prepends `rotational_augmenter` to the
  base list.
- pet/tests/test_conditioning.py: add `test_export_with_conditioning_preserves_validate`
  to regression-test that the in-forward `validate(...)` call survives
  TorchScript compilation in `model.export()`. Drives the saved model end-to-end
  via `MetatomicCalculator` and checks both the valid path and the
  out-of-range error path.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants