Skip to content
7 changes: 7 additions & 0 deletions deepmd/infer/deep_pot.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,9 @@ def eval(
when atomic is True.
hessian
The Hessian matrix of the system, in shape (nframes, 3 * natoms, 3 * natoms). Returned when available.
grad_aparam
The gradient of energy w.r.t. atomic parameters, in shape (nframes, natoms, dim_aparam).
Returned when aparam is provided and the model has dim_aparam > 0.
"""
# This method has been used by:
# documentation python.md
Expand Down Expand Up @@ -251,6 +254,10 @@ def eval(
nframes, 3 * natoms, 3 * natoms
)
result = (*list(result), hessian)
if "grad_aparam" in results:
dim_aparam = self.get_dim_aparam()
grad_aparam = results["grad_aparam"].reshape(nframes, natoms, dim_aparam)
result = (*list(result), grad_aparam)
return result


Expand Down
53 changes: 46 additions & 7 deletions deepmd/pt/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,9 +373,16 @@ def eval(
coords, atom_types, len(atom_types.shape) > 1
)
request_defs = self._get_request_defs(atomic)
compute_grad_aparam = kwargs.pop("grad_aparam", False) and aparam is not None
if "spin" not in kwargs or kwargs["spin"] is None:
out = self._eval_func(self._eval_model, numb_test, natoms)(
coords, cells, atom_types, fparam, aparam, request_defs
coords,
cells,
atom_types,
fparam,
aparam,
request_defs,
compute_grad_aparam,
)
else:
out = self._eval_func(self._eval_model_spin, numb_test, natoms)(
Expand All @@ -387,12 +394,14 @@ def eval(
aparam,
request_defs,
)
return dict(
zip(
[x.name for x in request_defs],
out,
)
)
n_request = len(request_defs)
if isinstance(out, tuple):
result = dict(zip([x.name for x in request_defs], out[:n_request]))
if compute_grad_aparam and len(out) > n_request:
result["grad_aparam"] = out[n_request]
else:
result = dict(zip([x.name for x in request_defs], out))
return result

def _get_request_defs(self, atomic: bool) -> list[OutputVariableDef]:
"""Get the requested output definitions.
Expand Down Expand Up @@ -487,6 +496,7 @@ def _eval_model(
fparam: np.ndarray | None,
aparam: np.ndarray | None,
request_defs: list[OutputVariableDef],
compute_grad_aparam: bool = False,
) -> tuple[np.ndarray, ...]:
model = self.dp.to(DEVICE)
prec = NP_PRECISION_DICT[RESERVED_PRECISION_DICT[GLOBAL_PT_FLOAT_PRECISION]]
Expand Down Expand Up @@ -528,6 +538,11 @@ def _eval_model(
)
else:
aparam_input = None

# If grad_aparam requested, enable grad tracking on aparam
if compute_grad_aparam and aparam_input is not None:
aparam_input = aparam_input.detach().requires_grad_(True)

do_atomic_virial = any(
x.category == OutputVariableCategory.DERV_C for x in request_defs
)
Expand All @@ -554,6 +569,30 @@ def _eval_model(
results.append(
np.full(np.abs(shape), np.nan, dtype=prec)
) # this is kinda hacky

# Compute dE/d(aparam) via autograd
if compute_grad_aparam:
# Find total energy in batch_output
energy_tensor = None
for key in ["energy_redu", "energy"]:
if key in batch_output:
energy_tensor = batch_output[key]
break
if energy_tensor is not None:
grad_ap = torch.autograd.grad(
energy_tensor.sum(),
aparam_input,
create_graph=False,
retain_graph=False,
)[0]
results.append(
grad_ap.reshape(nframes, natoms, -1).detach().cpu().numpy()
)
else:
# Energy not available or no grad path; fill with NaN
dim_ap = self.get_dim_aparam()
results.append(np.full((nframes, natoms, dim_ap), np.nan, dtype=prec))

return tuple(results)

def _eval_model_spin(
Expand Down
83 changes: 81 additions & 2 deletions deepmd/pt/loss/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ def __init__(
inference: bool = False,
use_huber: bool = False,
huber_delta: float = 0.01,
start_pref_ap: float = 0.0,
limit_pref_ap: float = 0.0,
numb_aparam: int = 0,
**kwargs: Any,
) -> None:
r"""Construct a layer to compute loss on energy, force and virial.
Expand Down Expand Up @@ -109,6 +112,12 @@ def __init__(
Formula: loss = 0.5 * (error**2) if |error| <= D else D * (|error| - 0.5 * D).
huber_delta : float
The threshold delta (D) used for Huber loss, controlling transition between L2 and L1 loss.
start_pref_ap : float
The prefactor of aparam gradient loss at the start of the training.
limit_pref_ap : float
The prefactor of aparam gradient loss at the end of the training.
numb_aparam : int
The dimension of atomic parameters. Required when aparam gradient loss is enabled.
**kwargs
Other keyword arguments.
"""
Expand Down Expand Up @@ -151,6 +160,15 @@ def __init__(
"Huber loss is not implemented for force with atom_pref, generalized force and relative force. "
)

self.has_ap = start_pref_ap != 0.0 or limit_pref_ap != 0.0
if self.has_ap and numb_aparam == 0:
raise RuntimeError(
"numb_aparam must be > 0 when aparam gradient loss is enabled"
)
self.start_pref_ap = start_pref_ap
self.limit_pref_ap = limit_pref_ap
self.numb_aparam = numb_aparam

def forward(
self,
input_dict: dict[str, torch.Tensor],
Expand Down Expand Up @@ -182,7 +200,18 @@ def forward(
more_loss: dict[str, torch.Tensor]
Other losses for display.
"""
model_pred = model(**input_dict)
ap_for_grad: torch.Tensor | None = None
# Capture the grad-enabled state before any enable_grad context.
in_training = torch.is_grad_enabled()
if self.has_ap and input_dict.get("aparam") is not None:
ap_for_grad = input_dict["aparam"].detach()
ap_for_grad.requires_grad_(True)
input_dict = {**input_dict, "aparam": ap_for_grad}
# Use enable_grad so gradient can be computed even inside no_grad (inference).
with torch.enable_grad():
model_pred = model(**input_dict)
else:
model_pred = model(**input_dict)
coef = learning_rate / self.starter_learning_rate
pref_e = self.limit_pref_e + (self.start_pref_e - self.limit_pref_e) * coef
pref_f = self.limit_pref_f + (self.start_pref_f - self.limit_pref_f) * coef
Expand Down Expand Up @@ -402,6 +431,41 @@ def forward(
rmse_ae.detach(), find_atom_ener
)

if self.has_ap and ap_for_grad is not None and "energy" in model_pred:
energy_pred_ap = model_pred["energy"] # [nf, 1]
# Compute d(sum_E)/d(aparam_raw), shape [nf, nloc, numb_aparam].
# Use enable_grad so this works both in training and no_grad inference.
with torch.enable_grad():
grad_ap_pred = torch.autograd.grad(
[energy_pred_ap.sum()],
[ap_for_grad],
create_graph=in_training,
retain_graph=not self.inference, # keep graph alive for subsequent loss.backward in training
)[0]
assert grad_ap_pred is not None
# Always expose aparam_grad in model_pred (useful for inference output).
model_pred = dict(model_pred)
model_pred["aparam_grad"] = (
grad_ap_pred.detach() if not in_training else grad_ap_pred
)
if "grad_aparam" in label:
find_grad_ap = label.get("find_grad_aparam", 0.0)
pref_ap = (
self.limit_pref_ap
+ (self.start_pref_ap - self.limit_pref_ap) * coef
) * find_grad_ap
grad_ap_label = label["grad_aparam"].to(grad_ap_pred.dtype)
diff_ap = (grad_ap_label - grad_ap_pred).reshape(-1)
l2_ap_loss = torch.mean(torch.square(diff_ap))
if not self.inference:
more_loss["l2_grad_aparam_loss"] = self.display_if_exist(
l2_ap_loss.detach(), find_grad_ap
)
loss += (pref_ap * l2_ap_loss).to(GLOBAL_PT_FLOAT_PRECISION)
more_loss["rmse_grad_aparam"] = self.display_if_exist(
l2_ap_loss.sqrt().detach(), find_grad_ap
)

if not self.inference:
more_loss["rmse"] = torch.sqrt(loss.detach())
return model_pred, loss, more_loss
Expand Down Expand Up @@ -482,6 +546,16 @@ def label_requirement(self) -> list[DataRequirementItem]:
default=1.0,
)
)
if self.has_ap:
label_requirement.append(
DataRequirementItem(
"grad_aparam",
ndof=self.numb_aparam,
atomic=True,
must=False,
high_prec=False,
)
)
return label_requirement

def serialize(self) -> dict:
Expand All @@ -492,7 +566,7 @@ def serialize(self) -> dict:
dict
The serialized loss module
"""
return {
data = {
"@class": "EnergyLoss",
"@version": 2,
"starter_learning_rate": self.starter_learning_rate,
Expand All @@ -514,6 +588,11 @@ def serialize(self) -> dict:
"use_huber": self.use_huber,
"huber_delta": self.huber_delta,
}
if self.has_ap:
data["start_pref_ap"] = self.start_pref_ap
data["limit_pref_ap"] = self.limit_pref_ap
data["numb_aparam"] = self.numb_aparam
return data

@classmethod
def deserialize(cls, data: dict) -> "TaskLoss":
Expand Down
10 changes: 10 additions & 0 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1687,9 +1687,19 @@ def get_loss(
loss_type = loss_params.get("type", "ener")
if whether_hessian(loss_params):
loss_params["starter_learning_rate"] = start_lr
if (
loss_params.get("start_pref_ap", 0.0) != 0.0
or loss_params.get("limit_pref_ap", 0.0) != 0.0
):
loss_params["numb_aparam"] = _model.get_dim_aparam()
return EnergyHessianStdLoss(**loss_params)
elif loss_type == "ener":
loss_params["starter_learning_rate"] = start_lr
if (
loss_params.get("start_pref_ap", 0.0) != 0.0
or loss_params.get("limit_pref_ap", 0.0) != 0.0
):
loss_params["numb_aparam"] = _model.get_dim_aparam()
return EnergyStdLoss(**loss_params)
elif loss_type == "dos":
loss_params["starter_learning_rate"] = start_lr
Expand Down
16 changes: 16 additions & 0 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -3072,6 +3072,8 @@ def loss_ener() -> list[Argument]:
doc_limit_pref_pf = limit_pref("atomic prefactor force")
doc_start_pref_gf = start_pref("generalized force", label="drdq", abbr="gf")
doc_limit_pref_gf = limit_pref("generalized force")
doc_start_pref_ap = start_pref("aparam gradient", label="grad_aparam", abbr="ap")
doc_limit_pref_ap = limit_pref("aparam gradient")
doc_numb_generalized_coord = "The dimension of generalized coordinates. Required when generalized force loss is used."
doc_relative_f = "If provided, relative force error will be used in the loss. The difference of force will be normalized by the magnitude of the force in the label with a shift given by `relative_f`, i.e. DF_i / ( || F || + relative_f ) with DF denoting the difference between prediction and label and || F || denoting the L2 norm of the label."
doc_enable_atom_ener_coeff = "If true, the energy will be computed as \\sum_i c_i E_i. c_i should be provided by file atom_ener_coeff.npy in each data system, otherwise it's 1."
Expand Down Expand Up @@ -3211,6 +3213,20 @@ def loss_ener() -> list[Argument]:
default=0.01,
doc=doc_huber_delta,
),
Argument(
"start_pref_ap",
[float, int],
optional=True,
default=0.0,
doc=doc_start_pref_ap,
),
Argument(
"limit_pref_ap",
[float, int],
optional=True,
default=0.0,
doc=doc_limit_pref_ap,
),
]


Expand Down
Loading