Skip to content
This repository was archived by the owner on Jan 26, 2026. It is now read-only.

Commit bb83030

Browse files
authored
introducing dt.from_local for single-process (#38)
* introducing dt.from_locals for single-process * get_local -> get_locals
1 parent 426d5b4 commit bb83030

File tree

11 files changed

+174
-26
lines changed

11 files changed

+174
-26
lines changed

ddptensor/spmd.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
11
from . import _ddptensor as _cdt
2+
from . import dtensor
23

34

45
def get_slice(obj, *args):
56
return _cdt._get_slice(obj._t, *args)
67

78

8-
def get_local(obj):
9-
return _cdt._get_local(obj._t, obj)
9+
def get_locals(obj):
10+
return _cdt._get_locals(obj._t, obj)
11+
12+
13+
def from_locals(objs):
14+
arg = objs if isinstance(objs, (list, tuple)) else [objs]
15+
return dtensor(_cdt._from_locals(arg))
1016

1117

1218
def gather(obj, root=_cdt._Ranks._REPLICATED):

src/DDPTensorImpl.cpp

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,14 +62,33 @@ DDPTensorImpl::DDPTensorImpl(const int64_t *shape, uint64_t N, rank_type owner)
6262
assert(!_transceiver || _transceiver == getTransceiver());
6363
}
6464

65+
// from numpy
66+
DDPTensorImpl::DDPTensorImpl(DTypeId dtype, ssize_t ndims, const ssize_t *shape,
67+
const intptr_t *strides, void *data)
68+
: _owner(NOOWNER), _gShape(shape, shape + ndims),
69+
_lo_allocated(
70+
static_cast<uint64_t *>(calloc(ndims, sizeof_dtype(dtype)))),
71+
_lo_aligned(_lo_allocated),
72+
_lData(ndims, data, data, 0, reinterpret_cast<const intptr_t *>(shape),
73+
reinterpret_cast<const intptr_t *>(strides)),
74+
_dtype(dtype) {}
75+
76+
void DDPTensorImpl::set_base(const tensor_i::ptr_type &base) {
77+
_base = new SharedBaseObject<tensor_i::ptr_type>(base);
78+
}
79+
void DDPTensorImpl::set_base(BaseObj *obj) { _base = obj; }
80+
6581
DDPTensorImpl::~DDPTensorImpl() {
6682
if (!_base) {
6783
// FIXME it seems possible that halos get reallocated even with when there
68-
// is a base _lhsHalo.freeData(); FIXME lhs and rhs can be identical
84+
// is a base
85+
if (_lhsHalo._allocated != _rhsHalo._allocated)
86+
_lhsHalo.freeData(); // lhs and rhs can be identical
6987
_lData.freeData();
7088
_rhsHalo.freeData();
7189
}
7290
free(_lo_allocated);
91+
delete _base;
7392
}
7493

