Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion deepmd/backend/pt_expt.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class PyTorchExportableBackend(Backend):
| Backend.Feature.IO
)
"""The features of the backend."""
suffixes: ClassVar[list[str]] = [".pte"]
suffixes: ClassVar[list[str]] = [".pte", ".pt2"]
"""The suffixes of the backend."""

def is_available(self) -> bool:
Expand Down
28 changes: 18 additions & 10 deletions deepmd/dpmodel/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from deepmd.dpmodel.array_api import (
Array,
xp_take_along_axis,
xp_take_first_n,
)
from deepmd.dpmodel.common import (
cast_precision,
Expand Down Expand Up @@ -534,7 +535,7 @@
(nf, nall, self.tebd_dim),
)
# nfnl x tebd_dim
atype_embd = atype_embd_ext[:, :nloc, :]
atype_embd = xp_take_first_n(atype_embd_ext, 1, nloc)
grrg, g2, h2, rot_mat, sw = self.se_atten(
nlist,
coord_ext,
Expand Down Expand Up @@ -1056,7 +1057,8 @@
self.stddev[...],
)
nf, nloc, nnei, _ = dmatrix.shape
atype = atype_ext[:, :nloc]
nall = atype_ext.shape[1]

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable nall is not used.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Drop the dead nall assignment.

Ruff already flags this as F841, so the Python lint job will keep failing until this local is removed.

💡 Suggested fix
-        nall = atype_ext.shape[1]
         atype = xp_take_first_n(atype_ext, 1, nloc)
As per coding guidelines, `**/*.py`: Always run `ruff check .` and `ruff format .` before committing changes or CI will fail.
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
nall = atype_ext.shape[1]
atype = xp_take_first_n(atype_ext, 1, nloc)
🧰 Tools
🪛 Ruff (0.15.4)

[error] 1060-1060: Local variable nall is assigned to but never used

Remove assignment to unused variable nall

