Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
ae35ea8
refactor(dpmodel): plumb comm_dict and extract _exchange_ghosts hook
Apr 25, 2026
bfe650f
feat(op): expose deepmd::border_op_backward as a standalone op
Apr 25, 2026
3af514a
feat(pt_expt): add deepmd_export::border_op opaque wrapper + block ov…
Apr 25, 2026
2936bd4
fix(pt_expt): plumb comm_dict through SpinModel + guards
Apr 25, 2026
b22feb7
feat(pt_expt): two-mode AOTInductor export with comm_dict
Apr 25, 2026
4b707a7
test(pt_expt): add comm_dict eager parity + export round-trip suite
Apr 26, 2026
0bd131a
fix(cc): link TORCH_LIBRARIES in api_cc tests so pt_expt tests run
Apr 26, 2026
cdef9d5
feat(cc): wire DeepPotPTExpt and DeepSpinPTExpt for multi-rank GNN
Apr 26, 2026
1ad6103
feat(gnn-mpi): wire up multi-rank LAMMPS path end-to-end
Apr 26, 2026
8b2501d
test(gnn-mpi): expand multi-rank coverage; address Phase 5 follow-up …
Apr 26, 2026
c43bd8b
test(gnn-mpi): tighten multi-rank LAMMPS test assertions
Apr 26, 2026
1706435
fix(cc): handle empty subdomain in copy_from_nlist; expand MPI tests …
Apr 26, 2026
b54b8f3
Merge branch 'master' from upstream into feat-pt-expt-gnn-mpi
Apr 26, 2026
a81fc10
test: cover DPA2 multi-rank dispatch + fix opaque-op import order
Apr 27, 2026
ece5c3d
test: extend MPI coverage with N>2 decompositions and schema-drift un…
Apr 30, 2026
0ef1bfc
test: cover NULL-type atoms (atype<0) under mpirun
Apr 30, 2026
0c95b3a
test: cover three NULL-type edge cases (isolated / all-null-rank / nl…
Apr 30, 2026
ad7761c
test: NULL atoms cross rank boundary; prune redundant decomposition
Apr 30, 2026
b25e00c
test: mixed-direction NULL velocities + real-atom thermal motion
Apr 30, 2026
124dc5e
test: empty-subdomain test exercises cached mapping_tensor path
Apr 30, 2026
5fef5c6
test(spin-mpi): cover spin GNN multi-rank end-to-end
Apr 30, 2026
803b2a4
test(spin-mpi): cover empty-subdomain and NULL-type for spin DPA3
May 1, 2026
47f0c29
test(spin-mpi): drop committed deeppot_dpa3_spin_mpi.yaml
May 1, 2026
c6a38e6
Merge remote-tracking branch 'upstream/master' into feat-pt-expt-gnn-mpi
May 1, 2026
3c9ee65
fix(jax): accept comm_dict kwarg in forward_common_atomic
May 1, 2026
87c9f3f
fix(pt_expt): auto-load underlying ops in comm.py
May 2, 2026
4865c4e
chore: drop redundant ``import deepmd.pt`` preloads
May 2, 2026
bf1685f
fix: address coderabbitai review on PR 5430
May 2, 2026
a429fc9
refactor: replace _has_message_passing hack with descriptor API
May 2, 2026
08805b6
fix(test): build comm_dict control tensors on CPU for repflow_parallel
May 3, 2026
afa99c7
fix(test): build comm_dict control tensors on CPU for repformer_parallel
May 3, 2026
8463af8
Merge remote-tracking branch 'upstream/master' into feat-pt-expt-gnn-mpi
May 3, 2026
e19108d
fix(op): dispatch border_op self-send on tensor device, not MPI state
May 4, 2026
4f8240e
fix(op): drain pending MPI eager-send ACKs in border_op via Barrier
May 4, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions deepmd/dpmodel/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ def forward_common_atomic(
mapping: Array | None = None,
fparam: Array | None = None,
aparam: Array | None = None,
comm_dict: dict | None = None,
) -> dict[str, Array]:
"""Common interface for atomic inference.

Expand All @@ -252,6 +253,9 @@ def forward_common_atomic(
frame parameters, shape: nf x dim_fparam
aparam
atomic parameter, shape: nf x nloc x dim_aparam
comm_dict
MPI communication metadata for parallel inference. ``None`` for
non-parallel inference (default).

Returns
-------
Expand Down Expand Up @@ -279,6 +283,7 @@ def forward_common_atomic(
mapping=mapping,
fparam=fparam,
aparam=aparam,
comm_dict=comm_dict,
)
ret_dict = self.apply_out_stat(ret_dict, atype)

Expand Down
5 changes: 5 additions & 0 deletions deepmd/dpmodel/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def forward_atomic(
mapping: Array | None = None,
fparam: Array | None = None,
aparam: Array | None = None,
comm_dict: dict | None = None,
) -> dict[str, Array]:
"""Models' atomic predictions.

Expand All @@ -174,6 +175,9 @@ def forward_atomic(
frame parameter. nf x ndf
aparam
atomic parameter. nf x nloc x nda
comm_dict
MPI communication metadata for parallel inference. ``None`` for
non-parallel inference (default). Forwarded to the descriptor.

Returns
-------
Expand Down Expand Up @@ -215,6 +219,7 @@ def forward_atomic(
nlist,
mapping=mapping,
fparam=fparam_input_for_des if self.add_chg_spin_ebd else None,
comm_dict=comm_dict,
)
ret = self.fitting_net(
descriptor,
Expand Down
6 changes: 6 additions & 0 deletions deepmd/dpmodel/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ def forward_atomic(
mapping: Array | None = None,
fparam: Array | None = None,
aparam: Array | None = None,
comm_dict: dict | None = None,
) -> dict[str, Array]:
"""Return atomic prediction.

Expand All @@ -241,6 +242,10 @@ def forward_atomic(
frame parameter. (nframes, ndf)
aparam
atomic parameter. (nframes, nloc, nda)
comm_dict
MPI communication metadata. Forwarded to each sub-model so GNN
sub-descriptors can perform parallel ghost exchange. ``None`` for
non-parallel inference (default).

Returns
-------
Expand Down Expand Up @@ -280,6 +285,7 @@ def forward_atomic(
mapping,
fparam,
aparam,
comm_dict,
)["energy"]
)
weights = self._compute_weight(extended_coord, extended_atype, nlists_)
Expand Down
2 changes: 2 additions & 0 deletions deepmd/dpmodel/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,9 @@ def forward_atomic(
mapping: Array | None = None,
fparam: Array | None = None,
aparam: Array | None = None,
comm_dict: dict | None = None,
) -> dict[str, Array]:
del comm_dict # pairtab is local; no MPI ghost exchange needed.
xp = array_api_compat.array_namespace(extended_coord, extended_atype, nlist)
nframes, nloc, nnei = nlist.shape
extended_coord = xp.reshape(extended_coord, (nframes, -1, 3))
Expand Down
9 changes: 9 additions & 0 deletions deepmd/dpmodel/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,14 @@ def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return self.se_atten.has_message_passing()

def has_message_passing_across_ranks(self) -> bool:
"""Returns whether per-layer node embeddings need MPI ghost exchange.

DPA1 (se_atten) is single-layer and does not exchange features
across ranks; same as the base se_e2_a path.
"""
return False

def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the descriptor needs sorted nlist when using `forward_lower`."""
return self.se_atten.need_sorted_nlist_for_lower()
Expand Down Expand Up @@ -500,6 +508,7 @@ def call(
nlist: Array,
mapping: Array | None = None,
fparam: Array | None = None,
comm_dict: dict | None = None,
) -> Array:
"""Compute the descriptor.

Expand Down
32 changes: 29 additions & 3 deletions deepmd/dpmodel/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,16 @@ def has_message_passing(self) -> bool:
[self.repinit.has_message_passing(), self.repformers.has_message_passing()]
)

def has_message_passing_across_ranks(self) -> bool:
"""Returns whether per-layer node embeddings need MPI ghost exchange.

DPA2's repformers always passes ``g1`` in ``[nb, nall, n_dim]``
layout (no ``use_loc_mapping`` opt-out exists at the block level),
so multi-rank deployment always needs cross-rank exchange of
per-atom features between layers.
"""
return self.repformers.has_message_passing_across_ranks()

def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the descriptor needs sorted nlist when using `forward_lower`."""
return True
Expand Down Expand Up @@ -831,6 +841,7 @@ def call(
nlist: Array,
mapping: Array | None = None,
fparam: Array | None = None,
comm_dict: dict | None = None,
) -> tuple[Array, Array, Array, Array, Array]:
"""Compute the descriptor.

Expand All @@ -844,6 +855,11 @@ def call(
The neighbor list. shape: nf x nloc x nnei
mapping
The index mapping, maps extended region index to local region.
comm_dict
MPI communication metadata for parallel inference. Forwarded to
the repformer block (the message-passing part). The repinit
sub-block does no message passing and does not receive it.
``None`` for non-parallel inference (default).

Returns
-------
Expand Down Expand Up @@ -912,9 +928,18 @@ def call(
assert self.tebd_transform is not None
g1 = g1 + self.tebd_transform(g1_inp)
# mapping g1
assert mapping is not None
mapping_ext = xp.tile(xp.expand_dims(mapping, axis=-1), (1, 1, g1.shape[-1]))
g1_ext = xp_take_along_axis(g1, mapping_ext, axis=1)
if comm_dict is None:
# non-parallel: gather g1 -> g1_ext via mapping, hand the
# nall-sized embedding to the repformer block.
assert mapping is not None
mapping_ext = xp.tile(
xp.expand_dims(mapping, axis=-1), (1, 1, g1.shape[-1])
)
g1_ext = xp_take_along_axis(g1, mapping_ext, axis=1)
else:
# parallel mode: hand the local-only g1 to the repformer block;
# its per-layer override fills ghosts via the MPI exchange.
g1_ext = g1
# repformer
g1, g2, h2, rot_mat, sw = self.repformers(
nlist_dict[
Expand All @@ -926,6 +951,7 @@ def call(
atype_ext,
g1_ext,
mapping,
comm_dict=comm_dict,
)
if self.concat_output_tebd:
g1 = xp.concat([g1, g1_inp], axis=-1)
Expand Down
16 changes: 16 additions & 0 deletions deepmd/dpmodel/descriptor/dpa3.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,17 @@ def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return self.repflows.has_message_passing()

def has_message_passing_across_ranks(self) -> bool:
"""Returns whether per-layer node embeddings need MPI ghost exchange.

Delegates to repflows: ``False`` when ``use_loc_mapping=True``
(per-layer messages stay within each rank's local atoms),
``True`` when ``use_loc_mapping=False`` (ghost slots in
``[nb, nall, n_dim]`` layout must be filled by cross-rank
exchange before each layer).
"""
return self.repflows.has_message_passing_across_ranks()

def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the descriptor needs sorted nlist when using `forward_lower`."""
return True
Expand Down Expand Up @@ -616,6 +627,7 @@ def call(
nlist: Array,
mapping: Array | None = None,
fparam: Array | None = None,
comm_dict: dict | None = None,
) -> tuple[Array, Array, Array, Array, Array]:
"""Compute the descriptor.

Expand All @@ -629,6 +641,9 @@ def call(
The neighbor list. shape: nf x nloc x nnei
mapping
The index mapping, mapps extended region index to local region.
comm_dict
MPI communication metadata for parallel inference. Forwarded to
the repflows block. ``None`` for non-parallel inference (default).

Returns
-------
Expand Down Expand Up @@ -695,6 +710,7 @@ def call(
atype_ext,
node_ebd_ext,
mapping,
comm_dict=comm_dict,
)
if self.concat_output_tebd:
node_ebd = xp.concat([node_ebd, node_ebd_inp], axis=-1)
Expand Down
15 changes: 14 additions & 1 deletion deepmd/dpmodel/descriptor/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,16 @@ def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return any(descrpt.has_message_passing() for descrpt in self.descrpt_list)

def has_message_passing_across_ranks(self) -> bool:
"""Returns whether per-layer node embeddings need MPI ghost exchange.

``True`` if any child descriptor needs cross-rank message passing
(e.g. a hybrid wrapping a DPA3 with ``use_loc_mapping=False``).
"""
return any(
descrpt.has_message_passing_across_ranks() for descrpt in self.descrpt_list
)

def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the descriptor needs sorted nlist when using `forward_lower`."""
return True
Expand Down Expand Up @@ -276,6 +286,7 @@ def call(
nlist: Array,
mapping: Array | None = None,
fparam: Array | None = None,
comm_dict: dict | None = None,
) -> tuple[
Array,
Array | None,
Expand Down Expand Up @@ -332,7 +343,9 @@ def call(
# mixed_types is True, but descrpt.mixed_types is False
assert nl_distinguish_types is not None
nl = nl_distinguish_types[:, :, nci]
odescriptor, gr, g2, h2, sw = descrpt(coord_ext, atype_ext, nl, mapping)
odescriptor, gr, _g2, _h2, _sw = descrpt(
coord_ext, atype_ext, nl, mapping, comm_dict=comm_dict
)
out_descriptor.append(odescriptor)
if gr is not None:
out_gr.append(gr)
Expand Down
18 changes: 18 additions & 0 deletions deepmd/dpmodel/descriptor/make_base_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,24 @@ def mixed_types(self) -> bool:
def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""

def has_message_passing_across_ranks(self) -> bool:
"""Returns whether the descriptor's message passing extends across rank
boundaries — i.e. whether it requires cross-rank exchange of intermediate
atomic features (per-layer node embeddings) during the forward pass.

Distinct from generic ghost-coord/force exchange that every LAMMPS
pair_style does. This question gates whether the pt_expt backend
compiles a second "with-comm" AOTI artifact for multi-rank deployment.

Concrete default ``False`` (non-GNN behavior) so pt and pd backend
descriptors that subclass ``BaseDescriptor`` directly do not have
to implement this method until they grow a multi-rank GNN path of
their own. GNN descriptors that need MPI ghost-feature exchange
(DPA2, DPA3 with ``use_loc_mapping=False``, hybrids wrapping such
children) override to return ``True``.
"""
return False

@abstractmethod
def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the descriptor needs sorted nlist when using `forward_lower`."""
Expand Down
60 changes: 53 additions & 7 deletions deepmd/dpmodel/descriptor/repflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,32 @@ def reinit_exclude(
self.exclude_types = exclude_types
self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types)

def _exchange_ghosts(
self,
node_ebd: Array,
mapping_tiled: Array | None,
comm_dict: dict | None,
nall: int,
nloc: int,
) -> Array:
"""Build node_ebd_ext (the ghost-aware embedding) for the per-layer loop.

Default: array-api gather via the pre-tiled `mapping_tiled`, or pass the
local-only `node_ebd` through when ``self.use_loc_mapping`` is set.
``comm_dict``, ``nall``, ``nloc`` are unused in this default impl; they
exist so the pt_expt subclass can perform the per-layer MPI ghost
exchange (``deepmd_export::border_op``) when ``comm_dict is not None``.
"""
del comm_dict, nall, nloc
if self.use_loc_mapping:
return node_ebd
if mapping_tiled is None:
raise ValueError(
"`mapping` is required when use_loc_mapping=False unless "
"`_exchange_ghosts` is overridden for parallel comm handling."
)
return xp_take_along_axis(node_ebd, mapping_tiled, axis=1)
Comment thread
coderabbitai[bot] marked this conversation as resolved.

def call(
self,
nlist: Array,
Expand All @@ -514,6 +540,7 @@ def call(
atype_embd_ext: Array | None = None,
mapping: Array | None = None,
type_embedding: Array | None = None,
comm_dict: dict | None = None,
) -> tuple[Array, Array, Array, Array, Array]:
xp = array_api_compat.array_namespace(nlist, coord_ext, atype_ext)
nframes, nloc, nnei = nlist.shape
Expand Down Expand Up @@ -641,15 +668,24 @@ def call(
# nf x nloc x a_nnei x a_nnei x a_dim [OR] n_angle x a_dim
angle_ebd = self.angle_embd(angle_input)

# nb x nall x n_dim
mapping = xp.tile(xp.expand_dims(mapping, axis=-1), (1, 1, self.n_dim))
# nb x nall x n_dim (pre-tiled mapping reused across layers when not
# using comm_dict). Skip the tile when mapping is None — pt_expt's
# parallel-mode override consults comm_dict instead.
mapping_tiled = (
xp.tile(xp.expand_dims(mapping, axis=-1), (1, 1, self.n_dim))
if mapping is not None
else None
)
for idx, ll in enumerate(self.layers):
# node_ebd: nb x nloc x n_dim
# node_ebd_ext: nb x nall x n_dim
node_ebd_ext = (
node_ebd
if self.use_loc_mapping
else xp_take_along_axis(node_ebd, mapping, axis=1)
# node_ebd_ext: nb x nall x n_dim (or nb x nloc x n_dim when
# use_loc_mapping=True)
node_ebd_ext = self._exchange_ghosts(
node_ebd,
mapping_tiled,
comm_dict,
nall,
nloc,
)
node_ebd, edge_ebd, angle_ebd = ll.call(
node_ebd_ext,
Expand Down Expand Up @@ -696,6 +732,16 @@ def has_message_passing(self) -> bool:
"""Returns whether the descriptor block has message passing."""
return True

def has_message_passing_across_ranks(self) -> bool:
"""Returns whether per-layer node embeddings need MPI ghost exchange.

Repflows passes ``node_ebd`` either in ``[nb, nloc, n_dim]`` layout
(``use_loc_mapping=True``: messages stay within the rank's local atoms)
or ``[nb, nall, n_dim]`` layout (``use_loc_mapping=False``: ghost slots
must be filled by cross-rank exchange before each layer).
"""
return not self.use_loc_mapping

def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the descriptor block needs sorted nlist when using `forward_lower`."""
return True
Expand Down
Loading
Loading