7594
void *DDPTensorImpl::data() {
@@ -103,8 +122,9 @@ std::string DDPTensorImpl::__repr__() const {
103122
for (auto i = 0; i < nd; ++i)
104123
oss << _gShape[i] << (i == nd - 1 ? "" : ", ");
105124
oss << "), loff=(";
106-
for (auto i = 0; i < nd; ++i)
107-
oss << _lo_aligned[i] << (i == nd - 1 ? "" : ", ");
125+
if (_lo_aligned)
126+
for (auto i = 0; i < nd; ++i)
127+
oss << _lo_aligned[i] << (i == nd - 1 ? "" : ", ");
108128
oss << "), lsz=(";
109129
for (auto i = 0; i < nd; ++i)
110130
oss << _lData._sizes[i] << (i == nd - 1 ? "" : ", ");

src/IO.cpp

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,94 @@
55
*/
66

77
#include "ddptensor/IO.hpp"
8+
#include "ddptensor/DDPTensorImpl.hpp"
89
#include "ddptensor/Factory.hpp"
910
#include "ddptensor/SetGetItem.hpp"
1011
#include "ddptensor/Transceiver.hpp"
1112
#include "ddptensor/TypeDispatch.hpp"
1213

14+
#include <pybind11/numpy.h>
15+
#include <pybind11/pybind11.h>
16+
namespace py = pybind11;
17+
18+
// ***************************************************************************
19+
20+
/// @brief form a ddptensor from local numpy arrays (inplace - no copy)
21+
struct DeferredFromLocal : public Deferred {
22+
py::array _npa;
23+
24+
DeferredFromLocal() = default;
25+
DeferredFromLocal(py::array npa)
26+
: Deferred(getDTypeId(npa.dtype()),
27+
{npa.shape(), npa.shape() + npa.ndim()}, 0, true),
28+
_npa(npa) {}
29+
30+
// get our DTypeId from py::dtype
31+
DTypeId getDTypeId(const py::dtype &dtype) {
32+
auto bw = dtype.itemsize();
33+
auto kind = dtype.kind();
34+
switch (kind) {
35+
case 'i':
36+
switch (bw) {
37+
case 1:
38+
return INT8;
39+
case 2:
40+
return INT16;
41+
case 4:
42+
return INT32;
43+
case 8:
44+
return INT64;
45+
};
46+
case 'f':
47+
switch (bw) {
48+
case 4:
49+
return FLOAT32;
50+
case 8:
51+
return FLOAT64;
52+
};
53+
};
54+
throw std::runtime_error("Unsupported dtype");
55+
}
56+
57+
void run() override {
58+
auto _strides = _npa.strides();
59+
auto shape = _npa.shape();
60+
auto data = _npa.mutable_data();
61+
auto dtype = _npa.dtype();
62+
auto ndim = _npa.ndim();
63+
auto eSz = dtype.itemsize();
64+
65+
// py::array stores strides in bytes, not elements
66+
std::vector<intptr_t> strides(ndim);
67+
for (auto i = 0; i < ndim; ++i) {
68+
strides[i] = _strides[i] / eSz;
69+
}
70+
71+
auto res = mk_tnsr(getDTypeId(dtype), ndim, shape, strides.data(), data);
72+
// make sure we do not delete numpy's memory before the numpy array is dead
73+
// notice: py::objects have ref-counting)
74+
res->set_base(new SharedBaseObject<py::object>(_npa));
75+
set_value(std::move(res));
76+
}
77+
78+
bool generate_mlir(::mlir::OpBuilder &builder, ::mlir::Location loc,
79+
jit::DepManager &dm) override {
80+
return true;
81+
}
82+
83+
FactoryId factory() const { return F_FROMLOCALS; }
84+
85+
template <typename S> void serialize(S &ser) {}
86+
};
87+
1388
GetItem::py_future_type IO::to_numpy(const ddptensor &a) {
1489
assert(!getTransceiver()->is_cw() || getTransceiver()->rank() == 0);
1590
return GetItem::gather(a, getTransceiver()->is_cw() ? 0 : REPLICATED);
1691
}
92+
93+
ddptensor *IO::from_locals(const std::vector<py::array> &a) {
94+
assert(a.size() == 1);
95+
return new ddptensor(defer<DeferredFromLocal>(a.front()));
96+
}
97+
98+
FACTORY_INIT(DeferredFromLocal, F_FROMLOCALS);

src/SetGetItem.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -59,29 +59,29 @@ py::object wrap(DDPTensorImpl::ptr_type tnsr, const py::handle &handle) {
5959

6060
// ***************************************************************************
6161

62-
struct DeferredGetLocal
62+
struct DeferredGetLocals
6363
: public DeferredT<GetItem::py_promise_type, GetItem::py_future_type> {
6464
id_type _a;
6565
py::handle _handle;
6666

67-
DeferredGetLocal() = default;
68-
DeferredGetLocal(const tensor_i::future_type &a, py::handle &handle)
67+
DeferredGetLocals() = default;
68+
DeferredGetLocals(const tensor_i::future_type &a, py::handle &handle)
6969
: _a(a.guid()), _handle(handle) {}
7070

7171
void run() override {
7272
auto aa = std::move(Registry::get(_a).get());
7373
auto a_ptr = std::dynamic_pointer_cast<DDPTensorImpl>(aa);
7474
assert(a_ptr);
7575
auto res = wrap(a_ptr, _handle);
76-
set_value(res);
76+
set_value(py::make_tuple(res));
7777
}
7878

7979
bool generate_mlir(::mlir::OpBuilder &builder, ::mlir::Location loc,
8080
jit::DepManager &dm) override {
8181
return true;
8282
}
8383

84-
FactoryId factory() const { return F_GETLOCAL; }
84+
FactoryId factory() const { return F_GETLOCALS; }
8585

8686
template <typename S> void serialize(S &ser) {
8787
ser.template value<sizeof(_a)>(_a);
@@ -345,8 +345,8 @@ ddptensor *GetItem::__getitem__(const ddptensor &a,
345345
return new ddptensor(defer<DeferredGetItem>(a.get(), std::move(slc)));
346346
}
347347

348-
GetItem::py_future_type GetItem::get_local(const ddptensor &a, py::handle h) {
349-
return defer<DeferredGetLocal>(a.get(), h);
348+
GetItem::py_future_type GetItem::get_locals(const ddptensor &a, py::handle h) {
349+
return defer<DeferredGetLocals>(a.get(), h);
350350
}
351351

352352
GetItem::py_future_type GetItem::gather(const ddptensor &a, rank_type root) {
@@ -377,4 +377,4 @@ FACTORY_INIT(DeferredGetItem, F_GETITEM);
377377
FACTORY_INIT(DeferredSetItem, F_SETITEM);
378378
FACTORY_INIT(DeferredMap, F_MAP);
379379
FACTORY_INIT(DeferredGather, F_GATHER);
380-
FACTORY_INIT(DeferredGetLocal, F_GETLOCAL);
380+
FACTORY_INIT(DeferredGetLocals, F_GETLOCALS);

src/ddptensor.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,10 +135,11 @@ PYBIND11_MODULE(_ddptensor, m) {
135135
.def("sync", &sync_promises)
136136
.def("myrank", &myrank)
137137
.def("_get_slice", &GetItem::get_slice)
138-
.def("_get_local",
138+
.def("_get_locals",
139139
[](const ddptensor &f, py::handle h) {
140-
PY_SYNC_RETURN(GetItem::get_local(f, h));
140+
PY_SYNC_RETURN(GetItem::get_locals(f, h));
141141
})
142+
.def("_from_locals", &IO::from_locals)
142143
.def("_gather",
143144
[](const ddptensor &f, rank_type root = REPLICATED) {
144145
PY_SYNC_RETURN(GetItem::gather(f, root));

src/include/ddptensor/CppTypes.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,10 +199,11 @@ enum FactoryId : int {
199199
F_ARANGE,
200200
F_EWBINOP,
201201
F_EWUNYOP,
202+
F_FROMLOCALS,
202203
F_FULL,
203204
F_GATHER,
204205
F_GETITEM,
205-
F_GETLOCAL,
206+
F_GETLOCALS,
206207
F_IEWBINOP,
207208
F_LINALGOP,
208209
F_LINSPACE,

src/include/ddptensor/DDPTensorImpl.hpp

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,25 @@
1717

1818
class Transceiver;
1919

20+
/// @brief use this to provide a base object to the tensor
21+
// such a base object can own shared data
22+
// you might need to implem,ent reference counting
23+
struct BaseObj {
24+
virtual ~BaseObj() {}
25+
};
26+
27+
/// @brief Simple implementatino of BaseObj for ref-counting types
28+
/// @tparam T ref-counting type, such as py::object of std::shared_Ptr
29+
/// we keep an object of the ref-counting type. Normal ref-counting/destructors
30+
/// will take care of the rest.
31+
template <typename T> struct SharedBaseObject : public BaseObj {
32+
SharedBaseObject(const SharedBaseObject &) = default;
33+
SharedBaseObject(SharedBaseObject &&) = default;
34+
SharedBaseObject(const T &o) : _base(o) {}
35+
SharedBaseObject(T &&o) : _base(std::forward<T>(o)) {}
36+
T _base;
37+
};
38+
2039
/// The actual implementation of the DDPTensor, implementing the tensor_i
2140
/// interface. It holds the tensor data and some meta information. The member
2241
/// attributes are mostly inspired by the needs of interacting with MLIR. It
@@ -25,6 +44,7 @@ class Transceiver;
2544
/// Here, the halos are never used for anything except for interchanging with
2645
/// MLIR.
2746
class DDPTensorImpl : public tensor_i {
47+
2848
mutable rank_type _owner;
2949
Transceiver *_transceiver = nullptr;
3050
shape_type _gShape = {};
@@ -34,7 +54,7 @@ class DDPTensorImpl : public tensor_i {
3454
DynMemRef _lData;
3555
DynMemRef _rhsHalo;
3656
DTypeId _dtype = DTYPE_LAST;
37-
tensor_i::ptr_type _base;
57+
BaseObj *_base = nullptr;
3858

3959
public:
4060
using ptr_type = std::shared_ptr<DDPTensorImpl>;
@@ -63,8 +83,14 @@ class DDPTensorImpl : public tensor_i {
6383
// incomplete, useful for computing meta information
6484
DDPTensorImpl() : _owner(REPLICATED) { assert(ndims() <= 1); }
6585

86+
// From numpy
87+
// FIXME multi-proc
88+
DDPTensorImpl(DTypeId dtype, ssize_t ndims, const ssize_t *shape,
89+
const intptr_t *strides, void *data);
90+
6691
// set the base tensor
67-
void set_base(const tensor_i::ptr_type &base) { _base = base; }
92+
void set_base(const tensor_i::ptr_type &base);
93+
void set_base(BaseObj *obj);
6894

6995
virtual ~DDPTensorImpl();
7096

src/include/ddptensor/IO.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@
77
#pragma once
88

99
#include "ddptensor/SetGetItem.hpp"
10+
#include <pybind11/numpy.h>
11+
namespace py = pybind11;
12+
#include <vector>
1013

1114
struct IO {
1215
static GetItem::py_future_type to_numpy(const ddptensor &a);
16+
static ddptensor *from_locals(const std::vector<py::array> &a);
1317
};

src/include/ddptensor/SetGetItem.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ struct GetItem {
1919
const std::vector<py::slice> &v);
2020
static py::object get_slice(const ddptensor &a,
2121
const std::vector<py::slice> &v);
22-
static py_future_type get_local(const ddptensor &a, py::handle h);
22+
static py_future_type get_locals(const ddptensor &a, py::handle h);
2323
static py_future_type gather(const ddptensor &a, rank_type root);
2424
};
2525

src/jit/mlir.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -178,12 +178,12 @@ ::mlir::Value DepManager::getDependent(::mlir::OpBuilder &builder,
178178
auto rank = impl->ndims();
179179
::mlir::SmallVector<int64_t> lhShape(rank), ownShape(rank), rhShape(rank);
180180
for (size_t i = 0; i < rank; i++) {
181-
lhShape[i] = impl->lh_shape()[i];
181+
lhShape[i] = impl->lh_shape() ? impl->lh_shape()[i] : 0;
182182
ownShape[i] = impl->local_shape()[i];
183-
rhShape[i] = impl->rh_shape()[i];
183+
rhShape[i] = impl->rh_shape() ? impl->rh_shape()[i] : 0;
184184
}
185185
auto typ = getTType(
186-
builder, fut.dtype(),
186+
builder, impl->dtype(),
187187
::mlir::SmallVector<int64_t>(impl->shape(), impl->shape() + rank),
188188
lhShape, ownShape, rhShape, fut.team(), fut.balanced());
189189
_func.insertArgument(idx, typ, {}, loc);

0 commit comments

Comments
 (0)