Skip to content
This repository was archived by the owner on Jan 26, 2026. It is now read-only.

Commit 23fa550

Browse files
committed
defering non-tensor-returning promises; adding to_numpy
1 parent ecc2cc7 commit 23fa550

19 files changed

+244
-93
lines changed

ddptensor/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ def init(cw=None):
4343
cw = _ddpt_cw if cw is None else cw
4444
_init(cw)
4545

46+
def to_numpy(a):
47+
return _cdt.to_numpy(a._t)
48+
4649
for op in api.api_categories["EWBinOp"]:
4750
if not op.startswith("__"):
4851
OP = op.upper()

ddptensor/spmd.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,5 @@ def get_slice(obj, *args):
66
def get_local(obj):
77
return _cdt._get_local(obj._t, obj)
88

9-
def gather(obj):
10-
return _cdt._gather(obj._t)
9+
def gather(obj, root=_cdt._Ranks._REPLICATED):
10+
return _cdt._gather(obj._t, root)

src/Deferred.cpp

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,51 +4,57 @@
44
#include "include/ddptensor/Mediator.hpp"
55
#include "include/ddptensor/Registry.hpp"
66

7-
static tbb::concurrent_bounded_queue<Deferred::ptr_type> _deferred;
7+
static tbb::concurrent_bounded_queue<Runable::ptr_type> _deferred;
88

9-
Deferred::future_type Deferred::get_future()
9+
void push_runable(Runable::ptr_type && r)
1010
{
11-
return {std::move(tensor_i::promise_type::get_future()), _guid};
11+
_deferred.push(std::move(r));
1212
}
1313

14-
#if 0
15-
void Deferred::set_value(tensor_i::ptr_type && v)
14+
void _dist(const Runable * p)
1615
{
17-
if(_guid != Registry::NOGUID) {
18-
Registry::put(_guid, v);
19-
}
20-
tensor_i::promise_type::set_value(std::forward<tensor_i::ptr_type>(v));
16+
if(is_cw() && theTransceiver->rank() == 0)
17+
theMediator->to_workers(p);
2118
}
22-
#endif
2319

24-
Deferred::future_type Deferred::defer(Deferred::ptr_type && d, bool is_global)
20+
Deferred::future_type Deferred::get_future()
2521
{
22+
return {std::move(tensor_i::promise_type::get_future().share()), _guid};
23+
}
24+
25+
Deferred::future_type defer_tensor(Runable::ptr_type && _d, bool is_global)
26+
{
27+
Deferred * d = dynamic_cast<Deferred*>(_d.get());
28+
if(!d) throw std::runtime_error("Expected Deferred Tensor promise");
2629
if(is_global) {
27-
if(is_cw() && theTransceiver->rank() == 0) theMediator->to_workers(d);
28-
if(d) d->_guid = Registry::get_guid();
30+
_dist(d);
31+
d->_guid = Registry::get_guid();
2932
}
30-
auto f = d ? d->get_future() : Deferred::future_type();
33+
auto f = d->get_future();
3134
Registry::put(f);
32-
_deferred.push(std::move(d));
35+
push_runable(std::move(_d));
3336
return f;
3437
}
3538

36-
Deferred::ptr_type Deferred::undefer_next()
39+
void Deferred::defer(Runable::ptr_type && p)
40+
{
41+
defer_tensor(std::move(p), true);
42+
}
43+
44+
void Runable::defer(Runable::ptr_type && p)
3745
{
38-
Deferred::ptr_type r;
39-
_deferred.pop(r);
40-
return r;
46+
push_runable(std::move(p));
4147
}
4248

43-
void Deferred::fini()
49+
void Runable::fini()
4450
{
4551
_deferred.clear();
4652
}
4753

