Skip to content

Improve DLPack-compatible array imports#3495

Open
XXXXRT666 wants to merge 5 commits intoml-explore:mainfrom
XXXXRT666:mps-dlpack-cpu-fallback
Open

Improve DLPack-compatible array imports#3495
XXXXRT666 wants to merge 5 commits intoml-explore:mainfrom
XXXXRT666:mps-dlpack-cpu-fallback

Conversation

@XXXXRT666
Copy link
Copy Markdown
Contributor

@XXXXRT666 XXXXRT666 commented May 8, 2026

Proposed changes

This PR improves mx.array import behavior for DLPack-compatible inputs.

Changes include:

  • Update the mx.array constructor signature to use a DLPackCompatible protocol instead of naming numpy.ndarray directly.
  • Reject non-CPU DLPack inputs consistently, including both mx.array(...) and operator argument conversion paths.
  • Add test coverage for PyTorch MPS tensors to ensure non-CPU DLPack inputs raise an explicit error.
  • Update the PyTorch interoperability docs to use DLPack directly via torch.tensor(a) and then move the tensor to CPU before importing it back into MLX.

PyTorch supports the DLPack import path directly: torch.tensor(data) checks for __dlpack__ and routes through torch.utils.dlpack.from_dlpack in tensor_new.cpp.

This PR intentionally does not add direct Metal buffer import or zero-copy semantics. That path depends on nanobind support that is not upstream yet: ndarray import needs to request the producer's DLPack device with dl_device and copy=False for non-CPU devices, and MLX needs access to the raw DLPack data handle and byte_offset to wrap the Metal allocation directly.

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

Copy link
Copy Markdown
Collaborator

@zcbenz zcbenz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm good with most changes in this PR but we should not silently convert GPU tensors to CPU, because it would be a surprising behavior for most users, and would become a breaking change when we support taking GPU tensors in future.

I think a better choice is to simply throw an error when a non-CPU tensor is passed.

@XXXXRT666
Copy link
Copy Markdown
Contributor Author

I think a better choice is to simply throw an error when a non-CPU tensor is passed.

That makes sense. I can keep this PR from silently converting GPU tensors to CPU.

I'm good with most changes in this PR but we should not silently convert GPU tensors to CPU, because it would be a surprising behavior for most users, and would become a breaking change when we support taking GPU tensors in future.

After this PR lands, I can add proper GPU tensor support in a follow up PR. At that point we should be able to support zero-copy where the producer exposes a compatible GPU buffer, instead of falling back to an implicit CPU copy.

@XXXXRT666 XXXXRT666 force-pushed the mps-dlpack-cpu-fallback branch from 2a15f3d to 5adb198 Compare May 9, 2026 02:46
@XXXXRT666 XXXXRT666 changed the title Support DLPack-compatible inputs via CPU fallback Improve DLPack-compatible array imports May 9, 2026
Comment thread python/src/convert.cpp Outdated
@XXXXRT666 XXXXRT666 force-pushed the mps-dlpack-cpu-fallback branch 2 times, most recently from c43a05b to 8101c39 Compare May 9, 2026 15:04
@XXXXRT666 XXXXRT666 force-pushed the mps-dlpack-cpu-fallback branch from 8101c39 to 90f9a24 Compare May 9, 2026 15:14
@XXXXRT666
Copy link
Copy Markdown
Contributor Author

XXXXRT666 commented May 9, 2026

Updated this PR to use nb::ndarray::device_type() after casting once to a generic contiguous ndarray.

The non-CPU guard now lives in nd_array_to_mlx(), so all CPU-copy paths reject non-CPU DLPack inputs consistently. That covers both mx.array(torch_mps_tensor) and operator argument conversion such as mx.array([1]) + torch.tensor([2]).to("mps"). I also removed the float64 import change from the PR since metal doesn't support fp64 computing.

Copy link
Copy Markdown
Collaborator

@zcbenz zcbenz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work, thanks!

Comment thread python/src/convert.cpp Outdated
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants