Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
002853e
add dplr(torch)
ChiahsinChu Jan 8, 2026
dc3367d
update torch_admp version; minor update in dipole_charge non-pbc asse…
ChiahsinChu Jan 8, 2026
8603ef0
minor code improvement based on @coderabbitai
ChiahsinChu Jan 8, 2026
23af13f
change source of torch_admp from github to pypi in pyproject.toml; mi…
ChiahsinChu Jan 8, 2026
28caa3f
add torch-admp to dependencies
njzjz Jan 10, 2026
85d3311
pass tuple
njzjz Jan 10, 2026
b55a793
Fix pop method for torch_static_requirement
njzjz Jan 10, 2026
3b3f074
docs: add PyTorch backend documentation to DPLR model
ChiahsinChu Jan 10, 2026
6098731
update freeze/eval with model (de-)serialization
ChiahsinChu Jan 15, 2026
539ae90
Feat: Refactor DipoleChargeModifier to support direct model object in…
ChiahsinChu Jan 15, 2026
0c8ef92
feat: Add batch processing support to DipoleChargeModifier for improv…
ChiahsinChu Jan 16, 2026
7d652e7
feat: add DipoleChargeModifier.eval_np
ChiahsinChu Jan 16, 2026
775997c
fix(modifier): resolve DipoleChargeModifier initialization and execut…
ChiahsinChu Jan 17, 2026
6a66a3e
update required ver of torch-admp
ChiahsinChu Jan 17, 2026
91ff702
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 17, 2026
c1cf079
fix bug in tf dipole_charge modifier
ChiahsinChu Jan 21, 2026
2f12b7c
add DPModifier
ChiahsinChu Jan 21, 2026
1318e3d
refactor(pt): refactor DipoleChargeModifier to inherit from DPModifier
ChiahsinChu Jan 21, 2026
99287fb
update DPModifier.__init__
ChiahsinChu Jan 21, 2026
9421022
fix(pt): resolve issues with dipole charge modifier
ChiahsinChu Jan 22, 2026
3b81c22
docs(pt): update docstrings in dipole_charge.py
ChiahsinChu Jan 23, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion backend/dynamic_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,11 @@ def dynamic_metadata(
]
optional_dependencies["lmp"].extend(find_libpython_requires)
optional_dependencies["ipi"].extend(find_libpython_requires)
torch_static_requirement = optional_dependencies.pop("torch", ())
return {
**optional_dependencies,
**get_tf_requirement(tf_version),
**get_pt_requirement(pt_version),
**get_pt_requirement(
pt_version, static_requirement=tuple(torch_static_requirement)
),
}
10 changes: 9 additions & 1 deletion backend/find_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,10 @@ def find_pytorch() -> tuple[str | None, list[str]]:


@lru_cache
def get_pt_requirement(pt_version: str = "") -> dict:
def get_pt_requirement(
pt_version: str = "",
static_requirement: tuple[str] | None = None,
) -> dict:
"""Get PyTorch requirement when PT is not installed.

If pt_version is not given and the environment variable `PYTORCH_VERSION` is set, use it as the requirement.
Expand All @@ -99,6 +102,8 @@ def get_pt_requirement(pt_version: str = "") -> dict:
----------
pt_version : str, optional
PT version
static_requirement : tuple[str] or None, optional
Static requirements

Returns
-------
Expand All @@ -125,6 +130,8 @@ def get_pt_requirement(pt_version: str = "") -> dict:
mpi_requirement = ["mpich"]
else:
mpi_requirement = []
if static_requirement is None:
static_requirement = ()

return {
"torch": [
Expand All @@ -138,6 +145,7 @@ def get_pt_requirement(pt_version: str = "") -> dict:
else "torch>=2.1.0",
*mpi_requirement,
*cibw_requirement,
*static_requirement,
],
}

Expand Down
1 change: 0 additions & 1 deletion deepmd/dpmodel/modifier/base_modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def serialize(self) -> dict:
dict
The serialized data
"""
pass

@classmethod
def deserialize(cls, data: dict) -> "BaseModifier":
Expand Down
20 changes: 8 additions & 12 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import argparse
import copy
import io
import json
import logging
import os
import pickle
from pathlib import (
Path,
)
Expand Down Expand Up @@ -401,17 +401,13 @@ def freeze(
model.eval()
model = torch.jit.script(model)

dm_output = "data_modifier.pth"
extra_files = {dm_output: ""}
if tester.modifier is not None:
dm = tester.modifier
dm.eval()
buffer = io.BytesIO()
torch.jit.save(
torch.jit.script(dm),
buffer,
)
extra_files = {dm_output: buffer.getvalue()}
extra_files = {"modifier_data": ""}
dm = tester.modifier
if dm is not None:
# dict from dm.serialize() includes np.ndarray
# use pickle rather than json
bytes_data = pickle.dumps(dm.serialize())
extra_files = {"modifier_data": bytes_data}
torch.jit.save(
model,
output,
Expand Down
18 changes: 11 additions & 7 deletions deepmd/pt/infer/deep_eval.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import io
import json
import logging
import pickle
from collections.abc import (
Callable,
)
Expand Down Expand Up @@ -49,6 +49,9 @@
from deepmd.pt.model.network.network import (
TypeEmbedNetConsistent,
)
from deepmd.pt.modifier import (
BaseModifier,
)
from deepmd.pt.train.wrapper import (
ModelWrapper,
)
Expand Down Expand Up @@ -172,19 +175,20 @@ def __init__(
self.dp = ModelWrapper(model)
self.dp.load_state_dict(state_dict)
elif str(self.model_path).endswith(".pth"):
extra_files = {"data_modifier.pth": ""}
extra_files = {"modifier_data": ""}
model = torch.jit.load(
model_file, map_location=env.DEVICE, _extra_files=extra_files
)
modifier = None
# Load modifier if it exists in extra_files
if len(extra_files["data_modifier.pth"]) > 0:
# Create a file-like object from the in-memory data
modifier_data = extra_files["data_modifier.pth"]
if len(extra_files["modifier_data"]) > 0:
modifier_data = extra_files["modifier_data"]
if isinstance(modifier_data, bytes):
modifier_data = io.BytesIO(modifier_data)
modifier_data = pickle.loads(modifier_data)
# Load the modifier directly from the file-like object
modifier = torch.jit.load(modifier_data, map_location=env.DEVICE)
modifier = BaseModifier.get_class_by_type(
modifier_data["type"]
).deserialize(modifier_data)
self.dp = ModelWrapper(model, modifier=modifier)
self.modifier = modifier
model_def_script = self.dp.model["Default"].get_model_def_script()
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/modifier/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@
from .base_modifier import (
BaseModifier,
)
from .dipole_charge import (
DipoleChargeModifier,
)

__all__ = [
"BaseModifier",
"DipoleChargeModifier",
"get_data_modifier",
]

Expand Down
1 change: 1 addition & 0 deletions deepmd/pt/modifier/base_modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def serialize(self) -> dict:
data = {
"@class": "Modifier",
"type": self.modifier_type,
"use_cache": self.use_cache,
"@version": 3,
}
return data
Expand Down
Loading
Loading