Skip to content

mesh: enable ShardTensor support for mesh conversion/geometry paths#1608

Open
loliverhennigh wants to merge 20 commits into
NVIDIA:mainfrom
loliverhennigh:mesh-shardtensor-mesh-support
Open

mesh: enable ShardTensor support for mesh conversion/geometry paths#1608
loliverhennigh wants to merge 20 commits into
NVIDIA:mainfrom
loliverhennigh:mesh-shardtensor-mesh-support

Conversation

@loliverhennigh
Copy link
Copy Markdown
Collaborator

PhysicsNeMo Pull Request

Description

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Apr 30, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

Comment thread physicsnemo/domain_parallel/__init__.py Outdated

from ._shard_tensor_spec import ShardTensorSpec
from .shard_tensor import ShardTensor, scatter_tensor
from .shard_tensor import ShardTensor, replicated_zeros_like, scatter_tensor
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @coreyjadams, if I am adding the zeros_like correct I might add a few similar ops just consistency even though they are not needed.

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 30, 2026

Greptile Summary

This PR adds ShardTensor support to the mesh conversion and geometry computation paths by introducing replicated_zeros_like/new_replicated_zeros helpers, a _cross_wrapper handler for torch.linalg.cross/torch.cross, replacing advanced-index patterns in mesh.py with expand/reshape, and wrapping torch.zeros calls in _scatter_ops.py with ShardTensor-aware allocators. Test coverage is added for both dense and sharded modes via an opt-in environment variable.

Important Files Changed

Filename Overview
physicsnemo/domain_parallel/shard_tensor.py Adds new_replicated_zeros, replicated_zeros_like, and _cross_wrapper; the wrapper has a default-dim mismatch for torch.cross calls and a hardcoded error message.
physicsnemo/domain_parallel/init.py Exports replicated_zeros_like in both the available and stub branches; straightforward and correct.
physicsnemo/mesh/mesh.py Replaces advanced-indexing (cell_values[cell_indices]) with unsqueeze/expand/reshape to avoid ShardTensor incompatibility; semantics are equivalent.
physicsnemo/mesh/utilities/_scatter_ops.py Introduces _is_sharded_tensor (duck-types on private _spec) and _replicated_zeros_like; the private-attribute check is fragile across PyTorch versions.
test/mesh/mesh/test_data_conversion.py Adds ShardTensor-aware fixtures and mode-switching; _single_rank_dist_group uses NamedTemporaryFile which pre-creates the rendezvous file before init_process_group.
test/mesh/mesh/test_geometry_properties.py Adds ShardTensor geometry test modes; shares the same NamedTemporaryFile concern as test_data_conversion.py and duplicates several helpers that could live in a shared conftest.

Comments Outside Diff (2)

  1. physicsnemo/mesh/utilities/_scatter_ops.py, line 173-174 (link)

    P2 Duck-type detection relies on a private attribute

    _is_sharded_tensor checks for hasattr(tensor, "_spec") — a private attribute. If ShardTensor/DTensor renames or reorganises that attribute in a future PyTorch release the check will silently fall through to the plain-tensor branch, causing incorrect zero-tensor allocation in distributed runs rather than a clear error. Prefer an isinstance check against the optionally-imported ShardTensor type:

    try:
        from physicsnemo.domain_parallel import ShardTensor as _ShardTensor
    except ImportError:
        _ShardTensor = None
    
    def _is_sharded_tensor(tensor: torch.Tensor) -> bool:
        return _ShardTensor is not None and isinstance(tensor, _ShardTensor)
  2. test/mesh/mesh/test_data_conversion.py, line 411-430 (link)

    P2 NamedTemporaryFile creates the file before init_process_group

    tempfile.NamedTemporaryFile(delete=True) creates and opens the file immediately, so it already exists when dist.init_process_group(init_method=f"file://{f.name}", …) is called. PyTorch's FileStore (file-based rendezvous) expects to create/own the file itself; a pre-existing file may cause silent corruption or an AssertionError on certain platforms.

    A more robust pattern is to use a path that does not yet exist:

    with tempfile.TemporaryDirectory(prefix="mesh_shard_pg_") as tmpdir:
        dist.init_process_group(
            backend="gloo",
            init_method=f"file://{tmpdir}/rendezvous",
            rank=0,
            world_size=1,
        )

    The same pattern is duplicated in test/mesh/mesh/test_geometry_properties.py and should be updated there as well.

Reviews (1): Last reviewed commit: "mesh: enable ShardTensor support for mes..." | Re-trigger Greptile

Comment on lines +1027 to +1036
dim = kwargs.get("dim", -1)
if len(args) > 2:
dim = args[2]

if not isinstance(input_tensor, ShardTensor) or not isinstance(
other_tensor, ShardTensor
):
raise RuntimeError(
"torch.linalg.cross with ShardTensor inputs requires both arguments to be ShardTensor."
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 torch.cross dim defaulting diverges from original behavior

When _cross_wrapper is invoked via the torch.cross handler (registered on line 1062) and the caller omits dim, the wrapper defaults to dim=-1. However, torch.cross (pre-deprecation) auto-detects the first dimension of size 3, which may not be the last dimension. Any call like torch.cross(a_shard, b_shard) where the cross-product axis isn't the last one will silently produce a wrong result instead of matching the original op's semantics.

For torch.linalg.cross, dim is keyword-only and defaults to -1, so the current default is correct there. For the torch.cross handler, consider either (a) raising explicitly when dim is absent to force callers to use the unambiguous form, or (b) documenting that only the torch.linalg.cross semantic (dim=-1) is supported.

Comment on lines +1031 to +1036
if not isinstance(input_tensor, ShardTensor) or not isinstance(
other_tensor, ShardTensor
):
raise RuntimeError(
"torch.linalg.cross with ShardTensor inputs requires both arguments to be ShardTensor."
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Hardcoded error message for both torch.cross and torch.linalg.cross

The error string says "torch.linalg.cross with ShardTensor inputs…" but this wrapper is registered for both torch.linalg.cross (line 1060) and torch.cross (line 1062). When the handler fires via torch.cross, the message will mislead users.

Suggested change
if not isinstance(input_tensor, ShardTensor) or not isinstance(
other_tensor, ShardTensor
):
raise RuntimeError(
"torch.linalg.cross with ShardTensor inputs requires both arguments to be ShardTensor."
)
raise RuntimeError(
f"{func.__module__}.{func.__name__} with ShardTensor inputs requires both arguments to be ShardTensor."
)

@loliverhennigh
Copy link
Copy Markdown
Collaborator Author

Blaa, some stuff got messed up and not quite ready for review. Ill fix tonight...



def _is_sharded_tensor(tensor: torch.Tensor) -> bool:
return hasattr(tensor, "_spec") and hasattr(type(tensor), "from_local")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@coreyjadams is this best-practices for type-narrowing ShardTensor?

Feels like isinstance(tensor, ShardTensor) would be better, unless there's something I'm missing?

(And if this does get replaced with isinstance, then this can be inlined rather than keeping a separate function.)

)


def _mesh_to_mode(mesh: Mesh, *, mesh_tensor_mode: str, mesh_shard_device_mesh) -> Mesh:
Copy link
Copy Markdown
Collaborator

@peterdsharpe peterdsharpe Apr 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docstrings needed (here and in above functions)



@pytest.fixture(params=_MESH_TENSOR_MODES)
def mesh_tensor_mode(request):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Type hints?



@pytest.fixture
def mesh_shard_device_mesh(mesh_tensor_mode, _single_rank_dist_group):
Copy link
Copy Markdown
Collaborator

@peterdsharpe peterdsharpe Apr 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Type hints + docstrings needed

Comment thread test/mesh/mesh/test_data_conversion.py Outdated
n_cells: int,
point_placement: Replicate | Shard,
cell_placement: Replicate | Shard,
):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstring needed

r"""Create zeros matching a tensor's device/mesh semantics.

For ``ShardTensor`` inputs this returns a replicated ``ShardTensor`` on the
same mesh. For regular tensors this falls back to ``torch.zeros`` on the
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that this is part of the public API, please add complete NumPy-style docstring

Comment thread physicsnemo/mesh/mesh.py
src_data=cell_values[cell_indices],
src_data=cell_values.unsqueeze(1)
.expand(-1, n_vertices_per_cell, *cell_values.shape[1:])
.reshape(-1, *cell_values.shape[1:]),
Copy link
Copy Markdown
Collaborator

@peterdsharpe peterdsharpe Apr 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is Mesh.cell_data_to_point_data() updated, but Mesh.point_data_to_cell_data() is not? The asymmetry seems suspect - seems like they should either both need updates for ShardTensor or neither should, but I might be missing something?

Or is this just an unrelated change (and if so, what motivated this)?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its because we needed a shardTensor to index hence the change. In the point_data_to_cell_data we use cells to index but that is already a ShardTensor.

Comment thread test/mesh/mesh/test_data_conversion.py Outdated


def _convert_data_for_mode(
data: dict[str, torch.Tensor] | None,
Copy link
Copy Markdown
Collaborator

@peterdsharpe peterdsharpe Apr 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function incorrectly:
a) accepts dict[str, torch.Tensor] instead of TensorDict
b) iterates only over top-level keys.

Combined with downstream uses that pass in TensorDict, this causes issues in hierarchical dictionaries. (Specifically, this will cause nested values to remain as torch.Tensor.)

Consider using TensorDict.apply() instead.

Copy link
Copy Markdown
Collaborator

@coreyjadams coreyjadams left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @loliverhennigh - I have left a number of comments on the PR. I think there are some updates to make, though I'm happy to see the changes overall are not too extensive. I think there are some design choices that cut against the philosophy of ShardTensor we should tweak, but it doesn't look likes going to be too much.

Happy to discuss this with as much detail as you'd like offline!


return _ToTorchTensor.apply(self, grad_placements)

def new_replicated_zeros(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To add this function as an API call on shard tensor, vs supporting the underlying dispatch call, has to have a really good motivation.

What is the value of this vs. supporting the backend of torch.zeros_like(a) when a is a shard tensor? and, in fact, I think that should already work?

Comment on lines +936 to +940
return ShardTensor.from_local(
local,
self._spec.mesh,
[Replicate() for _ in range(self._spec.mesh.ndim)],
sharding_shapes="infer",
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is unusual here: typically I'd object loudly to passing "infer" as sharding shapes since that can trigger a blocking allreduce and is a major perf headache. but you're replicating on all ranks. I don't think I understand this function's role, really.

Comment on lines +1000 to +1005
def replicated_zeros_like(
tensor: torch.Tensor,
shape: Sequence[int] | torch.Size,
*,
dtype: torch.dtype | None = None,
) -> torch.Tensor:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still think it makes more sense to support torch.zeros_like(a) for a as a replicated tensor.

tl;dr the ShardTensor design philosophy is to make zero code changes on user side, whenever possible, so we support torch calls on shard tensors first and foremost rather then introduce new API. Is it possible to implement your work without this?

Comment on lines +1018 to +1023
def _cross_wrapper(func, types, args, kwargs):
if kwargs is None:
kwargs = {}

if kwargs.get("out", None) is not None:
raise RuntimeError("torch.linalg.cross(out=...) is not supported for ShardTensor.")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is a function overload, it's in the wrong place. shard_tensor.py is for the core tensor object only. There is a sub folder for ops.

Comment on lines +1034 to +1036
raise RuntimeError(
"torch.linalg.cross with ShardTensor inputs requires both arguments to be ShardTensor."
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like what we want to be doing is implementing a wrapper for torch.linalg.cross on shard tensor objects. There is not a need to check that all objects are ShardTensor at this time.

Comment on lines +96 to +117
aggregated_data = torch.zeros((n_dst, *data_shape), dtype=dtype, device=device)
aggregated_data = _replicated_zeros_like(
src_data, (n_dst, *data_shape), dtype
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The distributed-friendly update to make here is, if mesh can support it, to port torch.zeros away from this shape and onto torch.zeros_like. So whatever is building data_shape as input, we build zeros_like(that_object) and then mesh can work on single device and sharded inputs too.

Comment on lines +105 to +129
weights = torch.ones(len(src_to_dst_mapping), dtype=dtype, device=device)
if _is_sharded_tensor(src_data):
weights = torch.ones_like(src_to_dst_mapping, dtype=dtype)
else:
weights = torch.ones(len(src_to_dst_mapping), dtype=dtype, device=device)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is much closer to the "right" way for domain parallelism, but in fact we can probably consolidate to just "ones_like" for all paths.

dtype=dtype,
device=device,
)
aggregated_data = _replicated_zeros_like(src_data, (n_dst, *data_shape), dtype)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here re:zeros_like

Comment on lines +137 to +157
weight_sums = torch.zeros(n_dst, dtype=dtype, device=device)
weight_sums = _replicated_zeros_like(src_data, (n_dst,), dtype)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here re:zeros_like

Comment thread physicsnemo/mesh/mesh.py

converted = self.cell_data.apply(
lambda cell_values: scatter_aggregate(
src_data=cell_values[cell_indices],
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indexing a ShardTensor with a ShardTensor index should work fine?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cell_indices is not a ShardTensor the way it was before. I was coming from just the torch.arange function

@loliverhennigh
Copy link
Copy Markdown
Collaborator Author

\blossom-ci

1 similar comment
@loliverhennigh
Copy link
Copy Markdown
Collaborator Author

\blossom-ci

@loliverhennigh
Copy link
Copy Markdown
Collaborator Author

Hi @coreyjadams - quick follow-up on this PR. The latest GitHub Actions checks are green, and I re-triggered Blossom with \blossom-ci, but I still do not see a Blossom status/check showing up on the PR. Could you take a look when you have a chance, or let me know if there is another validation step I should trigger?

Comment thread physicsnemo/mesh/utilities/_scatter_ops.py
Comment thread test/mesh/mesh/test_geometry_properties.py
Comment thread test/mesh/mesh/test_geometry_properties.py Outdated
@loliverhennigh
Copy link
Copy Markdown
Collaborator Author

\blossom-ci

@loliverhennigh
Copy link
Copy Markdown
Collaborator Author

/blossom-ci

@coreyjadams coreyjadams added the ! - Release PRs or Issues releating to a release label May 11, 2026
@coreyjadams coreyjadams removed the ! - Release PRs or Issues releating to a release label May 11, 2026
@coreyjadams
Copy link
Copy Markdown
Collaborator

I took a look at trying to update the tests to align them with the domain parallelism tests; the test coverage here was only ~60 tests, all on CPU, for domain parallelism. It isn't aligned with the torchrun syntax that domain parallel tests expect (look at /test/plugins/distributed_fixtures.py for more info).

The cross operation looks like it's pretty much implemented OK but the tests aren't quite aligned with the way the op testing is done in the domain parallelism suite, so we're falling short there too.

I just don't think there is sufficient test coverage to say if we've achieved domain parallelism on PhysicsNeMo mesh. I don't think this is ready to go for the RC. There are ~2000 tests for mesh, and while we don't need all of them, being able to test against most of the core functionality + things we would want in datapipes, models, preprocessing etc for domain parallelism is, I think, a prerequisite for merge here. 60 tests (3% coverage max, if we're generous) all on CPU I think is probably short?

@peterdsharpe Any suggestions of what core mesh functionality should be prioritized for testing in a distributed mode? I'm thinking we need generic mesh operations (slicing, scatter_ops, vertex_to_cell and the opposite) as well as manifold projection operations probably. What else?

I think we'd also want to be able to have good distributed-mesh-init testing, like how to initialize a mesh with sharded tensors and make sure it works, gradients flow, we can resize / reshape / redistribute it, etc. Core tensor ops are covered with shard tensor already, so just mesh-specific things are fine.

@peterdsharpe
Copy link
Copy Markdown
Collaborator

Yeah, I think these are all great calls @coreyjadams . Re Mesh testing, the high-risk areas that might be good to add testing coverage mesh for a sharded Mesh are as follows:

  • test/mesh/mesh and test/mesh/geomery - these are kind of the "most core" Mesh data structure ops (cheap, unlikely to be broken by sharding, but extremely widespread usage)
  • test/mesh/test_transformations.py - geometry transformations. This is low-risk (in the sense that I'm very confident this will work fine with a sharded Mesh), but the stakes are high: these are super critical in datapipes.
  • test/mesh/subdivision - this involves allocating a new-size Mesh, so potentially some things we'd want to sanity-check here - do this subdivisions happen evenly; do children live on the same device as parents;
  • test/mesh/calculus - involves message-passing across shard boundaries
  • test/mesh/neighbors - this is a very important one that is probably the most direct stress-test of contiguous topology across shards
  • test/mesh/smoothing - like calculus, is also a nonlocal operation
  • test/mesh/repair - things like "are duplicate vertices correctly detected if they're duplicated across shards"
  • test/mesh/spatial - something like a BVH tree construction could get dicey with sharded ops?
  • test/mesh/boundaries - watertightness checks might be interesting on a Sharded mesh?

If it's possible to add coverage for these tests in a relatively low-code-duplication way, that would be great - we definitely don't want a full copy of this test suite that will inevitably fall out of sync. Maybe a pytest.mark.parameterize decorator could work here?

@loliverhennigh
Copy link
Copy Markdown
Collaborator Author

/blossom-ci

Comment thread test/mesh/test_shard_tensor_mesh_ops.py Outdated
)


def _surface_mesh() -> Mesh:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would recommend using the meshes from physicsnemo.mesh.primitives here, which are intended for re-use and in-general can have much more interesting behavior (to more easily catch edge cases)

@loliverhennigh
Copy link
Copy Markdown
Collaborator Author

/blossom-ci

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants