|
5 | 5 | */ |
6 | 6 |
|
7 | 7 | #include "ddptensor/IO.hpp" |
| 8 | +#include "ddptensor/DDPTensorImpl.hpp" |
8 | 9 | #include "ddptensor/Factory.hpp" |
9 | 10 | #include "ddptensor/SetGetItem.hpp" |
10 | 11 | #include "ddptensor/Transceiver.hpp" |
11 | 12 | #include "ddptensor/TypeDispatch.hpp" |
12 | 13 |
|
| 14 | +#include <pybind11/numpy.h> |
| 15 | +#include <pybind11/pybind11.h> |
| 16 | +namespace py = pybind11; |
| 17 | + |
| 18 | +// *************************************************************************** |
| 19 | + |
| 20 | +/// @brief form a ddptensor from local numpy arrays (inplace - no copy) |
| 21 | +struct DeferredFromLocal : public Deferred { |
| 22 | + py::array _npa; |
| 23 | + |
| 24 | + DeferredFromLocal() = default; |
| 25 | + DeferredFromLocal(py::array npa) |
| 26 | + : Deferred(getDTypeId(npa.dtype()), |
| 27 | + {npa.shape(), npa.shape() + npa.ndim()}, 0, true), |
| 28 | + _npa(npa) {} |
| 29 | + |
| 30 | + // get our DTypeId from py::dtype |
| 31 | + DTypeId getDTypeId(const py::dtype &dtype) { |
| 32 | + auto bw = dtype.itemsize(); |
| 33 | + auto kind = dtype.kind(); |
| 34 | + switch (kind) { |
| 35 | + case 'i': |
| 36 | + switch (bw) { |
| 37 | + case 1: |
| 38 | + return INT8; |
| 39 | + case 2: |
| 40 | + return INT16; |
| 41 | + case 4: |
| 42 | + return INT32; |
| 43 | + case 8: |
| 44 | + return INT64; |
| 45 | + }; |
| 46 | + case 'f': |
| 47 | + switch (bw) { |
| 48 | + case 4: |
| 49 | + return FLOAT32; |
| 50 | + case 8: |
| 51 | + return FLOAT64; |
| 52 | + }; |
| 53 | + }; |
| 54 | + throw std::runtime_error("Unsupported dtype"); |
| 55 | + } |
| 56 | + |
| 57 | + void run() override { |
| 58 | + auto _strides = _npa.strides(); |
| 59 | + auto shape = _npa.shape(); |
| 60 | + auto data = _npa.mutable_data(); |
| 61 | + auto dtype = _npa.dtype(); |
| 62 | + auto ndim = _npa.ndim(); |
| 63 | + auto eSz = dtype.itemsize(); |
| 64 | + |
| 65 | + // py::array stores strides in bytes, not elements |
| 66 | + std::vector<intptr_t> strides(ndim); |
| 67 | + for (auto i = 0; i < ndim; ++i) { |
| 68 | + strides[i] = _strides[i] / eSz; |
| 69 | + } |
| 70 | + |
| 71 | + auto res = mk_tnsr(getDTypeId(dtype), ndim, shape, strides.data(), data); |
| 72 | + // make sure we do not delete numpy's memory before the numpy array is dead |
| 73 | + // notice: py::objects have ref-counting) |
| 74 | + res->set_base(new SharedBaseObject<py::object>(_npa)); |
| 75 | + set_value(std::move(res)); |
| 76 | + } |
| 77 | + |
| 78 | + bool generate_mlir(::mlir::OpBuilder &builder, ::mlir::Location loc, |
| 79 | + jit::DepManager &dm) override { |
| 80 | + return true; |
| 81 | + } |
| 82 | + |
| 83 | + FactoryId factory() const { return F_FROMLOCALS; } |
| 84 | + |
| 85 | + template <typename S> void serialize(S &ser) {} |
| 86 | +}; |
| 87 | + |
13 | 88 | GetItem::py_future_type IO::to_numpy(const ddptensor &a) { |
14 | 89 | assert(!getTransceiver()->is_cw() || getTransceiver()->rank() == 0); |
15 | 90 | return GetItem::gather(a, getTransceiver()->is_cw() ? 0 : REPLICATED); |
16 | 91 | } |
| 92 | + |
| 93 | +ddptensor *IO::from_locals(const std::vector<py::array> &a) { |
| 94 | + assert(a.size() == 1); |
| 95 | + return new ddptensor(defer<DeferredFromLocal>(a.front())); |
| 96 | +} |
| 97 | + |
| 98 | +FACTORY_INIT(DeferredFromLocal, F_FROMLOCALS); |
0 commit comments