Skip to content

Conversation

@OutisLi
Copy link
Collaborator

@OutisLi OutisLi commented Jan 15, 2026

Summary by CodeRabbit

  • New Features

    • Models now emit total virial, per-atom virial, and extended virial alongside energies and forces, including in spin-aware flows.
  • Behavioral Improvements

    • Training can include virial loss with per-atom scaling; RMSE/MAE virial metrics optionally recorded.
    • Improved propagation of coordinate corrections to ensure accurate virial computation across model stages.
  • Tests

    • Added and extended tests to validate virial behavior for spin and non-spin scenarios; test flags support optional virial outputs.

✏️ Tip: You can customize this high-level summary in your review settings.

Copilot AI review requested due to automatic review settings January 15, 2026 11:13
@dosubot dosubot bot added the new feature label Jan 15, 2026
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This pull request adds support for spin virial calculations in the PyTorch backend by implementing virial corrections for spin models. The changes enable the computation of virial tensors when dealing with spin systems by introducing coordinate corrections that account for the virtual atoms used in spin representations.

Changes:

  • Added virial correction mechanism for spin models through coordinate corrections
  • Updated C/C++ API to enable virial output for spin models
  • Extended test coverage to include spin virial validation

Reviewed changes

Copilot reviewed 14 out of 14 changed files in this pull request and generated no comments.

Show a summary per file
File Description
deepmd/pt/model/model/spin_model.py Core implementation of spin virial correction through coordinate corrections returned by process_spin_input methods
deepmd/pt/model/model/transform_output.py Added extended_coord_corr parameter to apply virial corrections
deepmd/pt/model/model/make_model.py Propagated coord_corr_for_virial parameter through forward methods
deepmd/pt/loss/ener_spin.py Added virial loss computation for spin models
source/api_cc/src/DeepSpinPT.cc Uncommented virial computation code to enable spin virial output
source/api_c/src/c_api.cc Uncommented virial assignment code
source/api_c/include/deepmd.hpp Uncommented virial indexing loops
source/tests/universal/pt/model/test_model.py Enabled spin virial testing for PT backend
source/tests/universal/common/cases/model/utils.py Updated test logic to conditionally test spin virial
source/tests/pt/model/test_autodiff.py Added spin virial test class and updated test to include spin
source/tests/pt/model/test_ener_spin_model.py Updated test to handle new return value from process_spin_input
deepmd/pt/model/network/utils.py Added Optional import (unused in shown diff)
deepmd/pt/model/descriptor/repformers.py Code formatting changes (no functional change)
deepmd/pt/model/descriptor/repflow_layer.py Code formatting changes (no functional change)
source/tests/pt/model/test_nosel.py New test file for DPA3 descriptor with dynamic selection
Comments suppressed due to low confidence (2)

deepmd/pt/model/model/spin_model.py:56

  • The return type annotation is incorrect. The method now returns three values (coord_spin, atype_spin, coord_corr) on line 69, but the annotation only specifies two return values. Update to tuple[torch.Tensor, torch.Tensor, torch.Tensor].
    def process_spin_input(
        self, coord: torch.Tensor, atype: torch.Tensor, spin: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:

deepmd/pt/model/model/spin_model.py:78

  • The return type annotation is incorrect. The method now returns five values on lines 111-117 (extended_coord_updated, extended_atype_updated, nlist_updated, mapping_updated, extended_coord_corr), but the annotation only specifies four return values. Update to tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor].
    def process_spin_input_lower(
        self,
        extended_coord: torch.Tensor,
        extended_atype: torch.Tensor,
        extended_spin: torch.Tensor,
        nlist: torch.Tensor,
        mapping: torch.Tensor | None = None,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None]:

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 15, 2026

📝 Walkthrough

Walkthrough

Propagates coordinate-correction tensors for virial through spin and backbone models, applies virial correction in output fitting, adds conditional virial loss computation, enables virial I/O in C/C++ APIs, and expands tests to validate virial and related spin flows.

Changes