4854
void process_promises()
4955
{
5056
while(true) {
51-
Deferred::ptr_type d;
57+
Runable::ptr_type d;
5258
_deferred.pop(d);
5359
if(d) d->run();
5460
else break;

src/IO.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#include "ddptensor/IO.hpp"
2+
#include "ddptensor/SetGetItem.hpp"
3+
#include "ddptensor/TypeDispatch.hpp"
4+
#include "ddptensor/Factory.hpp"
5+
6+
using promise_type = std::promise<py::object>;
7+
using future_type = std::shared_future<py::object>;
8+
9+
struct DeferredToNumpy : public DeferredT<promise_type, future_type>
10+
{
11+
id_type _a;
12+
13+
DeferredToNumpy() = default;
14+
DeferredToNumpy(const tensor_i::future_type & a)
15+
: _a(a.id())
16+
{}
17+
18+
void run()
19+
{
20+
const auto a = std::move(Registry::get(_a).get());
21+
set_value(GetItem::do_gather(a, is_cw() ? 0 : REPLICATED));
22+
}
23+
24+
FactoryId factory() const
25+
{
26+
return F_TONUMPY;
27+
}
28+
29+
template<typename S>
30+
void serialize(S & ser)
31+
{
32+
ser.template value<sizeof(_a)>(_a);
33+
}
34+
};
35+
36+
py::object IO::to_numpy(const ddptensor & a)
37+
{
38+
assert(!is_cw() || theTransceiver->rank() == 0);
39+
auto f = defer<DeferredToNumpy>(a.get());
40+
auto x = f.get();
41+
return x;
42+
}
43+
44+
FACTORY_INIT(DeferredToNumpy, F_TONUMPY);

src/MPIMediator.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ constexpr static int DEFER_TAG = 14714;
1919
constexpr static int EXIT_TAG = 14715;
2020
static std::mutex ak_mutex;
2121

22-
void send_to_workers(const Deferred::ptr_type & dfrd, bool self, MPI_Comm comm);
22+
void send_to_workers(const Runable * dfrd, bool self, MPI_Comm comm);
2323

2424
MPIMediator::MPIMediator()
2525
: _listener(nullptr)
@@ -82,7 +82,7 @@ void MPIMediator::pull(rank_type from, id_type guid, const NDSlice & slice, void
8282
if(cnt != sz) throw(std::runtime_error("Received unexpected message size."));
8383
}
8484

85-
void send_to_workers(const Deferred::ptr_type & dfrd, bool self, MPI_Comm comm)
85+
void send_to_workers(const Runable * dfrd, bool self, MPI_Comm comm)
8686
{
8787
int rank, sz;
8888
MPI_Comm_rank(comm, &rank);
@@ -124,7 +124,7 @@ void send_to_workers(const Deferred::ptr_type & dfrd, bool self, MPI_Comm comm)
124124
}
125125
}
126126

127-
void MPIMediator::to_workers(const Deferred::ptr_type & dfrd)
127+
void MPIMediator::to_workers(const Runable * dfrd)
128128
{
129129
send_to_workers(dfrd, false, _comm);
130130
}
@@ -158,7 +158,8 @@ void MPIMediator::listen()
158158
case DEFER_TAG: {
159159
FactoryId fctryid;
160160
ser.value<sizeof(fctryid)>(fctryid);
161-
Deferred::defer(Factory::get(fctryid)->create(ser), true);
161+
auto uptr = Factory::get(fctryid)->create(ser);
162+
uptr.get()->defer(std::move(uptr)); // grmpf
162163
break;
163164
}
164165
case PULL_TAG: {
@@ -181,7 +182,7 @@ void MPIMediator::listen()
181182
break;
182183
}
183184
case EXIT_TAG:
184-
Deferred::defer(nullptr, false);
185+
defer(nullptr);
185186
return;
186187
default:
187188
throw(std::runtime_error("Received unexpected message tag."));

src/MPITransceiver.cpp

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -180,14 +180,27 @@ void MPITransceiver::alltoall(const void* buffer_send,
180180
_comm);
181181
}
182182

