Skip to content
This repository was archived by the owner on Jan 26, 2026. It is now read-only.

Commit 172c55d

Browse files
committed
adding numpy.fromfunction; acquire gil for immediate execution/run, relase when waiting
1 parent f97f1e8 commit 172c55d

File tree

11 files changed

+154
-40
lines changed

11 files changed

+154
-40
lines changed

ddptensor/numpy/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,7 @@
1-
import dtensor
1+
from .. import empty, float32
2+
3+
4+
def fromfunction(function, shape, *, dtype=float32):
5+
t = empty(shape, dtype)
6+
t._t.map(function)
7+
return t

src/Creator.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,11 @@ struct DeferredFull : public Deferred {
151151
::imex::ptensor::DType dtyp;
152152
::mlir::Value val = dispatch<ValAndDType>(_dtype, builder, loc, _val, dtyp);
153153

154-
auto team = ::imex::createIndex(
155-
loc, builder, reinterpret_cast<uint64_t>(getTransceiver()));
154+
auto team = /*getTransceiver()->nranks() <= 1
155+
? ::mlir::Value()
156+
:*/
157+
::imex::createIndex(loc, builder,
158+
reinterpret_cast<uint64_t>(getTransceiver()));
156159

157160
dm.addVal(this->guid(),
158161
builder.create<::imex::ptensor::CreateOp>(loc, shp, dtyp, val,
@@ -206,8 +209,12 @@ struct DeferredArange : public Deferred {
206209
auto stop = ::imex::createInt(loc, builder, _end);
207210
auto step = ::imex::createInt(loc, builder, _step);
208211
// ::mlir::Value
209-
auto team = ::imex::createIndex(
210-
loc, builder, reinterpret_cast<uint64_t>(getTransceiver()));
212+
auto team = /*getTransceiver()->nranks() <= 1
213+
? ::mlir::Value()
214+
:*/
215+
::imex::createIndex(loc, builder,
216+
reinterpret_cast<uint64_t>(getTransceiver()));
217+
211218
dm.addVal(this->guid(),
212219
builder.create<::imex::ptensor::ARangeOp>(loc, start, stop, step,
213220
nullptr, team),

src/Deferred.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
2020
#include <oneapi/tbb/concurrent_queue.h>
2121

22+
#include <pybind11/pybind11.h>
23+
namespace py = pybind11;
24+
2225
#include <iostream>
2326

2427
// thread-safe FIFO queue holding deferred objects
@@ -148,9 +151,10 @@ void process_promises() {
148151
} // no else needed
149152

150153
// now we execute the deferred action which could not be compiled
151-
if (d)
154+
if (d) {
155+
py::gil_scoped_acquire acquire;
152156
d->run();
157+
d.reset();
158+
}
153159
} while (!done);
154160
}
155-
156-
void sync_promises() { (void)Service::run().get(); }

src/SetGetItem.cpp

Lines changed: 75 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,21 @@ template <typename T> struct wrap_array {
167167
}
168168
};
169169

170+
py::object wrap(DDPTensorImpl::ptr_type tnsr, const py::handle &handle) {
171+
auto tmp_shp = tnsr->local_shape();
172+
auto tmp_str = tnsr->local_strides();
173+
auto nd = tnsr->ndims();
174+
auto eSz = sizeof_dtype(tnsr->dtype());
175+
std::vector<ssize_t> strides(nd);
176+
for (auto i = 0; i < nd; ++i) {
177+
strides[i] = eSz * tmp_str[i];
178+
}
179+
180+
return dispatch<wrap_array>(tnsr->dtype(),
181+
std::vector<ssize_t>(tmp_shp, &tmp_shp[nd]),
182+
strides, tnsr->data(), handle);
183+
}
184+
170185
// ***************************************************************************
171186

172187
struct DeferredGetLocal
@@ -182,17 +197,7 @@ struct DeferredGetLocal
182197
auto aa = std::move(Registry::get(_a).get());
183198
auto a_ptr = std::dynamic_pointer_cast<DDPTensorImpl>(aa);
184199
assert(a_ptr);
185-
auto tmp_shp = a_ptr->local_shape();
186-
auto tmp_str = a_ptr->local_strides();
187-
auto nd = a_ptr->ndims();
188-
auto eSz = sizeof_dtype(a_ptr->dtype());
189-
std::vector<ssize_t> strides(nd);
190-
for (auto i = 0; i < nd; ++i) {
191-
strides[i] = eSz * tmp_str[i];
192-
}
193-
auto res = dispatch<wrap_array>(a_ptr->dtype(),
194-
std::vector<ssize_t>(tmp_shp, &tmp_shp[nd]),
195-
strides, a_ptr->data(), _handle);
200+
auto res = wrap(a_ptr, _handle);
196201
set_value(res);
197202
}
198203

@@ -317,6 +322,58 @@ struct DeferredSetItem : public Deferred {
317322

318323
// ***************************************************************************
319324

325+
struct DeferredMap : public Deferred {
326+
id_type _a;
327+
py::object _func;
328+
329+
DeferredMap() = default;
330+
DeferredMap(const tensor_i::future_type &a, py::object &func)
331+
: Deferred(a.id(), a.dtype(), a.rank(), a.balanced()), _a(a.id()),
332+
_func(func) {}
333+
334+
void run() override {
335+
auto aa = std::move(Registry::get(_a).get());
336+
auto a_ptr = std::dynamic_pointer_cast<DDPTensorImpl>(aa);
337+
assert(a_ptr);
338+
auto nd = a_ptr->ndims();
339+
auto lOffs = a_ptr->local_offsets();
340+
std::vector<int64_t> lIdx(nd);
341+
std::vector<int64_t> gIdx(nd);
342+
343+
dispatch(a_ptr->dtype(), a_ptr->data(), [&](auto *ptr) {
344+
forall(
345+
0, ptr, a_ptr->local_shape(), a_ptr->local_strides(), nd, lIdx,
346+
[&](const std::vector<int64_t> &idx, auto *elPtr) {
347+
for (auto i = 0; i < nd; ++i) {
348+
gIdx[i] = idx[i] + lOffs[i];
349+
}
350+
auto pyIdx = _make_tuple(gIdx);
351+
*elPtr =
352+
_func(*pyIdx)
353+
.cast<
354+
typename std::remove_pointer<decltype(elPtr)>::type>();
355+
});
356+
});
357+
358+
this->set_value(aa);
359+
};
360+
361+
bool generate_mlir(::mlir::OpBuilder &builder, ::mlir::Location loc,
362+
jit::DepManager &dm) override {
363+
return true;
364+
}
365+
366+
FactoryId factory() const { return F_MAP; }
367+
368+
template <typename S> void serialize(S &ser) {
369+
assert(false);
370+
ser.template value<sizeof(_a)>(_a);
371+
// nope ser.template value<sizeof(_func)>(_func);
372+
}
373+
};
374+
375+
// ***************************************************************************
376+
320377
struct DeferredGetItem : public Deferred {
321378
id_type _a;
322379
NDSlice _slc;
@@ -407,14 +464,18 @@ GetItem::py_future_type GetItem::gather(const ddptensor &a, rank_type root) {
407464

408465
ddptensor *SetItem::__setitem__(ddptensor &a, const std::vector<py::slice> &v,
409466
const py::object &b) {
410-
411467
auto bb = Creator::mk_future(b);
412468
a.put(defer<DeferredSetItem>(a.get(), bb.first->get(), v));
413469
if (bb.second)
414470
delete bb.first;
415471
return &a;
416472
}
417473

474+
ddptensor *SetItem::map(ddptensor &a, py::object &b) {
475+
a.put(defer<DeferredMap>(a.get(), b));
476+
return &a;
477+
}
478+
418479
py::object GetItem::get_slice(const ddptensor &a,
419480
const std::vector<py::slice> &v) {
420481
const auto aa = std::move(a.get());
@@ -423,5 +484,6 @@ py::object GetItem::get_slice(const ddptensor &a,
423484

424485
FACTORY_INIT(DeferredGetItem, F_GETITEM);
425486
FACTORY_INIT(DeferredSetItem, F_SETITEM);
487+
FACTORY_INIT(DeferredMap, F_MAP);
426488
FACTORY_INIT(DeferredGather, F_GATHER);
427-
FACTORY_INIT(DeferredGather, F_GETLOCAL);
489+
FACTORY_INIT(DeferredGetLocal, F_GETLOCAL);

src/ddptensor.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ extern bool finied;
5555

5656
// users currently need to call fini to make MPI terminate gracefully
5757
void fini() {
58+
py::gil_scoped_release release;
5859
if (finied)
5960
return;
6061
fini_mediator(); // stop task is sent in here
@@ -92,15 +93,22 @@ void init(bool cw) {
9293
finied = false;
9394
}
9495

96+
void sync_promises() {
97+
py::gil_scoped_release release;
98+
(void)Service::run().get();
99+
}
100+
95101
// #########################################################################
96102

97103
/// trigger compile&run and return future value
98104
#define PY_SYNC_RETURN(_f) \
105+
py::gil_scoped_release release; \
99106
Service::run(); \
100107
return (_f).get()
101108

102109
/// trigger compile&run and return given attribute _x
103110
#define SYNC_RETURN(_f, _a) \
111+
py::gil_scoped_release release; \
104112
Service::run(); \
105113
return (_f).get().get()->_a()
106114

@@ -188,7 +196,8 @@ PYBIND11_MODULE(_ddptensor, m) {
188196
[](const ddptensor &f) { REPL_SYNC_RETURN(f, __int__); })
189197
// attributes returning a new ddptensor
190198
.def("__getitem__", &GetItem::__getitem__)
191-
.def("__setitem__", &SetItem::__setitem__);
199+
.def("__setitem__", &SetItem::__setitem__)
200+
.def("map", &SetItem::map);
192201
#undef REPL_SYNC_RETURN
193202
#undef SYNC_RETURN
194203

src/include/ddptensor/CppTypes.hpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -198,19 +198,20 @@ enum FactoryId : int {
198198
F_EWUNYOP,
199199
F_FROMSHAPE,
200200
F_FULL,
201+
F_GATHER,
201202
F_GETITEM,
203+
F_GETLOCAL,
202204
F_IEWBINOP,
203205
F_LINALGOP,
204206
F_MANIPOP,
207+
F_MAP,
205208
F_RANDOM,
206209
F_REDUCEOP,
210+
F_REPLICATE,
207211
F_SERVICE,
208212
F_SETITEM,
209213
F_SORTOP,
210214
F_UNYOP,
211-
F_GATHER,
212-
F_GETLOCAL,
213-
F_REPLICATE,
214215
FACTORY_LAST
215216
};
216217

src/include/ddptensor/DDPTensorImpl.hpp

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -209,20 +209,41 @@ template <typename... Ts> static tensor_i::future_type mk_ftx(Ts &&...args) {
209209

210210
// execute an OP on all elements of a tensor represented by
211211
// dimensionality/ptr/sizes/strides.
212-
template <typename T, typename OP>
213-
void forall(uint64_t d, const T *cptr, const int64_t *sizes,
214-
const int64_t *strides, uint64_t nd, OP op) {
212+
template <typename T, typename OP, bool PASSIDX>
213+
void forall_(uint64_t d, T *cptr, const int64_t *sizes, const int64_t *strides,
214+
uint64_t nd, OP op, std::vector<int64_t> *idx) {
215+
assert(!PASSIDX || idx);
215216
auto stride = strides[d];
216217
auto sz = sizes[d];
217218
if (d == nd - 1) {
218219
for (auto i = 0; i < sz; ++i) {
219-
op(&cptr[i * stride]);
220+
if constexpr (PASSIDX) {
221+
(*idx)[d] = i;
222+
op(*idx, &cptr[i * stride]);
223+
} else if constexpr (!PASSIDX) {
224+
op(&cptr[i * stride]);
225+
}
220226
}
221227
} else {
222228
for (auto i = 0; i < sz; ++i) {
223-
const T *tmp = cptr;
224-
forall(d + 1, cptr, sizes, strides, nd, op);
229+
T *tmp = cptr;
230+
if constexpr (PASSIDX) {
231+
(*idx)[d] = i;
232+
}
233+
forall_<T, OP, PASSIDX>(d + 1, cptr, sizes, strides, nd, op, idx);
225234
cptr = tmp + strides[d];
226235
}
227236
}
228237
}
238+
239+
template <typename T, typename OP>
240+
void forall(uint64_t d, T *cptr, const int64_t *sizes, const int64_t *strides,
241+
uint64_t nd, OP op) {
242+
forall_<T, OP, false>(d, cptr, sizes, strides, nd, op, nullptr);
243+
}
244+
245+
template <typename T, typename OP>
246+
void forall(uint64_t d, T *cptr, const int64_t *sizes, const int64_t *strides,
247+
uint64_t nd, std::vector<int64_t> &idx, OP op) {
248+
forall_<T, OP, true>(d, cptr, sizes, strides, nd, op, &idx);
249+
}

src/include/ddptensor/Deferred.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
#include "tensor_i.hpp"
1717

1818
extern void process_promises();
19-
extern void sync_promises();
2019

2120
// interface for promises/tasks to generate MLIR or execute immediately.
2221
struct Runable {

src/include/ddptensor/PyTypes.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,12 @@ template <typename T> py::tuple _make_tuple(const std::vector<T> &v) {
115115
[](const V &v, int i) { return v[i]; });
116116
}
117117

118+
template <typename T> py::tuple _make_tuple(const T ptr, size_t n) {
119+
return _make_tuple(
120+
ptr, [n](const T &) { return n; },
121+
[](const T &v, int i) { return v[i]; });
122+
}
123+
118124
template <typename T> T to_native(const py::object &o) { return o.cast<T>(); }
119125

120126
inline void compute_slice(const py::slice &slc, uint64_t &offset,

src/include/ddptensor/SetGetItem.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,5 @@ struct GetItem {
2626
struct SetItem {
2727
static ddptensor *__setitem__(ddptensor &a, const std::vector<py::slice> &v,
2828
const py::object &b);
29+
static ddptensor *map(ddptensor &a, py::object &b);
2930
};

0 commit comments

Comments
 (0)