66*/
77
88#include " ddptensor/SetGetItem.hpp"
9+ #include " ddptensor/CollComm.hpp"
910#include " ddptensor/Creator.hpp"
1011#include " ddptensor/DDPTensorImpl.hpp"
1112#include " ddptensor/Factory.hpp"
1213#include " ddptensor/Mediator.hpp"
1314#include " ddptensor/NDSlice.hpp"
15+ #include " ddptensor/Transceiver.hpp"
1416#include " ddptensor/TypeDispatch.hpp"
1517#include " ddptensor/UtilsAndTypes.hpp"
1618
1921#include < imex/Utils/PassUtils.h>
2022#include < mlir/IR/Builders.h>
2123
24+ #include < pybind11/numpy.h>
25+ #include < pybind11/pybind11.h>
26+ namespace py = pybind11;
27+
2228#if 0
2329namespace x {
2430
@@ -140,59 +146,75 @@ namespace x {
140146 T * data = a_ptr->xarray().data();
141147 return py::array(std::move(slc.shape()), std::move(strides), data + off, handle);
142148 }
143-
144- // gather
145- // We simply create a local buffer, copy our local data to the right place
146- // and then call AllGatherV via inplace operation.
147- template<typename T>
148- static py::object op(rank_type root, const std::shared_ptr<DPTensorX<T>> & a_ptr)
149- {
150- auto nranks = getTransceiver()->nranks();
151- auto rank = getTransceiver()->rank();
152- bool sendonly = root != REPLICATED && root != rank;
153- const auto & slc = a_ptr->slice();
154- auto mysz = slc.local_slice().size();
155-
156- // create buffer/numpy array
157- T * ptr = nullptr;
158- py::array res;
159- if(sendonly) {
160- if(mysz > 0 && a_ptr->is_sliced()) ptr = new T[mysz];
161- } else {
162- res = py::array_t<T>(slc.shape());
163- ptr = reinterpret_cast<T*>(res.mutable_data());
164- }
165- int displacements[nranks];
166- int counts[nranks];
167- int off = 0;
168- // for each rank compute counts and displacements
169- for(auto i=0; i<nranks; ++i) {
170- uint64_t szi = i == rank ? mysz : slc.local_slice(i).size();
171- counts[i] = szi;
172- displacements[i] = off;
173- // copy our local data
174- if(i == rank) {
175- if(a_ptr->is_sliced()) {
176- // if non-contiguous copy element by element
177- const auto & av = xt::strided_view(a_ptr->xarray(), a_ptr->lslice());
178- uint64_t j = sendonly ? -1 : off - 1;
179- for(auto v : av) ptr[++j] = v;
180- } else {
181- if(sendonly && mysz > 0) ptr = a_ptr->xarray().data();
182- else memcpy(&ptr[off], a_ptr->xarray().data(), szi*sizeof(T));
183- }
184- }
185- off += szi;
186- }
187- getTransceiver()->gather(ptr, counts, displacements, DTYPE<T>::value, root);
188- if(sendonly && mysz > 0 && a_ptr->is_sliced()) delete [] ptr;
189- return res;
190- }
191149 };
192150
193151} // namespace x
194152#endif // if 0
195153
154+ // ***************************************************************************
155+
156+ struct DeferredGather
157+ : public DeferredT<GetItem::py_promise_type, GetItem::py_future_type> {
158+ id_type _a;
159+ rank_type _root;
160+
161+ DeferredGather () = default ;
162+ DeferredGather (const tensor_i::future_type &a, rank_type root)
163+ : _a(a.id()), _root(root) {}
164+
165+ template <typename T> struct mk_array {
166+ template <typename C> static py::object op (C &&shp, void *&outPtr) {
167+ auto ary = py::array_t <T>(std::forward<C>(shp));
168+ outPtr = ary.mutable_data ();
169+ return ary;
170+ }
171+ };
172+
173+ void run () override {
174+ // gather
175+ // We simply create a local buffer, copy our local data to the right place
176+ // and then call AllGatherV via inplace operation.
177+ auto trscvr = getTransceiver ();
178+ auto myrank = trscvr->rank ();
179+ auto aa = std::move (Registry::get (_a).get ());
180+ auto a_ptr = std::dynamic_pointer_cast<DDPTensorImpl>(aa);
181+ assert (a_ptr);
182+ bool sendonly = _root != REPLICATED && _root != myrank;
183+
184+ void *outPtr = nullptr ;
185+ py::object res;
186+ if (!sendonly) {
187+ auto tmp = a_ptr->shape ();
188+ // std::vector<ssize_t> shp(tmp, &tmp[a_ptr->ndims()]);
189+ res = dispatch<mk_array>(a_ptr->dtype (),
190+ std::vector<ssize_t >(tmp, &tmp[a_ptr->ndims ()]),
191+ outPtr);
192+ // (void*)nullptr, [&shp, &res, &outPtr](auto * ptr) {
193+ // auto ary = py::array_t<double>({4,4});
194+ // res = ary;
195+ // outPtr = ary.mutable_data();
196+ // });
197+ }
198+
199+ gather_tensor (a_ptr, _root, outPtr);
200+
201+ set_value (res);
202+ }
203+
204+ bool generate_mlir (::mlir::OpBuilder &builder, ::mlir::Location loc,
205+ jit::DepManager &dm) override {
206+ return true ;
207+ }
208+
209+ FactoryId factory () const { return F_GATHER; }
210+
211+ template <typename S> void serialize (S &ser) {
212+ ser.template value <sizeof (_a)>(_a);
213+ }
214+ };
215+
216+ // ***************************************************************************
217+
196218struct DeferredSetItem : public Deferred {
197219 id_type _a;
198220 id_type _b;
@@ -252,15 +274,7 @@ struct DeferredSetItem : public Deferred {
252274 }
253275};
254276
255- ddptensor *SetItem::__setitem__ (ddptensor &a, const std::vector<py::slice> &v,
256- const py::object &b) {
257-
258- auto bb = Creator::mk_future (b);
259- auto res = new ddptensor (defer<DeferredSetItem>(a.get (), bb.first ->get (), v));
260- if (bb.second )
261- delete bb.first ;
262- return res;
263- }
277+ // ***************************************************************************
264278
265279struct DeferredGetItem : public Deferred {
266280 id_type _a;
@@ -331,11 +345,27 @@ struct DeferredGetItem : public Deferred {
331345 }
332346};
333347
348+ // ***************************************************************************
349+
334350ddptensor *GetItem::__getitem__ (const ddptensor &a,
335351 const std::vector<py::slice> &v) {
336352 return new ddptensor (defer<DeferredGetItem>(a.get (), v));
337353}
338354
355+ GetItem::py_future_type GetItem::gather (const ddptensor &a, rank_type root) {
356+ return defer<DeferredGather>(a.get (), root);
357+ }
358+
359+ ddptensor *SetItem::__setitem__ (ddptensor &a, const std::vector<py::slice> &v,
360+ const py::object &b) {
361+
362+ auto bb = Creator::mk_future (b);
363+ auto res = new ddptensor (defer<DeferredSetItem>(a.get (), bb.first ->get (), v));
364+ if (bb.second )
365+ delete bb.first ;
366+ return res;
367+ }
368+
339369py::object GetItem::get_slice (const ddptensor &a,
340370 const std::vector<py::slice> &v) {
341371 const auto aa = std::move (a.get ());
@@ -347,14 +377,6 @@ py::object GetItem::get_local(const ddptensor &a, py::handle h) {
347377 return {}; // FIXME TypeDispatch<x::SPMD>(aa, h);
348378}
349379
350- py::object GetItem::do_gather (const tensor_i::ptr_type &a, rank_type root) {
351- return {}; // FIXME TypeDispatch<x::SPMD>(a, root);
352- }
353-
354- py::object GetItem::gather (const ddptensor &a, rank_type root) {
355- const auto aa = std::move (a.get ().get ());
356- return do_gather (aa, root);
357- }
358-
359380FACTORY_INIT (DeferredGetItem, F_GETITEM);
360381FACTORY_INIT (DeferredSetItem, F_SETITEM);
382+ FACTORY_INIT (DeferredGather, F_GATHER);
0 commit comments