183-
void MPITransceiver::allgather(void* buffer,
184-
const int* counts,
185-
const int* displacements,
186-
DTypeId datatype)
183+
void MPITransceiver::gather(void* buffer,
184+
const int* counts,
185+
const int* displacements,
186+
DTypeId datatype,
187+
rank_type root)
187188
{
188-
MPI_Allgatherv(MPI_IN_PLACE, 0, to_mpi(datatype),
189-
buffer, counts, displacements, to_mpi(datatype),
190-
_comm);
189+
if(root == REPLICATED) {
190+
MPI_Allgatherv(MPI_IN_PLACE, 0, to_mpi(datatype),
191+
buffer, counts, displacements, to_mpi(datatype),
192+
_comm);
193+
} else {
194+
if(root == _rank) {
195+
MPI_Gatherv(MPI_IN_PLACE, 0, to_mpi(datatype),
196+
buffer, counts, displacements, to_mpi(datatype),
197+
root, _comm);
198+
} else {
199+
MPI_Gatherv(buffer, counts[_rank], to_mpi(datatype),
200+
nullptr, nullptr, nullptr, to_mpi(datatype),
201+
root, _comm);
202+
}
203+
}
191204
}
192205

193206
void MPITransceiver::send_recv(void* buffer_send,

src/Random.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ ddptensor * Random::rand(DTypeId dtype, const shape_type & shape, const py::obje
6969

7070
void Random::seed(uint64_t s)
7171
{
72-
defer([s](){xt::random::seed(s); return tensor_i::ptr_type();});
72+
defer_lambda([s](){xt::random::seed(s); return tensor_i::ptr_type();});
7373
}
7474

7575
FACTORY_INIT(DeferredRandomOp, F_RANDOM);

src/SetGetItem.cpp

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ namespace x {
7777
PVSlice g_slc_view(a_ptr->slice(), slice);
7878
PVSlice my_rel_slice(g_slc_view, theTransceiver->rank());
7979
NDSlice my_norm_slice = g_slc_view.map_slice(my_rel_slice.slice_of_rank()); //slice());my_slice);
80-
80+
8181
if(is_spmd()) theTransceiver->barrier();
8282
_set_slice<A>(a_ptr->xarray(), my_rel_slice, b_ptr, my_norm_slice, val_guid);
8383
theTransceiver->barrier();
@@ -129,37 +129,47 @@ namespace x {
129129
// We simply create a local buffer, copy our local data to the right place
130130
// and then call AllGatherV via inplace operation.
131131
template<typename T>
132-
static py::object op(const std::shared_ptr<DPTensorX<T>> & a_ptr)
132+
static py::object op(rank_type root, const std::shared_ptr<DPTensorX<T>> & a_ptr)
133133
{
134134
auto nranks = theTransceiver->nranks();
135135
auto rank = theTransceiver->rank();
136+
bool sendonly = root != REPLICATED && root != rank;
136137
const auto & slc = a_ptr->slice();
138+
auto mysz = slc.slice_of_rank().size();
137139

138140
// create buffer/numpy array
139-
auto res = py::array_t<T>(std::move(slc.shape()));
140-
T * ptr = reinterpret_cast<T*>(res.mutable_data());
141+
T * ptr = nullptr;
142+
py::array res;
143+
if(sendonly) {
144+
if(mysz > 0 && a_ptr->is_sliced()) ptr = new T[mysz];
145+
} else {
146+
res = py::array_t<T>(std::move(slc.shape()));
147+
ptr = reinterpret_cast<T*>(res.mutable_data());
148+
}
141149
int displacements[nranks];
142150
int counts[nranks];
143151
int off = 0;
144152
// for each rank compute counts and displacements
145153
for(auto i=0; i<nranks; ++i) {
146-
uint64_t szi = slc.slice_of_rank(i).size();
154+
uint64_t szi = i == rank ? mysz : slc.slice_of_rank(i).size();
147155
counts[i] = szi;
148156
displacements[i] = off;
149157
// copy our local data
150158
if(i == rank) {
151159
if(a_ptr->is_sliced()) {
152160
// if non-contiguous copy element by element
153161
const auto & av = xt::strided_view(a_ptr->xarray(), a_ptr->lslice());
154-
uint64_t i = off-1;
155-
for(auto v : av) ptr[++i] = v;
162+
uint64_t j = sendonly ? -1 : off - 1;
163+
for(auto v : av) ptr[++j] = v;
156164
} else {
157-
memcpy(&ptr[off], a_ptr->xarray().data(), szi*sizeof(T));
165+
if(sendonly && mysz > 0) ptr = a_ptr->xarray().data();
166+
else memcpy(&ptr[off], a_ptr->xarray().data(), szi*sizeof(T));
158167
}
159168
}
160169
off += szi;
161170
}
162-
theTransceiver->allgather(ptr, counts, displacements, DTYPE<T>::value);
171+
theTransceiver->gather(ptr, counts, displacements, DTYPE<T>::value, root);
172+
if(sendonly && mysz > 0 && a_ptr->is_sliced()) delete [] ptr;
163173
return res;
164174
}
165175
};
@@ -171,12 +181,12 @@ struct DeferredSetItem : public Deferred
171181
id_type _a;
172182
id_type _b;
173183
NDSlice _slc;
174-
184+
175185
DeferredSetItem() = default;
176186
DeferredSetItem(const tensor_i::future_type & a, const tensor_i::future_type & b, const std::vector<py::slice> & v)
177187
: _a(a.id()), _b(b.id()), _slc(v)
178188
{}
179-
189+
180190
void run()
181191
{
182192
const auto a = std::move(Registry::get(_a).get());
@@ -249,10 +259,15 @@ py::object GetItem::get_local(const ddptensor & a, py::handle h)
249259
return TypeDispatch<x::SPMD>(aa, h);
250260
}
251261

252-
py::object GetItem::gather(const ddptensor & a)
262+
py::object GetItem::do_gather(const tensor_i::ptr_type & a, rank_type root)
263+
{
264+
return TypeDispatch<x::SPMD>(a, root);
265+
}
266+
267+
py::object GetItem::gather(const ddptensor & a, rank_type root)
253268
{
254269
const auto aa = std::move(a.get().get());
255-
return TypeDispatch<x::SPMD>(aa);
270+
return do_gather(aa, root);
256271
}
257272

258273
FACTORY_INIT(DeferredGetItem, F_GETITEM);

src/ddptensor.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ using namespace pybind11::literals; // to bring _a
3434
#include "ddptensor/LinAlgOp.hpp"
3535
#include "ddptensor/Service.hpp"
3636
#include "ddptensor/Factory.hpp"
37+
#include "ddptensor/IO.hpp"
3738

3839
// #########################################################################
3940
// The following classes are wrappers bridging pybind11 defs to TypeDispatch
@@ -68,7 +69,7 @@ void fini()
6869
delete theMediator; // stop task is sent in here
6970
theMediator = nullptr;
7071
if(pprocessor) {
71-
if(theTransceiver->nranks() == 1) Deferred::defer(nullptr, false);
72+
if(theTransceiver->nranks() == 1) defer(nullptr);
7273
pprocessor->join();
7374
delete pprocessor;
7475
}
@@ -115,18 +116,22 @@ PYBIND11_MODULE(_ddptensor, m) {
115116
Factory::init<F_SETITEM>();
116117
Factory::init<F_RANDOM>();
117118
Factory::init<F_SERVICE>();
119+
Factory::init<F_TONUMPY>();
118120

119121
m.doc() = "A partitioned and distributed tensor";
120122

121123
def_enums(m);
124+
py::enum_<_RANKS>(m, "_Ranks")
125+
.value("_REPLICATED", REPLICATED);
122126

123127
m.def("fini", &fini)
124128
.def("init", &init)
125129
.def("sync", &sync)
126130
.def("myrank", &myrank)
127131
.def("_get_slice", &GetItem::get_slice)
128132
.def("_get_local", &GetItem::get_local)
129-
.def("_gather", &GetItem::gather);
133+
.def("_gather", &GetItem::gather)
134+
.def("to_numpy", &IO::to_numpy);
130135

131136
py::class_<Creator>(m, "Creator")
132137
.def("create_from_shape", &Creator::create_from_shape)

0 commit comments

Comments
 (0)