Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,10 @@ frozen_model.*

# Test system directories
system/

temp/
test_mptraj/
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Ignoring test_mptraj/ conflicts with tracked files in this PR.

This PR commits test_mptraj/lmdb_baseline.json and test_mptraj/lmdb_mixed_batch.json under test_mptraj/, but the new test_mptraj/ ignore rule will hide any future additions/edits in that directory and forces git add -f. Either drop the directory-level ignore or add explicit negations for the tracked configs.

🛠️ Suggested fix
 temp/
-test_mptraj/
+test_mptraj/*
+!test_mptraj/*.json
+!test_mptraj/.gitignore
 pkl/
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
test_mptraj/
test_mptraj/*
!test_mptraj/*.json
!test_mptraj/.gitignore
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In @.gitignore at line 78, The .gitignore entry "test_mptraj/" conflicts with
tracked files committed in this PR (test_mptraj/lmdb_baseline.json and
test_mptraj/lmdb_mixed_batch.json); fix by either removing the directory-level
ignore line "test_mptraj/" or by adding explicit negation patterns for the
tracked files (e.g., "!test_mptraj/lmdb_baseline.json" and
"!test_mptraj/lmdb_mixed_batch.json") so future edits to those config files are
not hidden and git add won't require -f; update .gitignore accordingly and
ensure the two committed files remain tracked.

pkl/
history/
deepmd-kit/
*.hdf5
4 changes: 4 additions & 0 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,21 +167,25 @@ def _make_dp_loader_set(
# LMDB path: single string → LmdbDataset
if isinstance(training_systems, str) and is_lmdb(training_systems):
auto_prob = training_dataset_params.get("auto_prob", None)
mixed_batch = training_dataset_params.get("mixed_batch", False)
train_data_single = LmdbDataset(
training_systems,
model_params_single["type_map"],
training_dataset_params["batch_size"],
mixed_batch=mixed_batch,
auto_prob_style=auto_prob,
)
if (
validation_systems is not None
and isinstance(validation_systems, str)
and is_lmdb(validation_systems)
):
val_mixed_batch = validation_dataset_params.get("mixed_batch", False)
validation_data_single = LmdbDataset(
validation_systems,
model_params_single["type_map"],
validation_dataset_params["batch_size"],
mixed_batch=val_mixed_batch,
)
elif validation_systems is not None:
validation_data_single = _make_dp_loader_set(
Expand Down
15 changes: 14 additions & 1 deletion deepmd/pt/loss/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,20 @@ def forward(
more_loss = {}
# more_loss['log_keys'] = [] # showed when validation on the fly
# more_loss['test_keys'] = [] # showed when doing dp test
atom_norm = 1.0 / natoms

# Detect mixed batch format
is_mixed_batch = "ptr" in input_dict and input_dict["ptr"] is not None

# For mixed batch, compute per-frame atom_norm and average
if is_mixed_batch:
ptr = input_dict["ptr"]
nframes = ptr.numel() - 1
# Compute natoms for each frame
natoms_per_frame = ptr[1:] - ptr[:-1] # [nframes]
# Average atom_norm across frames
atom_norm = torch.mean(1.0 / natoms_per_frame.float())
else:
atom_norm = 1.0 / natoms
Comment on lines +228 to +240
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

This mixed-batch normalization misweights frames with different atom counts.

l2_ener_loss/l2_virial_loss are reduced over frames first, then multiplied by mean(1 / N_i). That is not the same as applying each frame’s own 1 / N_i normalization before the reduction, so mixed batches with small and large frames get the wrong loss weighting.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/pt/loss/ener.py` around lines 228 - 240, The mixed-batch normalization
currently computes a single scalar atom_norm via atom_norm = mean(1.0 /
natoms_per_frame) which misweights frames with different atom counts; instead
compute per-frame normalizers and apply them before reducing l2_ener_loss and
l2_virial_loss: when is_mixed_batch (detect via ptr and natoms_per_frame derived
from ptr), build per-frame atom_norms = 1.0 / natoms_per_frame.float() and
multiply each frame's loss by its corresponding atom_norm before taking the
final mean/sum (or perform a weighted reduction using these per-frame weights)
so that l2_ener_loss and l2_virial_loss are normalized per-frame correctly
rather than using a single averaged scalar.

if self.has_e and "energy" in model_pred and "energy" in label:
energy_pred = model_pred["energy"]
energy_label = label["energy"]
Expand Down
131 changes: 131 additions & 0 deletions deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,137 @@ def forward_atomic(
)
return fit_ret

def forward_common_atomic_flat(
self,
extended_coord: torch.Tensor,
extended_atype: torch.Tensor,
extended_batch: torch.Tensor,
nlist: torch.Tensor,
mapping: torch.Tensor,
batch: torch.Tensor,
ptr: torch.Tensor,
fparam: torch.Tensor | None = None,
aparam: torch.Tensor | None = None,
extended_ptr: torch.Tensor | None = None,
central_ext_index: torch.Tensor | None = None,
nlist_ext: torch.Tensor | None = None,
a_nlist: torch.Tensor | None = None,
a_nlist_ext: torch.Tensor | None = None,
nlist_mask: torch.Tensor | None = None,
a_nlist_mask: torch.Tensor | None = None,
edge_index: torch.Tensor | None = None,
angle_index: torch.Tensor | None = None,
) -> dict[str, torch.Tensor]:
"""Forward pass with flat batch format.

Parameters
----------
extended_coord : torch.Tensor
Extended coordinates [total_extended_atoms, 3].
extended_atype : torch.Tensor
Extended atom types [total_extended_atoms].
extended_batch : torch.Tensor
Frame assignment for extended atoms [total_extended_atoms].
nlist : torch.Tensor
Neighbor list [total_atoms, nnei].
mapping : torch.Tensor
Extended atom -> local flat index mapping [total_extended_atoms].
batch : torch.Tensor
Frame assignment for local atoms [total_atoms].
ptr : torch.Tensor
Frame boundaries [nframes + 1].
fparam : torch.Tensor | None
Frame parameters [nframes, ndf].
aparam : torch.Tensor | None
Atomic parameters [total_atoms, nda].
central_ext_index : torch.Tensor | None
Extended-atom indices corresponding to local atoms.
nlist_ext, a_nlist_ext : torch.Tensor | None
Edge and angle neighbor lists indexing concatenated extended atoms.
nlist_mask, a_nlist_mask : torch.Tensor | None
Valid-neighbor masks for flat edge and angle neighbor lists.
edge_index, angle_index : torch.Tensor | None
Dynamic graph indices produced by the flat graph preprocessor.

Returns
-------
result_dict : dict[str, torch.Tensor]
Model predictions in flat format.
"""
if self.do_grad_r() or self.do_grad_c():
extended_coord.requires_grad_(True)

# Descriptor and fitting both consume the flat atom layout.
descriptor_out = self.descriptor.forward_flat(
extended_coord,
extended_atype,
extended_batch,
nlist,
mapping,
batch,
ptr,
fparam=fparam if self.add_chg_spin_ebd else None,
Comment on lines +378 to +390
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Mirror the default-fparam fallback in the flat path.

forward_atomic() synthesizes get_default_fparam() when the fitting net expects frame parameters and the caller omits fparam. This flat path passes None straight into descriptor.forward_flat(), so mixed-batch runs lose that fallback and will diverge from the non-flat path for models using default charge/spin frame params.

Suggested fix
         if self.do_grad_r() or self.do_grad_c():
             extended_coord.requires_grad_(True)
 
+        if (
+            hasattr(self.fitting_net, "get_dim_fparam")
+            and self.fitting_net.get_dim_fparam() > 0
+            and fparam is None
+        ):
+            default_fparam_tensor = self.fitting_net.get_default_fparam()
+            assert default_fparam_tensor is not None
+            fparam_input_for_des = torch.tile(
+                default_fparam_tensor.unsqueeze(0), [ptr.numel() - 1, 1]
+            )
+        else:
+            fparam_input_for_des = fparam
+
         # Descriptor and fitting both consume the flat atom layout.
         descriptor_out = self.descriptor.forward_flat(
             extended_coord,
@@
-            fparam=fparam if self.add_chg_spin_ebd else None,
+            fparam=fparam_input_for_des if self.add_chg_spin_ebd else None,
             central_ext_index=central_ext_index,
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/pt/model/atomic_model/dp_atomic_model.py` around lines 378 - 390, The
flat-path call to descriptor.forward_flat currently passes fparam unmodified and
thus skips the fallback logic used in forward_atomic; update the flat-path
before calling descriptor.forward_flat (the block around
descriptor.forward_flat) to mirror forward_atomic's behavior: if
self.add_chg_spin_ebd (or the equivalent flag) is True and fparam is None,
compute the default via self.get_default_fparam(...) using the same inputs/shape
logic used in forward_atomic and assign it to fparam, then pass that fparam into
descriptor.forward_flat so mixed-batch runs use the same default frame
parameters.

central_ext_index=central_ext_index,
nlist_ext=nlist_ext,
a_nlist=a_nlist,
a_nlist_ext=a_nlist_ext,
nlist_mask=nlist_mask,
a_nlist_mask=a_nlist_mask,
edge_index=edge_index,
angle_index=angle_index,
)

descriptor = descriptor_out.get("descriptor")
rot_mat = descriptor_out.get("rot_mat")
g2 = descriptor_out.get("g2")
h2 = descriptor_out.get("h2")

if self.enable_eval_descriptor_hook:
self.eval_descriptor_list.append(descriptor.detach())

if central_ext_index is None:
from deepmd.pt.utils.nlist import get_central_ext_index

central_ext_index = get_central_ext_index(extended_batch, ptr)
atype = extended_atype[central_ext_index]
else:
atype = extended_atype[central_ext_index]

fit_ret = self.fitting_net.forward_flat(
descriptor,
atype,
batch,
ptr,
gr=rot_mat,
g2=g2,
h2=h2,
fparam=fparam,
aparam=aparam,
)
fit_ret = self.apply_out_stat(fit_ret, atype)

atom_mask = self.make_atom_mask(atype).to(torch.int32)
if self.atom_excl is not None:
atom_mask *= self.atom_excl(atype.unsqueeze(0)).squeeze(0)

for kk in fit_ret.keys():
out_shape = fit_ret[kk].shape
out_shape2 = 1
for ss in out_shape[1:]:
out_shape2 *= ss
fit_ret[kk] = (
fit_ret[kk].reshape([out_shape[0], out_shape2]) * atom_mask[:, None]
).view(out_shape)
fit_ret["mask"] = atom_mask

if self.enable_eval_fitting_last_layer_hook:
if "middle_output" in fit_ret:
self.eval_fitting_last_layer_list.append(
fit_ret.pop("middle_output").detach()
)

return fit_ret

def compute_or_load_stat(
self,
sampled_func: Callable[[], list[dict]],
Expand Down
126 changes: 126 additions & 0 deletions deepmd/pt/model/descriptor/dpa3.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,132 @@ def forward(
sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION) if sw is not None else None,
)

def forward_flat(
self,
extended_coord: torch.Tensor,
extended_atype: torch.Tensor,
extended_batch: torch.Tensor,
nlist: torch.Tensor,
mapping: torch.Tensor,
batch: torch.Tensor,
ptr: torch.Tensor,
fparam: torch.Tensor | None = None,
central_ext_index: torch.Tensor | None = None,
nlist_ext: torch.Tensor | None = None,
a_nlist: torch.Tensor | None = None,
a_nlist_ext: torch.Tensor | None = None,
nlist_mask: torch.Tensor | None = None,
a_nlist_mask: torch.Tensor | None = None,
edge_index: torch.Tensor | None = None,
angle_index: torch.Tensor | None = None,
) -> dict[str, torch.Tensor]:
"""Compute the descriptor with flat batch format.

Parameters
----------
extended_coord : torch.Tensor
Extended coordinates [total_extended_atoms, 3].
extended_atype : torch.Tensor
Extended atom types [total_extended_atoms].
extended_batch : torch.Tensor
Frame assignment for extended atoms [total_extended_atoms].
nlist : torch.Tensor
Neighbor list [total_atoms, nnei].
mapping : torch.Tensor
Extended atom -> local flat index mapping [total_extended_atoms].
batch : torch.Tensor
Frame assignment for local atoms [total_atoms].
ptr : torch.Tensor
Frame boundaries [nframes + 1].
fparam : torch.Tensor | None
Frame parameters [nframes, ndf].
central_ext_index : torch.Tensor | None
Extended-atom indices corresponding to local atoms.
nlist_ext, a_nlist_ext : torch.Tensor | None
Edge and angle neighbor lists indexing concatenated extended atoms.
nlist_mask, a_nlist_mask : torch.Tensor | None
Valid-neighbor masks for flat edge and angle neighbor lists.
edge_index, angle_index : torch.Tensor | None
Dynamic graph indices produced by the flat graph preprocessor.

Returns
-------
result : dict[str, torch.Tensor]
Dictionary containing:
- 'descriptor': [total_atoms, descriptor_dim]
- 'rot_mat': [total_atoms, e_dim, 3] or None
- 'g2': edge embedding or None
- 'h2': pair representation or None
"""
extended_coord = extended_coord.to(dtype=self.prec)

# Flat batches embed all extended atoms, then gather central atoms.
node_ebd_ext = self.type_embedding(
extended_atype
) # [total_extended_atoms, tebd_dim]

if self.add_chg_spin_ebd:
assert fparam is not None
assert self.chg_embedding is not None
assert self.spin_embedding is not None

# Expand frame-level charge/spin parameters to extended atoms.
charge = fparam[extended_batch, 0].to(dtype=torch.int64) + 100
spin = fparam[extended_batch, 1].to(dtype=torch.int64)
chg_ebd = self.chg_embedding(charge)
spin_ebd = self.spin_embedding(spin)
sys_cs_embd = self.act(
self.mix_cs_mlp(torch.cat((chg_ebd, spin_ebd), dim=-1))
)
node_ebd_ext = node_ebd_ext + sys_cs_embd

if central_ext_index is None:
from deepmd.pt.utils.nlist import get_central_ext_index

central_ext_index = get_central_ext_index(extended_batch, ptr)
node_ebd_inp = node_ebd_ext[central_ext_index]

node_ebd, edge_ebd, h2, rot_mat, sw = self.repflows.forward_flat(
nlist,
extended_coord,
extended_atype,
extended_batch,
node_ebd_ext,
mapping,
batch,
ptr,
central_ext_index=central_ext_index,
nlist_ext=nlist_ext,
a_nlist=a_nlist,
a_nlist_ext=a_nlist_ext,
nlist_mask=nlist_mask,
a_nlist_mask=a_nlist_mask,
edge_index=edge_index,
angle_index=angle_index,
)

if self.concat_output_tebd:
node_ebd = torch.cat([node_ebd, node_ebd_inp], dim=-1)

return {
"descriptor": node_ebd.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
"rot_mat": (
rot_mat.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION)
if rot_mat is not None
else None
),
"g2": (
edge_ebd.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION)
if edge_ebd is not None
else None
),
"h2": (
h2.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION)
if h2 is not None
else None
),
}

@classmethod
def update_sel(
cls,
Expand Down
Loading