Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 69 additions & 1 deletion docs/src/usage/numpy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ PyTorch
-------

PyTorch supports DLPack inputs and can import MLX arrays directly.
MLX can also import PyTorch tensors through DLPack with ``mx.array`` or
``mx.from_dlpack``.

.. code-block:: python

Expand All @@ -84,7 +86,73 @@ PyTorch supports DLPack inputs and can import MLX arrays directly.

a = mx.arange(3)
b = torch.tensor(a)
c = mx.array(b.cpu())
c = mx.array(b)

Creating an MLX array from a CPU tensor copies the data into MLX-owned storage.
The arrays do not share memory:

.. code-block:: python

b = torch.arange(3)
c = mx.array(b)

b += 10
print(c.tolist()) # [0, 1, 2]

Metal DLPack inputs are different. If a PyTorch MPS tensor is passed to
``mx.array`` or to ``mx.from_dlpack`` with ``copy=None`` or ``copy=False``, MLX
imports the underlying Metal buffer without copying it. The PyTorch tensor and
the MLX array then share the same storage. MLX arrays exported to PyTorch with
DLPack are also shared without a copy.

Since the buffer is shared across frameworks, synchronization has to be managed
explicitly. After PyTorch writes to an MPS tensor, call
``torch.mps.synchronize()`` before reading the shared data from MLX. After MLX
writes to the shared array, call ``mx.eval`` on the MLX result before reading
the shared data from PyTorch. Without these synchronization points, the other
framework may read the shared buffer before the producer has finished writing,
so it can observe stale data.

.. code-block:: python

b = torch.arange(3, device="mps", dtype=torch.float32)
torch.mps.synchronize()
c = mx.array(b) # zero-copy Metal DLPack import

b.add_(10)
torch.mps.synchronize()
print(c.tolist()) # [10.0, 11.0, 12.0]

Updates made by MLX can also be observed from PyTorch after the MLX computation
has been evaluated:

.. code-block:: python

b = torch.arange(3, device="mps", dtype=torch.float32)
torch.mps.synchronize()
c = mx.array(b)

c += 10
mx.eval(c)
print(b.cpu()) # tensor([10., 11., 12.])

For MLX arrays exported to PyTorch, the share is tied to the exported buffer.
MLX updates after export may rebind the MLX array to a new buffer, while the
PyTorch tensor continues to reference the exported buffer.

Use ``mx.from_dlpack`` when you need to control the copy behavior. Specifying
``copy=True`` asks MLX to create a new array instead of sharing the Metal
buffer:

.. code-block:: python

b = torch.arange(3, device="mps", dtype=torch.float32)
torch.mps.synchronize()
c = mx.from_dlpack(b, copy=True)

b.add_(10)
torch.mps.synchronize()
print(c.tolist()) # [0.0, 1.0, 2.0]