(F841)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepmd/dpmodel/descriptor/dpa1.py` at line 1060, Remove the unused local
assignment "nall = atype_ext.shape[1]" from dpa1.py (the dead variable 'nall'
referenced in the review) so the F841 lint error is resolved; locate the
assignment near the code that references 'atype_ext' and delete that single
line, then run "ruff check ." and "ruff format ." to ensure no other lint errors
remain.

atype = xp_take_first_n(atype_ext, 1, nloc)
exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext)
# nfnl x nnei
exclude_mask = xp.reshape(exclude_mask, (nf * nloc, nnei))
Expand All @@ -1075,6 +1077,12 @@
nlist_masked = xp.where(nlist_mask, nlist, xp.zeros_like(nlist))
ng = self.neuron[-1]
nt = self.tebd_dim

# Gather neighbor info using xp_take_along_axis along axis=1.
# This avoids flat (nf*nall,) indexing that creates Ne(nall, nloc)
# constraints in torch.export, breaking NoPbc (nall == nloc).
nlist_2d = xp.reshape(nlist_masked, (nf, nloc * nnei)) # (nf, nloc*nnei)

# nfnl x nnei x 4
rr = xp.reshape(dmatrix, (nf * nloc, nnei, 4))
rr = rr * xp.astype(exclude_mask[:, :, None], rr.dtype)
Expand All @@ -1083,15 +1091,16 @@
if self.tebd_input_mode in ["concat"]:
# nfnl x tebd_dim
atype_embd = xp.reshape(
atype_embd_ext[:, :nloc, :], (nf * nloc, self.tebd_dim)
xp_take_first_n(atype_embd_ext, 1, nloc), (nf * nloc, self.tebd_dim)
)
# nfnl x nnei x tebd_dim
atype_embd_nnei = xp.tile(atype_embd[:, xp.newaxis, :], (1, nnei, 1))
index = xp.tile(
xp.reshape(nlist_masked, (nf, -1, 1)), (1, 1, self.tebd_dim)
# Gather neighbor type embeddings: (nf, nall, tebd_dim) -> (nf, nloc*nnei, tebd_dim)
nlist_idx_tebd = xp.tile(nlist_2d[:, :, xp.newaxis], (1, 1, self.tebd_dim))
atype_embd_nlist = xp_take_along_axis(
atype_embd_ext, nlist_idx_tebd, axis=1
)
# nfnl x nnei x tebd_dim
atype_embd_nlist = xp_take_along_axis(atype_embd_ext, index, axis=1)
atype_embd_nlist = xp.reshape(
atype_embd_nlist, (nf * nloc, nnei, self.tebd_dim)
)
Expand All @@ -1110,10 +1119,9 @@
assert self.embeddings_strip is not None
assert type_embedding is not None
ntypes_with_padding = type_embedding.shape[0]
# nf x (nl x nnei)
nlist_index = xp.reshape(nlist_masked, (nf, nloc * nnei))
# nf x (nl x nnei)
nei_type = xp_take_along_axis(atype_ext, nlist_index, axis=1)
# Gather neighbor types: (nf, nall) -> (nf, nloc*nnei)
nei_type = xp_take_along_axis(atype_ext, nlist_2d, axis=1)
nei_type = xp.reshape(nei_type, (-1,)) # (nf * nloc * nnei,)
# (nf x nl x nnei) x ng
nei_type_index = xp.tile(xp.reshape(nei_type, (-1, 1)), (1, ng))
if self.type_one_side:
Expand Down
7 changes: 3 additions & 4 deletions deepmd/dpmodel/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from deepmd.dpmodel.array_api import (
Array,
xp_take_along_axis,
xp_take_first_n,
)
from deepmd.dpmodel.common import (
cast_precision,
Expand Down Expand Up @@ -876,7 +877,7 @@ def call(
xp.take(type_embedding, xp.reshape(atype_ext, (-1,)), axis=0),
(nframes, nall, self.tebd_dim),
)
g1_inp = g1_ext[:, :nloc, :]
g1_inp = xp_take_first_n(g1_ext, 1, nloc)
g1, _, _, _, _ = self.repinit(
nlist_dict[
get_multiple_nlist_key(self.repinit.get_rcut(), self.repinit.get_nsel())
Expand Down Expand Up @@ -910,9 +911,7 @@ def call(
g1 = g1 + self.tebd_transform(g1_inp)
# mapping g1
assert mapping is not None
mapping_ext = xp.tile(
xp.reshape(mapping, (nframes, nall, 1)), (1, 1, g1.shape[-1])
)
mapping_ext = xp.tile(xp.expand_dims(mapping, axis=-1), (1, 1, g1.shape[-1]))
g1_ext = xp_take_along_axis(g1, mapping_ext, axis=1)
# repformer
g1, g2, h2, rot_mat, sw = self.repformers(
Expand Down
8 changes: 4 additions & 4 deletions deepmd/dpmodel/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from deepmd.dpmodel.array_api import (
Array,
xp_take_along_axis,
xp_take_first_n,
)
from deepmd.dpmodel.common import (
to_numpy_array,
Expand Down Expand Up @@ -499,7 +500,7 @@ def call(
sw = xp.reshape(sw, (nf, nloc, nnei))
sw = xp.where(nlist_mask, sw, xp.zeros_like(sw))
# nf x nloc x tebd_dim
atype_embd = atype_embd_ext[:, :nloc, :]
atype_embd = xp_take_first_n(atype_embd_ext, 1, nloc)
assert list(atype_embd.shape) == [nf, nloc, self.g1_dim]

g1 = self.act(atype_embd)
Expand All @@ -516,7 +517,7 @@ def call(
# if a neighbor is real or not is indicated by nlist_mask
nlist = xp.where(nlist == -1, xp.zeros_like(nlist), nlist)
# nf x nall x ng1
mapping = xp.tile(xp.reshape(mapping, (nf, -1, 1)), (1, 1, self.g1_dim))
mapping = xp.tile(xp.expand_dims(mapping, axis=-1), (1, 1, self.g1_dim))
for idx, ll in enumerate(self.layers):
# g1: nf x nloc x ng1
# g1_ext: nf x nall x ng1
Expand Down Expand Up @@ -1765,9 +1766,8 @@ def call(
)

nf, nloc, nnei, _ = g2.shape
nall = g1_ext.shape[1]
# g1, _ = xp.split(g1_ext, [nloc], axis=1)
g1 = g1_ext[:, :nloc, :]
g1 = xp_take_first_n(g1_ext, 1, nloc)
assert (nf, nloc) == g1.shape[:2]
assert (nf, nloc, nnei) == h2.shape[:3]

Expand Down
9 changes: 5 additions & 4 deletions deepmd/dpmodel/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

from deepmd.dpmodel.array_api import (
Array,
xp_take_along_axis,
xp_take_first_n,
)
from deepmd.dpmodel.atomic_model.base_atomic_model import (
BaseAtomicModel,
Expand Down Expand Up @@ -558,7 +560,6 @@ def _format_nlist(
xp = array_api_compat.array_namespace(extended_coord, nlist)
n_nf, n_nloc, n_nnei = nlist.shape
extended_coord = extended_coord.reshape([n_nf, -1, 3])
nall = extended_coord.shape[1]
rcut = self.get_rcut()

if n_nnei < nnei:
Expand All @@ -581,14 +582,14 @@ def _format_nlist(
# make a copy before revise
m_real_nei = nlist >= 0
ret = xp.where(m_real_nei, nlist, 0)
coord0 = extended_coord[:, :n_nloc, :]
coord0 = xp_take_first_n(extended_coord, 1, n_nloc)
index = xp.tile(ret.reshape(n_nf, n_nloc * n_nnei, 1), (1, 1, 3))
coord1 = xp.take_along_axis(extended_coord, index, axis=1)
coord1 = xp_take_along_axis(extended_coord, index, axis=1)
coord1 = coord1.reshape(n_nf, n_nloc, n_nnei, 3)
rr = xp.linalg.norm(coord0[:, :, None, :] - coord1, axis=-1)
rr = xp.where(m_real_nei, rr, float("inf"))
rr, ret_mapping = xp.sort(rr, axis=-1), xp.argsort(rr, axis=-1)
ret = xp.take_along_axis(ret, ret_mapping, axis=2)
ret = xp_take_along_axis(ret, ret_mapping, axis=2)
ret = xp.where(rr > rcut, -1, ret)
ret = ret[..., :nnei]
# not extra_nlist_sort and n_nnei <= nnei:
Expand Down
21 changes: 13 additions & 8 deletions deepmd/dpmodel/utils/exclude_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from deepmd.dpmodel.array_api import (
Array,
xp_take_along_axis,
xp_take_first_n,
)


Expand Down Expand Up @@ -131,18 +132,22 @@ def build_type_exclude_mask(
],
axis=-1,
)
type_i = xp.reshape(atype_ext[:, :nloc], (nf, nloc)) * (self.ntypes + 1)
# nf x nloc x nnei
index = xp.reshape(
xp.where(nlist == -1, xp.full_like(nlist, nall), nlist), (nf, nloc * nnei)
type_i = xp.reshape(xp_take_first_n(atype_ext, 1, nloc), (nf, nloc)) * (
self.ntypes + 1
)
type_j = xp_take_along_axis(ae, index, axis=1)
# Map -1 entries to nall (the virtual atom index in ae)
nlist_for_type = xp.where(nlist == -1, xp.full_like(nlist, nall), nlist)
# Gather neighbor types using xp_take_along_axis along axis=1.
# This avoids flat (nf*(nall+1),) indexing that creates Ne(nall, nloc)
# constraints in torch.export, breaking NoPbc (nall == nloc).
nlist_for_gather = xp.reshape(nlist_for_type, (nf, nloc * nnei))
type_j = xp_take_along_axis(ae, nlist_for_gather, axis=1)
type_j = xp.reshape(type_j, (nf, nloc, nnei))
type_ij = type_i[:, :, None] + type_j
# nf x (nloc x nnei)
type_ij = xp.reshape(type_ij, (nf, nloc * nnei))
# (nf * nloc * nnei,)
type_ij_flat = xp.reshape(type_ij, (-1,))
mask = xp.reshape(
xp.take(self.type_mask[...], xp.reshape(type_ij, (-1,))),
xp.take(self.type_mask[...], type_ij_flat),
(nf, nloc, nnei),
)
return mask
Expand Down
4 changes: 2 additions & 2 deletions deepmd/dpmodel/utils/nlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from deepmd.dpmodel.array_api import (
Array,
xp_take_along_axis,
xp_take_first_n,
)

from .region import (
Expand Down Expand Up @@ -243,8 +244,7 @@ def build_multiple_neighbor_list(
nlist = xp.concat([nlist, pad], axis=-1)
nsel = nsels[-1]
coord1 = xp.reshape(coord, (nb, -1, 3))
nall = coord1.shape[1]
coord0 = coord1[:, :nloc, :]
coord0 = xp_take_first_n(coord1, 1, nloc)
nlist_mask = nlist == -1
tnlist_0 = xp.where(nlist_mask, xp.zeros_like(nlist), nlist)
index = xp.tile(xp.reshape(tnlist_0, (nb, nloc * nsel, 1)), (1, 1, 3))
Expand Down
60 changes: 60 additions & 0 deletions deepmd/pt_expt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,54 @@ def train(
trainer.run()


def freeze(
model: str,
output: str = "frozen_model.pt2",
head: str | None = None,
) -> None:
"""Freeze a pt_expt training checkpoint to .pte or .pt2 format.

Parameters
----------
model : str
Path to the training checkpoint (.pt file).
output : str
Path for the frozen model output (.pte or .pt2).
head : str or None
Head to freeze in a multi-task model (not yet supported).
"""
import torch

from deepmd.pt_expt.model import (
get_model,
)
from deepmd.pt_expt.train.wrapper import (
ModelWrapper,
)
from deepmd.pt_expt.utils.env import (
DEVICE,
)
from deepmd.pt_expt.utils.serialization import (
deserialize_to_file,
)

state_dict = torch.load(model, map_location=DEVICE, weights_only=True)
if "model" in state_dict:
state_dict = state_dict["model"]
model_params = state_dict["_extra_state"]["model_params"]

# Reconstruct model and load weights
pt_expt_model = get_model(model_params).to(DEVICE)
wrapper = ModelWrapper(pt_expt_model)
wrapper.load_state_dict(state_dict)
pt_expt_model.eval()

# Serialize to dict and export
model_dict = pt_expt_model.serialize()
deserialize_to_file(output, {"model": model_dict})
log.info(f"Saved frozen model to {output}")
Comment on lines +163 to +208
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Don't silently ignore --head.

If the caller passes a head today, freeze() still exports the default/full model and never tells them the request was ignored. Fail fast until per-head freezing is actually implemented.

🛑 Suggested guard
 def freeze(
     model: str,
     output: str = "frozen_model.pt2",
     head: str | None = None,
 ) -> None:
@@
     from deepmd.pt_expt.utils.serialization import (
         deserialize_to_file,
     )

+    if head is not None:
+        raise NotImplementedError(
+            "--head is not supported for the pt_expt freeze command yet."
+        )
+
     state_dict = torch.load(model, map_location=DEVICE, weights_only=True)

As per coding guidelines, "Always run ruff check . and ruff format . before committing changes or CI will fail."

🧰 Tools
🪛 Ruff (0.15.4)

[warning] 166-166: Unused function argument: head

(ARG001)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepmd/pt_expt/entrypoints/main.py` around lines 163 - 208, The freeze
function currently ignores the head parameter; update freeze (the function
handling model loading, get_model, ModelWrapper, serialize, deserialize_to_file)
to fail-fast when head is not None by raising a clear exception (e.g.,
ValueError) indicating per-head freezing is not yet supported, so callers are
not silently misled; add the check near the start of freeze before
reconstructing the model and ensure the error message references the unsupported
head behavior, then run ruff check . and ruff format . to satisfy linting/style
rules.



def main(args: list[str] | argparse.Namespace | None = None) -> None:
"""Entry point for the pt_expt backend CLI.

Expand Down Expand Up @@ -195,6 +243,18 @@ def main(args: list[str] | argparse.Namespace | None = None) -> None:
skip_neighbor_stat=FLAGS.skip_neighbor_stat,
output=FLAGS.output,
)
elif FLAGS.command == "freeze":
if Path(FLAGS.checkpoint_folder).is_dir():
checkpoint_path = Path(FLAGS.checkpoint_folder)
latest_ckpt_file = (checkpoint_path / "checkpoint").read_text()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Handle pt_expt checkpoint dirs without checkpoint file

Reading checkpoint_path / "checkpoint" unconditionally breaks dp --pt_expt freeze -c <dir> for directories produced by the pt_expt trainer, because deepmd/pt_expt/train/training.py writes model.ckpt-<step>.pt plus a model.ckpt.pt symlink and does not create a checkpoint pointer file; in that common flow (including the default -c .), freeze raises FileNotFoundError instead of resolving the latest checkpoint.

Useful? React with 👍 / 👎.

FLAGS.model = str(checkpoint_path.joinpath(latest_ckpt_file))
Comment on lines +247 to +250
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Strip the checkpoint pointer before joining the path.

Path.read_text() preserves trailing newlines, so a line-based checkpoint file can produce FLAGS.model with a trailing \n and make torch.load() miss the file.

🧩 Suggested fix
-            latest_ckpt_file = (checkpoint_path / "checkpoint").read_text()
-            FLAGS.model = str(checkpoint_path.joinpath(latest_ckpt_file))
+            latest_ckpt_file = (checkpoint_path / "checkpoint").read_text().strip()
+            FLAGS.model = str(checkpoint_path / latest_ckpt_file)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if Path(FLAGS.checkpoint_folder).is_dir():
checkpoint_path = Path(FLAGS.checkpoint_folder)
latest_ckpt_file = (checkpoint_path / "checkpoint").read_text()
FLAGS.model = str(checkpoint_path.joinpath(latest_ckpt_file))
if Path(FLAGS.checkpoint_folder).is_dir():
checkpoint_path = Path(FLAGS.checkpoint_folder)
latest_ckpt_file = (checkpoint_path / "checkpoint").read_text().strip()
FLAGS.model = str(checkpoint_path / latest_ckpt_file)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepmd/pt_expt/entrypoints/main.py` around lines 247 - 250, The code reads
the checkpoint pointer into latest_ckpt_file with Path(...).read_text() which
can include a trailing newline, causing FLAGS.model to point to a non-existent
path; update the logic that sets latest_ckpt_file (used when
FLAGS.checkpoint_folder is a directory) to strip whitespace/newlines (e.g.,
.strip()) before calling checkpoint_path.joinpath(...) so FLAGS.model is set to
the clean filename used by torch.load().

else:
FLAGS.model = FLAGS.checkpoint_folder
# Default to .pt2; user can specify .pte via -o flag
suffix = Path(FLAGS.output).suffix
if suffix not in (".pte", ".pt2"):
FLAGS.output = str(Path(FLAGS.output).with_suffix(".pt2"))
freeze(model=FLAGS.model, output=FLAGS.output, head=FLAGS.head)
else:
raise RuntimeError(
f"Unsupported command '{FLAGS.command}' for the pt_expt backend."
Expand Down
Loading
Loading