Cohort / File(s) Summary
PyTorch Loss Layer
deepmd/pt/loss/ener_spin.py
Adds conditional virial-loss handling: reads label virial, computes L2 (and optional MAE) virial loss, scales by pref_v and atom_norm, records diagnostics, and adds to total loss.
Model Backbone (coord-corr propagation)
deepmd/pt/model/model/make_model.py
Adds coord_corr_for_virial param to forward_common, computes/gathers extended_coord_corr, threads it into forward_common_lower and downstream fitting.
Spin Model (input & outputs)
deepmd/pt/model/model/spin_model.py
process_spin_input / process_spin_input_lower now return coord-correction for virial; forward paths pass these through; outputs may include virial, atom_virial, and extended_virial when gradients/flags enabled.
Output Transform
deepmd/pt/model/model/transform_output.py
fit_output_to_model_output accepts optional extended_coord_corr and, when present and differentiable, applies it to dc (virial correction).
C API (virial output enabled)
source/api_c/include/deepmd.hpp, source/api_c/src/c_api.cc
Activates previously-commented virial-population code paths to flatten and copy virial into C output buffers.
C++ API (virial output enabled)
source/api_cc/src/DeepSpinPT.cc
Enables conditional reading/assignment of virial, extended_virial, and atom_virial from model outputs; clears containers when absent.
PyTorch Tests - Autodiff & Spin Virial
source/tests/pt/model/test_autodiff.py
Adds finite-difference cell-energy helper, passes spin into inference, expects force_mag when spinning, and adds spin-virial tests (including shear variant).
PyTorch Tests - Spin Model unpacking
source/tests/pt/model/test_ener_spin_model.py
Updates tests to unpack and ignore extra return value from process_spin_input / process_spin_input_lower.
Universal Test Utilities & PT Tests
source/tests/universal/common/cases/model/utils.py, source/tests/universal/pt/model/test_model.py
Adds test_spin_virial flag and opts PT tests into spin-virial evaluation; includes spin in virial inputs when present.
C/C++ Tests & LAMMPS Tests
source/api_cc/tests/*, source/api_cc/tests/CMakeLists.txt, source/api_cc/src/*, source/lmp/tests/*, source/api_c/*, source/tests/tf/*
Enables and validates virial outputs in many tests: adds expected virial arrays, conditional assertions, test env var, and small TF/C++ compute-path reorderings to populate virial.
Descriptor / Misc Tests
deepmd/tf/descriptor/se_a.py, source/tests/pt/model/test_nosel.py, other test updates
Adds spin-aware coordinate helpers in descriptor, new descriptor consistency test, and assorted test data/format updates to incorporate virial/velocity expectations.

Sequence Diagram(s)

sequenceDiagram
    participant User as Caller
    participant Spin as SpinModel
    participant Backbone as BackboneModel
    participant Transform as TransformOutput
    participant Loss as EnergySpinLoss

    User->>Spin: forward(coord, atype, spin)
    Spin-->>User: coord_updated, atype_updated, coord_corr_for_virial
    User->>Backbone: forward_common(coord_updated, atype_updated, coord_corr_for_virial)
    Backbone-->>Transform: energy, force, energy_derv_c, extended_coord_corr
    Transform->>Transform: apply extended_coord_corr to dc (virial correction)
    Transform-->>Loss: energy, force, virial_pred, atom_virial
    Loss->>Loss: compute l2_virial_loss (+ optional MAE), scale and add to total loss
    Loss-->>User: total loss (includes virial component)
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Possibly related PRs

Suggested reviewers

  • iProzd
  • wanghan-iapcm
  • njzjz
🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 21.33% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and specifically describes the main feature being added: spin virial support in the PyTorch (pt) component, with proper context noting it's a rebased version.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
source/api_c/include/deepmd.hpp (1)

2695-2711: Consider explicitly documenting/handling spin atom_virial (still disabled).

This overload still resizes atom_virial[i] but does not populate it (block remains commented). If this is intentionally unsupported for spin models, consider documenting it in the API comment (or explicitly zeroing/clearing) to avoid consumers assuming it’s meaningful.

deepmd/pt/model/model/spin_model.py (1)

71-78: Update return type annotation.

The return type annotation still declares a 4-tuple, but the function now returns 5 values (including extended_coord_corr).

Proposed fix
     def process_spin_input_lower(
         self,
         extended_coord: torch.Tensor,
         extended_atype: torch.Tensor,
         extended_spin: torch.Tensor,
         nlist: torch.Tensor,
         mapping: torch.Tensor | None = None,
-    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None]:
+    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor]:
🤖 Fix all issues with AI agents
In `@deepmd/pt/model/descriptor/repformers.py`:
- Around line 435-437: The assertion uses an undefined attribute self.n_dim in
DescrptBlockRepformers; change the check to use the correct embedding size
attribute self.g1_dim: after slicing extended_atype_embd into atype_embd
(variable names extended_atype_embd and atype_embd), update the assertion that
compares atype_embd.shape to expect [nframes, nloc, self.g1_dim] instead of
self.n_dim so it matches the class's defined embedding dimension.
♻️ Duplicate comments (1)
source/api_cc/src/DeepSpinPT.cc (1)

414-436: Same "virial" key-availability risk in standalone module.forward(...) path.

Same comment as the forward_lower path: if "virial" isn’t guaranteed, guard and error clearly.

🧹 Nitpick comments (7)
source/tests/pt/model/test_nosel.py (4)

31-31: Module-level dtype is unused and shadowed.

This variable is reassigned inside the test loop at line 66, making this module-level assignment dead code. Consider removing it to avoid confusion.

Suggested fix
-dtype = env.GLOBAL_PT_FLOAT_PRECISION
-
-
 class TestDescrptDPA3Nosel(unittest.TestCase, TestCaseSingleFrameWithNlist):

42-42: Prefix unused unpacked variables with underscore.

nf and nloc are unpacked but never used. Per Python convention, prefix them with _ to indicate they're intentionally unused.

Suggested fix
-        nf, nloc, nnei = self.nlist.shape
+        _nf, _nloc, nnei = self.nlist.shape

105-117: Mutating shared repflow object is fragile.

After passing repflow to dd0, mutating repflow.use_dynamic_sel = True before creating dd1 relies on DescrptDPA3 copying values during __init__ rather than holding a reference. This pattern is fragile and could break if the descriptor's implementation changes.

Consider creating separate RepFlowArgs instances for clarity and robustness:

Suggested approach
             # dpa3 new impl
             dd0 = DescrptDPA3(
                 self.nt,
                 repflow=repflow,
                 # kwargs for descriptor
                 exclude_types=[],
                 precision=prec,
                 use_econf_tebd=ect,
                 type_map=["O", "H"] if ect else None,
                 seed=GLOBAL_SEED,
             ).to(env.DEVICE)

-            repflow.use_dynamic_sel = True
+            repflow_dynamic = RepFlowArgs(
+                n_dim=20,
+                e_dim=10,
+                a_dim=10,
+                nlayers=3,
+                e_rcut=self.rcut,
+                e_rcut_smth=self.rcut_smth,
+                e_sel=nnei,
+                a_rcut=self.rcut - 0.1,
+                a_rcut_smth=self.rcut_smth,
+                a_sel=nnei,
+                a_compress_rate=acr,
+                n_multi_edge_message=nme,
+                axis_neuron=4,
+                update_angle=ua,
+                update_style=rus,
+                update_residual_init=ruri,
+                optim_update=optim,
+                smooth_edge_update=True,
+                sel_reduce_factor=1.0,
+                use_dynamic_sel=True,
+            )

             # dpa3 new impl
             dd1 = DescrptDPA3(
                 self.nt,
-                repflow=repflow,
+                repflow=repflow_dynamic,
                 # kwargs for descriptor

Alternatively, use copy.deepcopy(repflow) and then set the attribute on the copy.


143-205: Consider removing or documenting the commented-out test.

This large block of commented-out code for test_jit lacks an explanation for why it's disabled. If it's a work-in-progress, add a TODO comment explaining the intent and when it should be enabled. If it's no longer needed, consider removing it to reduce code clutter.

deepmd/pt/model/network/utils.py (1)

2-4: Unused Optional import.

The Optional import is added but not used in this file. The existing code uses PEP 604 union syntax (int | None on line 14) rather than Optional[int]. Consider removing this unused import unless it's needed for future changes in this PR.

🧹 Suggested fix
 # SPDX-License-Identifier: LGPL-3.0-or-later
-from typing import (
-    Optional,
-)
-
 import torch
source/api_cc/src/DeepSpinPT.cc (1)

252-274: Guard missing "virial" key (avoid runtime outputs.at() throw).

If a TorchScript spin model doesn’t emit "virial" (older checkpoints / config-dependent outputs), outputs.at("virial") will throw. Consider checking presence and raising a clearer error.

Proposed hardening (illustrative — please verify c10::Dict API)
-  c10::IValue virial_ = outputs.at("virial");
+  // NOTE: verify the correct c10::Dict presence-check API for your libtorch version.
+  // The goal is to fail with a clearer message than an unhandled `at()` exception.
+  if (!outputs.contains("virial")) {
+    throw deepmd::deepmd_exception(
+        "Spin model output dict is missing key 'virial' (model may not support virial).");
+  }
+  c10::IValue virial_ = outputs.at("virial");
deepmd/pt/loss/ener_spin.py (1)

282-297: Make virial loss shape/dtype handling more robust.

Right now it assumes label["virial"] is already (-1, 9) and that dtypes won’t conflict with the in-place loss += .... Reshaping label virial and casting the increment avoids brittle failures.

Proposed patch
         if self.has_v and "virial" in model_pred and "virial" in label:
             find_virial = label.get("find_virial", 0.0)
             pref_v = pref_v * find_virial
-            diff_v = label["virial"] - model_pred["virial"].reshape(-1, 9)
+            virial_label = label["virial"].reshape(-1, 9)
+            virial_pred = model_pred["virial"].reshape(-1, 9)
+            diff_v = virial_label - virial_pred
             l2_virial_loss = torch.mean(torch.square(diff_v))
             if not self.inference:
                 more_loss["l2_virial_loss"] = self.display_if_exist(
                     l2_virial_loss.detach(), find_virial
                 )
-            loss += atom_norm * (pref_v * l2_virial_loss)
+            loss += (atom_norm * (pref_v * l2_virial_loss)).to(GLOBAL_PT_FLOAT_PRECISION)
             rmse_v = l2_virial_loss.sqrt() * atom_norm
             more_loss["rmse_v"] = self.display_if_exist(rmse_v.detach(), find_virial)
             if mae:
                 mae_v = torch.mean(torch.abs(diff_v)) * atom_norm
                 more_loss["mae_v"] = self.display_if_exist(mae_v.detach(), find_virial)
📜 Review details

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9553e6e and ac0ebb9.

📒 Files selected for processing (15)
  • deepmd/pt/loss/ener_spin.py
  • deepmd/pt/model/descriptor/repflow_layer.py
  • deepmd/pt/model/descriptor/repformers.py
  • deepmd/pt/model/model/make_model.py
  • deepmd/pt/model/model/spin_model.py
  • deepmd/pt/model/model/transform_output.py
  • deepmd/pt/model/network/utils.py
  • source/api_c/include/deepmd.hpp
  • source/api_c/src/c_api.cc
  • source/api_cc/src/DeepSpinPT.cc
  • source/tests/pt/model/test_autodiff.py
  • source/tests/pt/model/test_ener_spin_model.py
  • source/tests/pt/model/test_nosel.py
  • source/tests/universal/common/cases/model/utils.py
  • source/tests/universal/pt/model/test_model.py
🧰 Additional context used
🧠 Learnings (4)
📚 Learning: 2024-10-08T15:32:11.479Z
Learnt from: 1azyking
Repo: deepmodeling/deepmd-kit PR: 4169
File: deepmd/pt/loss/ener_hess.py:341-348
Timestamp: 2024-10-08T15:32:11.479Z
Learning: In `deepmd/pt/loss/ener_hess.py`, the `label` uses the key `"atom_ener"` intentionally to maintain consistency with the forked version.

Applied to files:

  • deepmd/pt/loss/ener_spin.py
📚 Learning: 2024-10-08T15:32:11.479Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4160
File: deepmd/dpmodel/utils/env_mat.py:52-64
Timestamp: 2024-10-08T15:32:11.479Z
Learning: Negative indices in `nlist` are properly handled by masking later in the computation, so they do not cause issues in indexing operations.

Applied to files:

  • deepmd/pt/model/descriptor/repflow_layer.py
📚 Learning: 2024-10-08T15:32:11.479Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4144
File: source/api_cc/tests/test_deeppot_dpa_pt.cc:166-246
Timestamp: 2024-10-08T15:32:11.479Z
Learning: Refactoring between test classes `TestInferDeepPotDpaPt` and `TestInferDeepPotDpaPtNopbc` is addressed in PR `#3905`.

Applied to files:

  • source/tests/pt/model/test_nosel.py
  • source/tests/universal/pt/model/test_model.py
📚 Learning: 2025-12-12T13:40:14.334Z
Learnt from: CR
Repo: deepmodeling/deepmd-kit PR: 0
File: AGENTS.md:0-0
Timestamp: 2025-12-12T13:40:14.334Z
Learning: Run core tests with `pytest source/tests/tf/test_dp_test.py::TestDPTestEner::test_1frame -v` to validate basic functionality

Applied to files:

  • source/tests/pt/model/test_nosel.py
🧬 Code graph analysis (6)
source/api_c/src/c_api.cc (1)
data/raw/copy_raw.py (1)
  • copy (11-71)
deepmd/pt/loss/ener_spin.py (3)
deepmd/driver.py (1)
  • label (42-75)
deepmd/pt/loss/loss.py (1)
  • display_if_exist (44-54)
deepmd/entrypoints/test.py (1)
  • mae (247-260)
deepmd/pt/model/descriptor/repflow_layer.py (2)
deepmd/pt/model/network/utils.py (1)
  • aggregate (10-50)
deepmd/pt/model/descriptor/repformer_layer.py (2)
  • list_update_res_residual (1292-1309)
  • list_update (1312-1322)
source/tests/pt/model/test_nosel.py (1)
deepmd/dpmodel/descriptor/dpa3.py (1)
  • RepFlowArgs (59-250)
deepmd/pt/model/model/spin_model.py (2)
deepmd/dpmodel/model/spin_model.py (2)
  • concat_switch_virtual (190-208)
  • process_spin_input (43-53)
deepmd/pt/utils/spin.py (1)
  • concat_switch_virtual (6-30)
source/tests/pt/model/test_ener_spin_model.py (1)
deepmd/pt/model/model/spin_model.py (1)
  • process_spin_input (54-69)
🪛 Ruff (0.14.11)
source/tests/pt/model/test_nosel.py

42-42: Unpacked variable nf is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


42-42: Unpacked variable nloc is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (42)
  • GitHub Check: Agent
  • GitHub Check: CodeQL analysis (python)
  • GitHub Check: Test Python (12, 3.13)
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Test Python (4, 3.13)
  • GitHub Check: Test Python (12, 3.10)
  • GitHub Check: Test Python (2, 3.13)
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Test Python (10, 3.13)
  • GitHub Check: Test Python (5, 3.13)
  • GitHub Check: Build C++ (rocm, rocm)
  • GitHub Check: Test Python (10, 3.10)
  • GitHub Check: Test Python (11, 3.13)
  • GitHub Check: Test Python (6, 3.13)
  • GitHub Check: Test Python (3, 3.10)
  • GitHub Check: Test Python (9, 3.10)
  • GitHub Check: Test Python (8, 3.13)
  • GitHub Check: Test Python (7, 3.10)
  • GitHub Check: Test Python (6, 3.10)
  • GitHub Check: Test Python (1, 3.10)
  • GitHub Check: Test Python (8, 3.10)
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Test Python (11, 3.10)
  • GitHub Check: Test Python (3, 3.13)
  • GitHub Check: Test Python (1, 3.13)
  • GitHub Check: Test Python (9, 3.13)
  • GitHub Check: Test Python (5, 3.10)
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
  • GitHub Check: Test Python (7, 3.13)
  • GitHub Check: Test Python (2, 3.10)
  • GitHub Check: Test Python (4, 3.10)
  • GitHub Check: Build wheels for cp311-macosx_x86_64
  • GitHub Check: Build wheels for cp311-macosx_arm64
  • GitHub Check: Build wheels for cp310-manylinux_aarch64
  • GitHub Check: Build wheels for cp311-win_amd64
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Analyze (python)
  • GitHub Check: Analyze (c-cpp)
  • GitHub Check: Test C++ (false, false, false, true)
  • GitHub Check: Test C++ (true, false, false, true)
  • GitHub Check: Test C++ (true, true, true, false)
  • GitHub Check: Test C++ (false, true, true, false)
🔇 Additional comments (23)
deepmd/pt/model/model/transform_output.py (1)

152-206: LGTM! Virial correction logic is well-integrated.

The extended_coord_corr parameter is properly gated with a None check and only applied when vdef.c_differentiable is True. The tensor operations compute the virial correction by combining the force derivative with the coordinate correction.

deepmd/pt/model/descriptor/repflow_layer.py (1)

1-1243: LGTM - Formatting-only changes.

All modifications in this file are style/formatting changes that collapse multi-line expressions into single lines. No semantic or functional changes are introduced. The logic remains identical to the previous implementation.

source/tests/universal/pt/model/test_model.py (1)

747-748: LGTM - Appropriate test flag for spin virial support.

The test_spin_virial = True flag enables virial testing specifically for the PyTorch spin energy model. The accompanying comment clearly documents that this flag can be removed once other backends implement spin virial support, which is good for future maintainability.

source/api_c/include/deepmd.hpp (1)

2594-2607: Virial reshape now consistent with other outputs (good).

Copying virial_flat into per-model virial[i] aligns this path with energy/force/force_mag and avoids surprising “all zeros” virials in the model-deviation API.

source/api_c/src/c_api.cc (1)

865-869: Virial now actually returned from DP_DeepSpinModelDeviCompute_variant (good).

This brings the virial behavior in line with force/force_mag flattening and fixes the previously “silently empty” virial output when requested.

source/tests/pt/model/test_ener_spin_model.py (2)

118-120: LGTM!

The test correctly adapts to the updated process_spin_input signature that now returns a third element (coord_corr). Using _ to discard the unused value is appropriate here since this test focuses on coordinate and type transformations rather than virial corrections.


172-180: LGTM!

The test correctly handles the extended return signature of process_spin_input_lower, which now returns five values including extended_coord_corr. The underscore appropriately discards the virial correction tensor that isn't relevant to this particular test case.

source/tests/universal/common/cases/model/utils.py (3)

918-921: LGTM!

The test_spin_virial flag provides a clean mechanism to incrementally enable virial testing for spin models. The condition if not test_spin or test_spin_virial correctly maintains backward compatibility while allowing virial tests to run for spin configurations when the flag is explicitly set.


931-933: LGTM!

Correctly includes spin in the virial calculation input when test_spin is enabled. This ensures the virial finite difference test properly exercises the spin-virial code path.


952-954: LGTM!

Consistently passes spin to the model when test_spin is enabled for the actual virial computation, matching the finite difference setup above.

source/tests/pt/model/test_autodiff.py (3)

144-154: LGTM!

The VirialTest class is properly extended to support spin models. The spin tensor generation and test_keys logic mirror the pattern already established in ForceTest, ensuring consistent behavior across both test types.


166-167: LGTM!

The spin tensor is correctly passed to eval_model in the virial inference function, enabling proper virial computation for spin-enabled models.


261-268: LGTM!

The new TestEnergyModelSpinSeAVirial class properly mirrors TestEnergyModelSpinSeAForce, completing the autodiff test coverage for spin virial functionality. The test_spin = True flag correctly enables spin-aware virial testing.

deepmd/pt/model/model/make_model.py (3)

141-142: LGTM!

The new coord_corr_for_virial parameter is properly added as an optional argument with None default, maintaining backward compatibility with existing callers.


190-196: LGTM!

The coordinate correction is correctly extended to the neighbor list region using torch.gather with the mapping tensor. The dtype conversion to cc.dtype ensures consistency with the coordinate precision. This follows the same extension pattern used elsewhere in the codebase.


263-264: LGTM!

The extended_coord_corr parameter is properly propagated through the lower interface with appropriate documentation. The parameter is correctly forwarded to fit_output_to_model_output where it will be used for virial correction.

Also applies to: 290-291, 324-324

deepmd/pt/model/model/spin_model.py (7)

62-69: LGTM!

The virial correction computation is correctly implemented. The coord_corr tensor with [zeros_like(coord), -spin_dist] properly represents that real atoms have zero correction while virtual atoms need -spin_dist correction to account for their artificial displacement in virial calculations.


89-100: LGTM!

The extended coordinate correction is correctly computed using concat_switch_virtual with torch.zeros_like(extended_coord) for real atoms and -extended_spin_dist for virtual atoms, maintaining consistency with process_spin_input.


419-432: LGTM!

The forward_common method correctly unpacks the new 3-tuple return from process_spin_input and passes coord_corr_for_virial to the backbone model's forward_common method.


473-495: LGTM!

The forward_common_lower method correctly unpacks the new 5-tuple return from process_spin_input_lower and passes extended_coord_corr_for_virial to the backbone model's forward_common_lower.


567-572: LGTM!

The translated_output_def properly includes virial and atom_virial output definitions when do_grad_c("energy") is true, enabling virial outputs for spin models.


600-604: LGTM!

The forward method correctly outputs virial and conditionally atom_virial based on do_grad_c and do_atomic_virial flags, consistent with non-spin energy models.


640-645: LGTM!

The forward_lower method correctly outputs virial and conditionally extended_virial (matching the naming convention for extended outputs in lower interfaces).

✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.

@codecov
Copy link

codecov bot commented Jan 16, 2026

Codecov Report

❌ Patch coverage is 60.08772% with 91 lines in your changes missing coverage. Please review.
✅ Project coverage is 81.88%. Comparing base (9553e6e) to head (714f587).
⚠️ Report is 5 commits behind head on master.

Files with missing lines Patch % Lines
source/api_cc/tests/test_deeppot_dpa_pt_spin.cc 29.41% 27 Missing and 9 partials ⚠️
source/api_cc/src/DeepSpinPT.cc 15.38% 29 Missing and 4 partials ⚠️
source/api_cc/tests/test_deeppot_tf_spin.cc 63.26% 18 Missing ⚠️
deepmd/pt/loss/ener_spin.py 84.61% 2 Missing ⚠️
source/api_c/src/c_api.cc 75.00% 0 Missing and 1 partial ⚠️
source/api_cc/src/DeepSpinTF.cc 66.66% 0 Missing and 1 partial ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #5156      +/-   ##
==========================================
- Coverage   81.94%   81.88%   -0.06%     
==========================================
  Files         713      714       +1     
  Lines       73009    73631     +622     
  Branches     3617     3655      +38     
==========================================
+ Hits        59826    60294     +468     
- Misses      12021    12159     +138     
- Partials     1162     1178      +16     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@source/api_cc/tests/test_deeppot_dpa_pt_spin.cc`:
- Around line 193-195: The test is indexing atom_vir (and expected_atom_v)
without verifying their lengths; add size assertions before the loop that
compares elements: check that atom_vir.size() == natoms*9 and
expected_atom_v.size() == natoms*9 (or the expected length expression used in
this test) so the subsequent for-loop comparing fabs(atom_vir[ii] -
expected_atom_v[ii]) is safe; apply the same pre-checks wherever the test later
iterates over atom_vir (the second comparison block that mirrors this loop).
🧹 Nitpick comments (1)
source/api_cc/tests/test_deeppot_dpa_pt_spin.cc (1)

46-49: Include per‑atom virial (av) in the regen snippet.

Keeps the comment aligned with the new av expectations so values can be regenerated consistently.

📝 Suggested tweak
-// print(f"{e.ravel()=} {f.ravel()=} {v.ravel()=} {fm.ravel()=}
-// {ae.ravel()=}")
+// print(f"{e.ravel()=} {f.ravel()=} {v.ravel()=} {av.ravel()=} {fm.ravel()=}
+// {ae.ravel()=}")
@@
-// print(f"{e.ravel()=} {f.ravel()=} {v.ravel()=} {fm.ravel()=}
-// {ae.ravel()=}")
+// print(f"{e.ravel()=} {f.ravel()=} {v.ravel()=} {av.ravel()=} {fm.ravel()=}
+// {ae.ravel()=}")

Also applies to: 223-227

📜 Review details

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f261d19 and 71cc203.

📒 Files selected for processing (1)
  • source/api_cc/tests/test_deeppot_dpa_pt_spin.cc
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2024-10-08T15:32:11.479Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4144
File: source/api_cc/tests/test_deeppot_dpa_pt.cc:166-246
Timestamp: 2024-10-08T15:32:11.479Z
Learning: Refactoring between test classes `TestInferDeepPotDpaPt` and `TestInferDeepPotDpaPtNopbc` is addressed in PR `#3905`.

Applied to files:

  • source/api_cc/tests/test_deeppot_dpa_pt_spin.cc
🧬 Code graph analysis (1)
source/api_cc/tests/test_deeppot_dpa_pt_spin.cc (2)
source/api_cc/tests/test_deeppot_a.cc (2)
  • ener (150-156)
  • ener (150-154)
source/api_cc/tests/test_deeppot_pt.cc (2)
  • ener (132-138)
  • ener (132-136)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (40)
  • GitHub Check: Test Python (9, 3.13)
  • GitHub Check: Test Python (11, 3.13)
  • GitHub Check: Test Python (3, 3.13)
  • GitHub Check: Test Python (10, 3.10)
  • GitHub Check: Test Python (1, 3.13)
  • GitHub Check: Test Python (12, 3.10)
  • GitHub Check: Test Python (9, 3.10)
  • GitHub Check: Test Python (8, 3.13)
  • GitHub Check: Test Python (2, 3.13)
  • GitHub Check: Test Python (3, 3.10)
  • GitHub Check: Test Python (4, 3.13)
  • GitHub Check: Test Python (7, 3.10)
  • GitHub Check: Test Python (4, 3.10)
  • GitHub Check: Test Python (6, 3.10)
  • GitHub Check: Test Python (11, 3.10)
  • GitHub Check: Test Python (10, 3.13)
  • GitHub Check: Test Python (12, 3.13)
  • GitHub Check: Test Python (7, 3.13)
  • GitHub Check: Test Python (6, 3.13)
  • GitHub Check: Test Python (2, 3.10)
  • GitHub Check: Test Python (8, 3.10)
  • GitHub Check: Test Python (5, 3.13)
  • GitHub Check: Test Python (5, 3.10)
  • GitHub Check: Test Python (1, 3.10)
  • GitHub Check: Analyze (python)
  • GitHub Check: Analyze (c-cpp)
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Build C++ (rocm, rocm)
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build wheels for cp311-macosx_x86_64
  • GitHub Check: Test C++ (false, true, true, false)
  • GitHub Check: Build wheels for cp311-win_amd64
  • GitHub Check: Build wheels for cp310-manylinux_aarch64
  • GitHub Check: Build wheels for cp311-macosx_arm64
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
  • GitHub Check: Test C++ (true, true, true, false)
  • GitHub Check: Test C++ (true, false, false, true)
  • GitHub Check: Test C++ (false, false, false, true)

✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@source/api_cc/tests/test_deeppot_dpa_pt_spin.cc`:
- Around line 54-60: The test's expected_f vector is missing one force component
(contains 17 values instead of natoms*3 == 18) causing the SetUp assertion to
fail; regenerate the full 18-value expected_f using the Python snippet in the
test comments (or the same calculation used to produce the other values),
replace the existing std::vector<VALUETYPE> expected_f in
test_deeppot_dpa_pt_spin.cc with the complete 18-element list, and ensure the
final element(s) correspond to the same ordering and precision used elsewhere in
the test so EXPECT_EQ(natoms * 3, expected_f.size()) passes.
🧹 Nitpick comments (1)
source/api_cc/tests/test_deeppot_dpa_pt_spin.cc (1)

115-119: Add expected_atom_v size assertion for consistency with TestInferDeepSpinDpaPtNopbc.

The TestInferDeepSpinDpaPtNopbc::SetUp validates both expected_tot_v.size() and expected_atom_v.size() (lines 297-298), but this class only validates expected_tot_v.size(). Add the missing check for consistency and early error detection.

🔧 Suggested fix
     EXPECT_EQ(9, expected_tot_v.size());
+    EXPECT_EQ(natoms * 9, expected_atom_v.size());
     expected_tot_e = 0.;
📜 Review details

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 71cc203 and 8d9f6e0.

📒 Files selected for processing (3)
  • source/api_cc/src/DeepSpinPT.cc
  • source/api_cc/tests/test_deeppot_dpa_pt_spin.cc
  • source/tests/pt/model/test_autodiff.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2024-09-19T04:25:12.408Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4144
File: source/api_cc/tests/test_deeppot_dpa_pt.cc:166-246
Timestamp: 2024-09-19T04:25:12.408Z
Learning: Refactoring between test classes `TestInferDeepPotDpaPt` and `TestInferDeepPotDpaPtNopbc` is addressed in PR `#3905`.

Applied to files:

  • source/api_cc/tests/test_deeppot_dpa_pt_spin.cc
🧬 Code graph analysis (3)
source/api_cc/src/DeepSpinPT.cc (2)
source/api_c/include/deepmd.hpp (2)
  • select_map (3508-3522)
  • select_map (3508-3511)
source/api_cc/src/common.cc (12)
  • select_map (935-961)
  • select_map (935-941)
  • select_map (964-982)
  • select_map (964-970)
  • select_map (1038-1044)
  • select_map (1046-1053)
  • select_map (1077-1083)
  • select_map (1085-1092)
  • select_map (1116-1122)
  • select_map (1124-1131)
  • select_map (1154-1161)
  • select_map (1163-1170)
source/api_cc/tests/test_deeppot_dpa_pt_spin.cc (2)
source/api_cc/tests/test_deeppot_a.cc (2)
  • ener (150-156)
  • ener (150-154)
source/api_cc/tests/test_deeppot_pt.cc (2)
  • ener (132-138)
  • ener (132-136)
source/tests/pt/model/test_autodiff.py (2)
source/tests/universal/common/cases/model/utils.py (1)
  • stretch_box (838-843)
source/tests/pt/common.py (1)
  • eval_model (48-307)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (40)
  • GitHub Check: Test Python (4, 3.13)
  • GitHub Check: Test Python (10, 3.10)
  • GitHub Check: Test Python (8, 3.10)
  • GitHub Check: Test Python (12, 3.10)
  • GitHub Check: Test Python (1, 3.10)
  • GitHub Check: Test Python (4, 3.10)
  • GitHub Check: Test Python (12, 3.13)
  • GitHub Check: Test Python (5, 3.10)
  • GitHub Check: Test Python (10, 3.13)
  • GitHub Check: Test Python (1, 3.13)
  • GitHub Check: Test Python (11, 3.10)
  • GitHub Check: Test Python (2, 3.13)
  • GitHub Check: Test Python (7, 3.13)
  • GitHub Check: Test Python (8, 3.13)
  • GitHub Check: Test Python (6, 3.13)
  • GitHub Check: Test Python (5, 3.13)
  • GitHub Check: Test Python (11, 3.13)
  • GitHub Check: Test Python (9, 3.10)
  • GitHub Check: Test Python (3, 3.10)
  • GitHub Check: Test Python (9, 3.13)
  • GitHub Check: Test Python (7, 3.10)
  • GitHub Check: Test Python (2, 3.10)
  • GitHub Check: Test Python (6, 3.10)
  • GitHub Check: Test Python (3, 3.13)
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Build C++ (rocm, rocm)
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build wheels for cp311-macosx_x86_64
  • GitHub Check: Build wheels for cp311-macosx_arm64
  • GitHub Check: Build wheels for cp311-win_amd64
  • GitHub Check: Build wheels for cp310-manylinux_aarch64
  • GitHub Check: Analyze (c-cpp)
  • GitHub Check: Analyze (python)
  • GitHub Check: Test C++ (true, false, false, true)
  • GitHub Check: Test C++ (false, false, false, true)
  • GitHub Check: Test C++ (true, true, true, false)
  • GitHub Check: Test C++ (false, true, true, false)
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
🔇 Additional comments (14)
source/api_cc/src/DeepSpinPT.cc (4)

254-278: LGTM! Conditional virial handling follows correct patterns.

The check-then-access pattern using outputs.contains() before outputs.at() is safe, and clearing the vector when virial is absent prevents stale data issues.


300-313: LGTM! Extended virial handling with proper atom mapping.

Using "extended_virial" key is correct for the extended-atom compute path, and select_map with stride 9 properly handles the 3×3 per-atom virial tensor mapping back to original indices.


424-447: LGTM! Virial handling in standalone compute path.

The conditional virial extraction follows the same safe pattern as the first compute method.


456-466: LGTM! Correct use of "atom_virial" key for non-extended compute path.

Using "atom_virial" (instead of "extended_virial") is correct here since this compute method doesn't involve ghost/extended atoms and doesn't require index remapping via select_map.

source/api_cc/tests/test_deeppot_dpa_pt_spin.cc (6)

151-157: Conditional virial handling pattern looks good.

The early return when virial.empty() is a reasonable approach for incremental feature enablement. The size assertion before the comparison loop protects against out-of-bounds access.


188-204: Virial and atomic virial validation looks correct.

The size assertions at lines 191 and 201 properly guard against out-of-bounds access before the comparison loops. The early return pattern allows graceful handling when virial data is not yet available.


238-281: Expected value arrays are correctly sized.

All expected arrays (expected_e, expected_f, expected_fm, expected_tot_v, expected_atom_v) have the correct element counts for 6 atoms.


297-298: SetUp validation is complete.

Both expected_tot_v and expected_atom_v size checks are present, which is the correct pattern.


423-429: Virial handling in LMP nlist test is consistent.

The same conditional pattern is applied correctly here.


468-484: Atomic virial handling is consistent with other tests.

The size assertions at lines 471 and 481 properly guard the comparison loops.

source/tests/pt/model/test_autodiff.py (4)

60-90: Solid finite-difference helper for cell derivatives.
Clear central-difference implementation with good shape documentation.


296-302: Spin virial coverage added cleanly.
The setup mirrors existing spin tests and is easy to follow.


304-347: Shear-based finite-difference virial test looks solid.
Good coverage for spin-enabled virial via cell perturbations.


177-200: No action needed—eval_model already properly guards spin inputs by model capability.

The eval_model function determines whether to pass spins to the model based on the model's has_spin attribute (line 133-135), not on whether the spins parameter is provided. Even if spins is passed for a non-spin model, the spin kwarg will not be added to the model call due to the guard on line 175-176 (if has_spin: input_dict["spin"] = batch_spin). The suggested conditional guard is redundant.

Likely an incorrect or invalid review comment.

✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@source/api_cc/tests/test_deeppot_dpa_pt_spin.cc`:
- Around line 112-116: Add a size guard for expected_atom_v before
cpu_build_nlist_atomic is called: check that expected_atom_v.size() equals
natoms * 3 (same pattern used for expected_f/expected_fm) to prevent
out-of-bounds indexing later in the test; insert an EXPECT_EQ(natoms * 3,
expected_atom_v.size()) immediately after natoms is set (near where
expected_f/expected_fm checks occur) so the nopbc-like assertion ensures safe
indexing in cpu_build_nlist_atomic.

@njzjz njzjz linked an issue Jan 20, 2026 that may be closed by this pull request
Copy link
Member

@njzjz njzjz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The LAMMPS tests can be recovered if they have passed

def test_pair_deepmd_virial(lammps) -> None:
lammps.pair_style(f"deepspin {pb_file.resolve()}")
lammps.pair_coeff("* *")
lammps.compute("peatom all pe/atom pair")
lammps.compute("pressure all pressure NULL pair")
lammps.compute("virial all centroid/stress/atom NULL pair")
lammps.variable("eatom atom c_peatom")
# for ii in range(9):
# jj = [0, 4, 8, 3, 6, 7, 1, 2, 5][ii]
# lammps.variable(f"pressure{jj} equal c_pressure[{ii+1}]")
# for ii in range(9):
# jj = [0, 4, 8, 3, 6, 7, 1, 2, 5][ii]
# lammps.variable(f"virial{jj} atom c_virial[{ii+1}]")
# lammps.dump(
# "1 all custom 1 dump id " + " ".join([f"v_virial{ii}" for ii in range(9)])
# )
lammps.run(0)
assert lammps.eval("pe") == pytest.approx(expected_e)
for ii in range(4):
assert lammps.atoms[ii].force == pytest.approx(
expected_f[lammps.atoms[ii].id - 1]
)
idx_map = lammps.lmp.numpy.extract_atom("id")[: coord.shape[0]] - 1
assert np.array(lammps.variables["eatom"].value) == pytest.approx(
expected_ae[idx_map]
)
# vol = box[1] * box[3] * box[5]
# for ii in range(6):
# jj = [0, 4, 8, 3, 6, 7, 1, 2, 5][ii]
# assert np.array(
# lammps.variables[f"pressure{jj}"].value
# ) / constants.nktv2p == pytest.approx(
# -expected_v[idx_map, jj].sum(axis=0) / vol
# )
# for ii in range(9):
# assert np.array(
# lammps.variables[f"virial{ii}"].value
# ) / constants.nktv2p == pytest.approx(expected_v[idx_map, ii])

@iProzd
Copy link
Collaborator

iProzd commented Jan 21, 2026

We might mark this PR as draft before modifications in ut and tf finished.

@iProzd iProzd marked this pull request as draft January 21, 2026 11:37
@OutisLi OutisLi marked this pull request as ready for review January 23, 2026 09:53
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
source/lmp/tests/test_lammps_spin.py (1)

283-311: Restrict pressure variable loop to 6 components; compute pressure only outputs 6 components, not 9.

The compute pressure command outputs a global vector of 6 components ([1]=Pxx, [2]=Pyy, [3]=Pzz, [4]=Pxy, [5]=Pxz, [6]=Pyz), so accessing indices 7–9 via c_pressure[7], c_pressure[8], c_pressure[9] is invalid and will cause LAMMPS to error.

The virial loop (lines 286–288) correctly uses range(9) since compute centroid/stress/atom outputs 9 components. The pressure loop (lines 283–285) must be changed to range(6):

Fix for pressure loop
-    for ii in range(9):
-        jj = [0, 4, 8, 3, 6, 7, 1, 2, 5][ii]
-        lammps.variable(f"pressure{jj} equal c_pressure[{ii + 1}]")
+    for ii in range(6):
+        jj = [0, 4, 8, 3, 6, 7][ii]
+        lammps.variable(f"pressure{jj} equal c_pressure[{ii + 1}]")
🤖 Fix all issues with AI agents
In `@deepmd/tf/descriptor/se_a.py`:
- Around line 1399-1429: The current natoms_not_match implementation misuses
tf.unique_with_counts on ghost_atype which returns counts only for present types
and leads to misaligned or out-of-range indexing when some types have zero
ghosts; replace the unique_with_counts approach with a fixed-length per-type
count (length self.ntypes) — e.g. compute ghost_natoms as
tf.math.bincount(ghost_atype, minlength=self.ntypes) or by summing boolean
equality masks per type — then build ghost_natoms_index = tf.concat([[0],
tf.cumsum(ghost_natoms)], axis=0) and proceed to slice coord using ghost_natoms
and ghost_natoms_index (keeping references to natoms_not_match, natoms_match,
self.ntypes, self.spin.use_spin, ghost_atype, ghost_natoms, ghost_natoms_index
to locate changes).
🧹 Nitpick comments (3)
source/lmp/tests/test_lammps_spin_nopbc.py (2)

74-191: Consider centralizing large expected_v fixtures.

These large inline arrays are duplicated across spin test files; extracting to a shared fixture or data file would reduce maintenance drift.


295-301: Avoid hard‑coding the atom count in velocity deviation.

Using the actual atom count makes this test resilient to future fixture changes.

♻️ Proposed fix
-    expected_md_v = (
-        np.std([np.sum(expected_v[:], axis=0), np.sum(expected_v2[:], axis=0)], axis=0)
-        / 4
-    )
+    atom_count = coord.shape[0]
+    expected_md_v = (
+        np.std([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0)
+        / atom_count
+    )
source/lmp/tests/test_lammps_spin.py (1)

335-341: Avoid hard‑coding the atom count in velocity deviation.

Use the actual atom count so the test stays correct if fixtures change.

♻️ Proposed fix
-    expected_md_v = (
-        np.std([np.sum(expected_v[:], axis=0), np.sum(expected_v2[:], axis=0)], axis=0)
-        / 4
-    )
+    atom_count = coord.shape[0]
+    expected_md_v = (
+        np.std([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0)
+        / atom_count
+    )

Comment on lines +1399 to +1429
def natoms_not_match(self, coord, natoms, atype):
diff_coord_loc = self.natoms_match(coord, natoms)
diff_coord_ghost = []
aatype = atype[0, :]
ghost_atype = aatype[natoms[0] :]
_, _, ghost_natoms = tf.unique_with_counts(ghost_atype)
ghost_natoms_index = tf.concat([[0], tf.cumsum(ghost_natoms)], axis=0)
ghost_natoms_index += natoms[0]
for i in range(self.ntypes):
if i + self.ntypes_spin >= self.ntypes:
diff_coord_ghost.append(
tf.slice(
coord,
[0, ghost_natoms_index[i] * 3],
[-1, ghost_natoms[i] * 3],
)
- tf.slice(
coord,
[0, ghost_natoms_index[i - len(self.spin.use_spin)] * 3],
[-1, ghost_natoms[i - len(self.spin.use_spin)] * 3],
)
)
else:
diff_coord_ghost.append(
tf.zeros(
[tf.shape(coord)[0], ghost_natoms[i] * 3], dtype=coord.dtype
)
)
diff_coord_ghost = tf.concat(diff_coord_ghost, axis=1)
diff_coord = tf.concat([diff_coord_loc, diff_coord_ghost], axis=1)
return diff_coord
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fix ghost type counting to avoid misaligned indices.

tf.unique_with_counts returns counts only for types that appear (and in appearance order). Indexing ghost_natoms[i] by type id will be wrong or out-of-range when some types have zero ghost atoms or the ghost list is not ordered by type. Use a fixed-length per-type count instead.

🔧 Suggested fix
-        _, _, ghost_natoms = tf.unique_with_counts(ghost_atype)
-        ghost_natoms_index = tf.concat([[0], tf.cumsum(ghost_natoms)], axis=0)
+        ghost_natoms = tf.math.bincount(
+            ghost_atype,
+            minlength=self.ntypes,
+            maxlength=self.ntypes,
+            dtype=natoms.dtype,
+        )
+        ghost_natoms_index = tf.concat([[0], tf.cumsum(ghost_natoms)], axis=0)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def natoms_not_match(self, coord, natoms, atype):
diff_coord_loc = self.natoms_match(coord, natoms)
diff_coord_ghost = []
aatype = atype[0, :]
ghost_atype = aatype[natoms[0] :]
_, _, ghost_natoms = tf.unique_with_counts(ghost_atype)
ghost_natoms_index = tf.concat([[0], tf.cumsum(ghost_natoms)], axis=0)
ghost_natoms_index += natoms[0]
for i in range(self.ntypes):
if i + self.ntypes_spin >= self.ntypes:
diff_coord_ghost.append(
tf.slice(
coord,
[0, ghost_natoms_index[i] * 3],
[-1, ghost_natoms[i] * 3],
)
- tf.slice(
coord,
[0, ghost_natoms_index[i - len(self.spin.use_spin)] * 3],
[-1, ghost_natoms[i - len(self.spin.use_spin)] * 3],
)
)
else:
diff_coord_ghost.append(
tf.zeros(
[tf.shape(coord)[0], ghost_natoms[i] * 3], dtype=coord.dtype
)
)
diff_coord_ghost = tf.concat(diff_coord_ghost, axis=1)
diff_coord = tf.concat([diff_coord_loc, diff_coord_ghost], axis=1)
return diff_coord
def natoms_not_match(self, coord, natoms, atype):
diff_coord_loc = self.natoms_match(coord, natoms)
diff_coord_ghost = []
aatype = atype[0, :]
ghost_atype = aatype[natoms[0] :]
ghost_natoms = tf.math.bincount(
ghost_atype,
minlength=self.ntypes,
maxlength=self.ntypes,
dtype=natoms.dtype,
)
ghost_natoms_index = tf.concat([[0], tf.cumsum(ghost_natoms)], axis=0)
ghost_natoms_index += natoms[0]
for i in range(self.ntypes):
if i + self.ntypes_spin >= self.ntypes:
diff_coord_ghost.append(
tf.slice(
coord,
[0, ghost_natoms_index[i] * 3],
[-1, ghost_natoms[i] * 3],
)
- tf.slice(
coord,
[0, ghost_natoms_index[i - len(self.spin.use_spin)] * 3],
[-1, ghost_natoms[i - len(self.spin.use_spin)] * 3],
)
)
else:
diff_coord_ghost.append(
tf.zeros(
[tf.shape(coord)[0], ghost_natoms[i] * 3], dtype=coord.dtype
)
)
diff_coord_ghost = tf.concat(diff_coord_ghost, axis=1)
diff_coord = tf.concat([diff_coord_loc, diff_coord_ghost], axis=1)
return diff_coord
🤖 Prompt for AI Agents
In `@deepmd/tf/descriptor/se_a.py` around lines 1399 - 1429, The current
natoms_not_match implementation misuses tf.unique_with_counts on ghost_atype
which returns counts only for present types and leads to misaligned or
out-of-range indexing when some types have zero ghosts; replace the
unique_with_counts approach with a fixed-length per-type count (length
self.ntypes) — e.g. compute ghost_natoms as tf.math.bincount(ghost_atype,
minlength=self.ntypes) or by summing boolean equality masks per type — then
build ghost_natoms_index = tf.concat([[0], tf.cumsum(ghost_natoms)], axis=0) and
proceed to slice coord using ghost_natoms and ghost_natoms_index (keeping
references to natoms_not_match, natoms_match, self.ntypes, self.spin.use_spin,
ghost_atype, ghost_natoms, ghost_natoms_index to locate changes).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature Request] pt: support virial for the spin model

3 participants