JAX
---
Expand Down
4 changes: 4 additions & 0 deletions mlx/allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ class MLX_API Buffer {
// Get the raw data pointer from the buffer
void* raw_ptr();

// Whether raw_ptr() can return a host-accessible pointer without moving or
// copying the buffer.
bool is_host_accessible() const;

// Get the buffer pointer from the buffer
const void* ptr() const {
return ptr_;
Expand Down
4 changes: 4 additions & 0 deletions mlx/backend/cuda/allocator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,10 @@ void* Buffer::raw_ptr() {
return cbuf.data;
}

bool Buffer::is_host_accessible() const {
return true;
}

} // namespace allocator

size_t get_active_memory() {
Expand Down
18 changes: 17 additions & 1 deletion mlx/backend/metal/allocator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,23 @@ void* Buffer::raw_ptr() {
if (!ptr_) {
return nullptr;
}
return static_cast<MTL::Buffer*>(ptr_)->contents();
auto* buf = static_cast<MTL::Buffer*>(ptr_);
auto* contents = buf->contents();
if (!contents && buf->length() > 0) {
throw std::runtime_error(
"[metal::Buffer::raw_ptr] Cannot access Metal buffer contents on the "
"host. The buffer is not CPU-addressable, for example because it uses "
"private storage.");
}
return contents;
}

bool Buffer::is_host_accessible() const {
if (!ptr_) {
return true;
}
auto* buf = static_cast<MTL::Buffer*>(ptr_);
return buf->storageMode() != MTL::StorageModePrivate;
}

} // namespace allocator
Expand Down
4 changes: 4 additions & 0 deletions mlx/backend/no_gpu/allocator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ void* Buffer::raw_ptr() {
return static_cast<size_t*>(ptr_) + 1;
}

bool Buffer::is_host_accessible() const {
return true;
}

Buffer CommonAllocator::malloc(size_t size) {
void* ptr = std::malloc(size + sizeof(size_t));
if (ptr != nullptr) {
Expand Down
10 changes: 10 additions & 0 deletions mlx/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,16 @@ array copy(array a, StreamOrDevice s /* = {} */) {
{std::move(a)});
}

array copy_to_new_buffer(array a, StreamOrDevice s /* = {} */) {
auto copied_shape = a.shape(); // |a| will be moved
auto dtype = a.dtype();
return array(
std::move(copied_shape),
dtype,
std::make_shared<AsType>(to_stream(s), dtype),
{std::move(a)});
}

array full_impl(array vals, Dtype dtype, StreamOrDevice s /* = {} */) {
return array(
vals.shape(),
Expand Down
3 changes: 3 additions & 0 deletions mlx/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ MLX_API array as_strided(
/** Copy another array. */
MLX_API array copy(array a, StreamOrDevice s = {});

/** Copy another array into newly allocated storage. */
MLX_API array copy_to_new_buffer(array a, StreamOrDevice s = {});

/** Fill an array of the given shape with the given value(s). */
MLX_API array full(Shape shape, array vals, Dtype dtype, StreamOrDevice s = {});
MLX_API array full(Shape shape, array vals, StreamOrDevice s = {});
Expand Down
16 changes: 15 additions & 1 deletion mlx/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <vector>

#include "mlx/dtype_utils.h"
#include "mlx/ops.h"
#include "mlx/types/limits.h"
#include "mlx/utils.h"

Expand Down Expand Up @@ -210,6 +211,19 @@ std::ostream& operator<<(std::ostream& os, uint8_t x) {
return os;
}

array host_accessible_array(array a) {
a.eval();
a.wait();
if (a.buffer().is_host_accessible()) {
return a;
}
auto out = copy_to_new_buffer(std::move(a), Device::gpu);
out.eval();
out.wait();
out.detach();
return out;
}

namespace {

template <typename T>
Expand Down Expand Up @@ -277,7 +291,7 @@ std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k) {
}

std::ostream& operator<<(std::ostream& os, array a) {
a.eval();
a = host_accessible_array(std::move(a));
dispatch_all_types(a.dtype(), [&](auto type_tag) {
print_array<MLX_GET_TYPE(type_tag)>(os, a);
});
Expand Down
2 changes: 2 additions & 0 deletions mlx/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ MLX_API void set_printoptions(PrintOptions options);

MLX_API PrintFormatter& get_global_formatter();

MLX_API array host_accessible_array(array a);

/** Print the exception and then abort. */
MLX_API void abort_with_exception(const std::exception& error);

Expand Down
24 changes: 23 additions & 1 deletion python/src/array.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,29 @@ void init_array(nb::module_& m) {
new (&arr) mx::array(nd_array_to_mlx(nd, std::nullopt));
}
})
.def("__dlpack__", [](const mx::array& a) { return mlx_to_dlpack(a); })
.def(
"__dlpack__",
[](const mx::array& a,
nb::object,
nb::object,
nb::object dl_device,
nb::object) {
std::optional<int> dl_device_type;
if (!dl_device.is_none()) {
auto device = nb::cast<nb::tuple>(dl_device);
if (nb::len(device) != 2) {
throw nb::type_error(
"dl_device must be None or a tuple[int, int]");
}
dl_device_type = nb::cast<int>(device[0]);
}
return mlx_to_dlpack(a, dl_device_type);
},
nb::kw_only(),
"stream"_a = nb::none(),
"max_version"_a = nb::none(),
"dl_device"_a = nb::none(),
"copy"_a = nb::none())
.def(
"__dlpack_device__",
[](const mx::array& a) {
Expand Down
8 changes: 8 additions & 0 deletions python/src/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,14 @@ extern "C" inline int getbuffer(PyObject* obj, Py_buffer* view, int flags) {
{
nb::gil_scoped_release nogil;
a.eval();
a.wait();
}
if (!a.buffer().is_host_accessible()) {
PyErr_SetString(
PyExc_BufferError,
"Cannot provide a buffer for an array whose storage is not "
"CPU-addressable.");
return -1;
}

std::vector<Py_ssize_t> shape(a.shape().begin(), a.shape().end());
Expand Down
Loading