Improve DLPack-compatible array imports#3495
Improve DLPack-compatible array imports#3495XXXXRT666 wants to merge 5 commits intoml-explore:mainfrom
Conversation
zcbenz
left a comment
There was a problem hiding this comment.
I'm good with most changes in this PR but we should not silently convert GPU tensors to CPU, because it would be a surprising behavior for most users, and would become a breaking change when we support taking GPU tensors in future.
I think a better choice is to simply throw an error when a non-CPU tensor is passed.
That makes sense. I can keep this PR from silently converting GPU tensors to CPU.
After this PR lands, I can add proper GPU tensor support in a follow up PR. At that point we should be able to support zero-copy where the producer exposes a compatible GPU buffer, instead of falling back to an implicit CPU copy. |
2a15f3d to
5adb198
Compare
c43a05b to
8101c39
Compare
8101c39 to
90f9a24
Compare
|
Updated this PR to use The non-CPU guard now lives in |
Proposed changes
This PR improves
mx.arrayimport behavior for DLPack-compatible inputs.Changes include:
mx.arrayconstructor signature to use aDLPackCompatibleprotocol instead of namingnumpy.ndarraydirectly.mx.array(...)and operator argument conversion paths.torch.tensor(a)and then move the tensor to CPU before importing it back into MLX.PyTorch supports the DLPack import path directly:
torch.tensor(data)checks for__dlpack__and routes throughtorch.utils.dlpack.from_dlpackintensor_new.cpp.This PR intentionally does not add direct Metal buffer import or zero-copy semantics. That path depends on nanobind support that is not upstream yet: ndarray import needs to request the producer's DLPack device with
dl_deviceandcopy=Falsefor non-CPU devices, and MLX needs access to the raw DLPackdatahandle andbyte_offsetto wrap the Metal allocation directly.Checklist
Put an
xin the boxes that apply.pre-commit run --all-filesto format my code / installed pre-commit prior to committing changes