mesh: enable ShardTensor support for mesh conversion/geometry paths#1608
mesh: enable ShardTensor support for mesh conversion/geometry paths#1608loliverhennigh wants to merge 20 commits into
Conversation
|
|
||
| from ._shard_tensor_spec import ShardTensorSpec | ||
| from .shard_tensor import ShardTensor, scatter_tensor | ||
| from .shard_tensor import ShardTensor, replicated_zeros_like, scatter_tensor |
There was a problem hiding this comment.
Hey @coreyjadams, if I am adding the zeros_like correct I might add a few similar ops just consistency even though they are not needed.
Greptile SummaryThis PR adds Important Files Changed
|
| dim = kwargs.get("dim", -1) | ||
| if len(args) > 2: | ||
| dim = args[2] | ||
|
|
||
| if not isinstance(input_tensor, ShardTensor) or not isinstance( | ||
| other_tensor, ShardTensor | ||
| ): | ||
| raise RuntimeError( | ||
| "torch.linalg.cross with ShardTensor inputs requires both arguments to be ShardTensor." | ||
| ) |
There was a problem hiding this comment.
torch.cross dim defaulting diverges from original behavior
When _cross_wrapper is invoked via the torch.cross handler (registered on line 1062) and the caller omits dim, the wrapper defaults to dim=-1. However, torch.cross (pre-deprecation) auto-detects the first dimension of size 3, which may not be the last dimension. Any call like torch.cross(a_shard, b_shard) where the cross-product axis isn't the last one will silently produce a wrong result instead of matching the original op's semantics.
For torch.linalg.cross, dim is keyword-only and defaults to -1, so the current default is correct there. For the torch.cross handler, consider either (a) raising explicitly when dim is absent to force callers to use the unambiguous form, or (b) documenting that only the torch.linalg.cross semantic (dim=-1) is supported.
| if not isinstance(input_tensor, ShardTensor) or not isinstance( | ||
| other_tensor, ShardTensor | ||
| ): | ||
| raise RuntimeError( | ||
| "torch.linalg.cross with ShardTensor inputs requires both arguments to be ShardTensor." | ||
| ) |
There was a problem hiding this comment.
Hardcoded error message for both
torch.cross and torch.linalg.cross
The error string says "torch.linalg.cross with ShardTensor inputs…" but this wrapper is registered for both torch.linalg.cross (line 1060) and torch.cross (line 1062). When the handler fires via torch.cross, the message will mislead users.
| if not isinstance(input_tensor, ShardTensor) or not isinstance( | |
| other_tensor, ShardTensor | |
| ): | |
| raise RuntimeError( | |
| "torch.linalg.cross with ShardTensor inputs requires both arguments to be ShardTensor." | |
| ) | |
| raise RuntimeError( | |
| f"{func.__module__}.{func.__name__} with ShardTensor inputs requires both arguments to be ShardTensor." | |
| ) |
|
Blaa, some stuff got messed up and not quite ready for review. Ill fix tonight... |
|
|
||
|
|
||
| def _is_sharded_tensor(tensor: torch.Tensor) -> bool: | ||
| return hasattr(tensor, "_spec") and hasattr(type(tensor), "from_local") |
There was a problem hiding this comment.
@coreyjadams is this best-practices for type-narrowing ShardTensor?
Feels like isinstance(tensor, ShardTensor) would be better, unless there's something I'm missing?
(And if this does get replaced with isinstance, then this can be inlined rather than keeping a separate function.)
| ) | ||
|
|
||
|
|
||
| def _mesh_to_mode(mesh: Mesh, *, mesh_tensor_mode: str, mesh_shard_device_mesh) -> Mesh: |
There was a problem hiding this comment.
Docstrings needed (here and in above functions)
|
|
||
|
|
||
| @pytest.fixture(params=_MESH_TENSOR_MODES) | ||
| def mesh_tensor_mode(request): |
|
|
||
|
|
||
| @pytest.fixture | ||
| def mesh_shard_device_mesh(mesh_tensor_mode, _single_rank_dist_group): |
There was a problem hiding this comment.
Type hints + docstrings needed
| n_cells: int, | ||
| point_placement: Replicate | Shard, | ||
| cell_placement: Replicate | Shard, | ||
| ): |
| r"""Create zeros matching a tensor's device/mesh semantics. | ||
|
|
||
| For ``ShardTensor`` inputs this returns a replicated ``ShardTensor`` on the | ||
| same mesh. For regular tensors this falls back to ``torch.zeros`` on the |
There was a problem hiding this comment.
Given that this is part of the public API, please add complete NumPy-style docstring
| src_data=cell_values[cell_indices], | ||
| src_data=cell_values.unsqueeze(1) | ||
| .expand(-1, n_vertices_per_cell, *cell_values.shape[1:]) | ||
| .reshape(-1, *cell_values.shape[1:]), |
There was a problem hiding this comment.
Why is Mesh.cell_data_to_point_data() updated, but Mesh.point_data_to_cell_data() is not? The asymmetry seems suspect - seems like they should either both need updates for ShardTensor or neither should, but I might be missing something?
Or is this just an unrelated change (and if so, what motivated this)?
There was a problem hiding this comment.
Its because we needed a shardTensor to index hence the change. In the point_data_to_cell_data we use cells to index but that is already a ShardTensor.
|
|
||
|
|
||
| def _convert_data_for_mode( | ||
| data: dict[str, torch.Tensor] | None, |
There was a problem hiding this comment.
This function incorrectly:
a) accepts dict[str, torch.Tensor] instead of TensorDict
b) iterates only over top-level keys.
Combined with downstream uses that pass in TensorDict, this causes issues in hierarchical dictionaries. (Specifically, this will cause nested values to remain as torch.Tensor.)
Consider using TensorDict.apply() instead.
coreyjadams
left a comment
There was a problem hiding this comment.
Hi @loliverhennigh - I have left a number of comments on the PR. I think there are some updates to make, though I'm happy to see the changes overall are not too extensive. I think there are some design choices that cut against the philosophy of ShardTensor we should tweak, but it doesn't look likes going to be too much.
Happy to discuss this with as much detail as you'd like offline!
|
|
||
| return _ToTorchTensor.apply(self, grad_placements) | ||
|
|
||
| def new_replicated_zeros( |
There was a problem hiding this comment.
To add this function as an API call on shard tensor, vs supporting the underlying dispatch call, has to have a really good motivation.
What is the value of this vs. supporting the backend of torch.zeros_like(a) when a is a shard tensor? and, in fact, I think that should already work?
| return ShardTensor.from_local( | ||
| local, | ||
| self._spec.mesh, | ||
| [Replicate() for _ in range(self._spec.mesh.ndim)], | ||
| sharding_shapes="infer", |
There was a problem hiding this comment.
This is unusual here: typically I'd object loudly to passing "infer" as sharding shapes since that can trigger a blocking allreduce and is a major perf headache. but you're replicating on all ranks. I don't think I understand this function's role, really.
| def replicated_zeros_like( | ||
| tensor: torch.Tensor, | ||
| shape: Sequence[int] | torch.Size, | ||
| *, | ||
| dtype: torch.dtype | None = None, | ||
| ) -> torch.Tensor: |
There was a problem hiding this comment.
I still think it makes more sense to support torch.zeros_like(a) for a as a replicated tensor.
tl;dr the ShardTensor design philosophy is to make zero code changes on user side, whenever possible, so we support torch calls on shard tensors first and foremost rather then introduce new API. Is it possible to implement your work without this?
| def _cross_wrapper(func, types, args, kwargs): | ||
| if kwargs is None: | ||
| kwargs = {} | ||
|
|
||
| if kwargs.get("out", None) is not None: | ||
| raise RuntimeError("torch.linalg.cross(out=...) is not supported for ShardTensor.") |
There was a problem hiding this comment.
If this is a function overload, it's in the wrong place. shard_tensor.py is for the core tensor object only. There is a sub folder for ops.
| raise RuntimeError( | ||
| "torch.linalg.cross with ShardTensor inputs requires both arguments to be ShardTensor." | ||
| ) |
There was a problem hiding this comment.
It looks like what we want to be doing is implementing a wrapper for torch.linalg.cross on shard tensor objects. There is not a need to check that all objects are ShardTensor at this time.
| aggregated_data = torch.zeros((n_dst, *data_shape), dtype=dtype, device=device) | ||
| aggregated_data = _replicated_zeros_like( | ||
| src_data, (n_dst, *data_shape), dtype | ||
| ) |
There was a problem hiding this comment.
The distributed-friendly update to make here is, if mesh can support it, to port torch.zeros away from this shape and onto torch.zeros_like. So whatever is building data_shape as input, we build zeros_like(that_object) and then mesh can work on single device and sharded inputs too.
| weights = torch.ones(len(src_to_dst_mapping), dtype=dtype, device=device) | ||
| if _is_sharded_tensor(src_data): | ||
| weights = torch.ones_like(src_to_dst_mapping, dtype=dtype) | ||
| else: | ||
| weights = torch.ones(len(src_to_dst_mapping), dtype=dtype, device=device) |
There was a problem hiding this comment.
This is much closer to the "right" way for domain parallelism, but in fact we can probably consolidate to just "ones_like" for all paths.
| dtype=dtype, | ||
| device=device, | ||
| ) | ||
| aggregated_data = _replicated_zeros_like(src_data, (n_dst, *data_shape), dtype) |
There was a problem hiding this comment.
Same here re:zeros_like
| weight_sums = torch.zeros(n_dst, dtype=dtype, device=device) | ||
| weight_sums = _replicated_zeros_like(src_data, (n_dst,), dtype) |
There was a problem hiding this comment.
Same here re:zeros_like
|
|
||
| converted = self.cell_data.apply( | ||
| lambda cell_values: scatter_aggregate( | ||
| src_data=cell_values[cell_indices], |
There was a problem hiding this comment.
Indexing a ShardTensor with a ShardTensor index should work fine?
There was a problem hiding this comment.
cell_indices is not a ShardTensor the way it was before. I was coming from just the torch.arange function
|
\blossom-ci |
1 similar comment
|
\blossom-ci |
|
Hi @coreyjadams - quick follow-up on this PR. The latest GitHub Actions checks are green, and I re-triggered Blossom with |
|
\blossom-ci |
|
/blossom-ci |
|
I took a look at trying to update the tests to align them with the domain parallelism tests; the test coverage here was only ~60 tests, all on CPU, for domain parallelism. It isn't aligned with the torchrun syntax that domain parallel tests expect (look at /test/plugins/distributed_fixtures.py for more info). The cross operation looks like it's pretty much implemented OK but the tests aren't quite aligned with the way the op testing is done in the domain parallelism suite, so we're falling short there too. I just don't think there is sufficient test coverage to say if we've achieved domain parallelism on PhysicsNeMo mesh. I don't think this is ready to go for the RC. There are ~2000 tests for mesh, and while we don't need all of them, being able to test against most of the core functionality + things we would want in datapipes, models, preprocessing etc for domain parallelism is, I think, a prerequisite for merge here. 60 tests (3% coverage max, if we're generous) all on CPU I think is probably short? @peterdsharpe Any suggestions of what core mesh functionality should be prioritized for testing in a distributed mode? I'm thinking we need generic mesh operations (slicing, scatter_ops, vertex_to_cell and the opposite) as well as manifold projection operations probably. What else? I think we'd also want to be able to have good distributed-mesh-init testing, like how to initialize a mesh with sharded tensors and make sure it works, gradients flow, we can resize / reshape / redistribute it, etc. Core tensor ops are covered with shard tensor already, so just mesh-specific things are fine. |
|
Yeah, I think these are all great calls @coreyjadams . Re Mesh testing, the high-risk areas that might be good to add testing coverage mesh for a sharded Mesh are as follows:
If it's possible to add coverage for these tests in a relatively low-code-duplication way, that would be great - we definitely don't want a full copy of this test suite that will inevitably fall out of sync. Maybe a |
|
/blossom-ci |
This reverts commit 8bc39bf.
| ) | ||
|
|
||
|
|
||
| def _surface_mesh() -> Mesh: |
There was a problem hiding this comment.
Would recommend using the meshes from physicsnemo.mesh.primitives here, which are intended for re-use and in-general can have much more interesting behavior (to more easily catch edge cases)
|
/blossom-ci |
PhysicsNeMo Pull Request
Description