@@ -167,6 +167,21 @@ template <typename T> struct wrap_array {
167167 }
168168};
169169
170+ py::object wrap (DDPTensorImpl::ptr_type tnsr, const py::handle &handle) {
171+ auto tmp_shp = tnsr->local_shape ();
172+ auto tmp_str = tnsr->local_strides ();
173+ auto nd = tnsr->ndims ();
174+ auto eSz = sizeof_dtype (tnsr->dtype ());
175+ std::vector<ssize_t > strides (nd);
176+ for (auto i = 0 ; i < nd; ++i) {
177+ strides[i] = eSz * tmp_str[i];
178+ }
179+
180+ return dispatch<wrap_array>(tnsr->dtype (),
181+ std::vector<ssize_t >(tmp_shp, &tmp_shp[nd]),
182+ strides, tnsr->data (), handle);
183+ }
184+
170185// ***************************************************************************
171186
172187struct DeferredGetLocal
@@ -182,17 +197,7 @@ struct DeferredGetLocal
182197 auto aa = std::move (Registry::get (_a).get ());
183198 auto a_ptr = std::dynamic_pointer_cast<DDPTensorImpl>(aa);
184199 assert (a_ptr);
185- auto tmp_shp = a_ptr->local_shape ();
186- auto tmp_str = a_ptr->local_strides ();
187- auto nd = a_ptr->ndims ();
188- auto eSz = sizeof_dtype (a_ptr->dtype ());
189- std::vector<ssize_t > strides (nd);
190- for (auto i = 0 ; i < nd; ++i) {
191- strides[i] = eSz * tmp_str[i];
192- }
193- auto res = dispatch<wrap_array>(a_ptr->dtype (),
194- std::vector<ssize_t >(tmp_shp, &tmp_shp[nd]),
195- strides, a_ptr->data (), _handle);
200+ auto res = wrap (a_ptr, _handle);
196201 set_value (res);
197202 }
198203
@@ -317,6 +322,58 @@ struct DeferredSetItem : public Deferred {
317322
318323// ***************************************************************************
319324
325+ struct DeferredMap : public Deferred {
326+ id_type _a;
327+ py::object _func;
328+
329+ DeferredMap () = default ;
330+ DeferredMap (const tensor_i::future_type &a, py::object &func)
331+ : Deferred(a.id(), a.dtype(), a.rank(), a.balanced()), _a(a.id()),
332+ _func (func) {}
333+
334+ void run () override {
335+ auto aa = std::move (Registry::get (_a).get ());
336+ auto a_ptr = std::dynamic_pointer_cast<DDPTensorImpl>(aa);
337+ assert (a_ptr);
338+ auto nd = a_ptr->ndims ();
339+ auto lOffs = a_ptr->local_offsets ();
340+ std::vector<int64_t > lIdx (nd);
341+ std::vector<int64_t > gIdx (nd);
342+
343+ dispatch (a_ptr->dtype (), a_ptr->data (), [&](auto *ptr) {
344+ forall (
345+ 0 , ptr, a_ptr->local_shape (), a_ptr->local_strides (), nd, lIdx,
346+ [&](const std::vector<int64_t > &idx, auto *elPtr) {
347+ for (auto i = 0 ; i < nd; ++i) {
348+ gIdx [i] = idx[i] + lOffs[i];
349+ }
350+ auto pyIdx = _make_tuple (gIdx );
351+ *elPtr =
352+ _func (*pyIdx)
353+ .cast <
354+ typename std::remove_pointer<decltype (elPtr)>::type>();
355+ });
356+ });
357+
358+ this ->set_value (aa);
359+ };
360+
361+ bool generate_mlir (::mlir::OpBuilder &builder, ::mlir::Location loc,
362+ jit::DepManager &dm) override {
363+ return true ;
364+ }
365+
366+ FactoryId factory () const { return F_MAP; }
367+
368+ template <typename S> void serialize (S &ser) {
369+ assert (false );
370+ ser.template value <sizeof (_a)>(_a);
371+ // nope ser.template value<sizeof(_func)>(_func);
372+ }
373+ };
374+
375+ // ***************************************************************************
376+
320377struct DeferredGetItem : public Deferred {
321378 id_type _a;
322379 NDSlice _slc;
@@ -407,14 +464,18 @@ GetItem::py_future_type GetItem::gather(const ddptensor &a, rank_type root) {
407464
408465ddptensor *SetItem::__setitem__ (ddptensor &a, const std::vector<py::slice> &v,
409466 const py::object &b) {
410-
411467 auto bb = Creator::mk_future (b);
412468 a.put (defer<DeferredSetItem>(a.get (), bb.first ->get (), v));
413469 if (bb.second )
414470 delete bb.first ;
415471 return &a;
416472}
417473
474+ ddptensor *SetItem::map (ddptensor &a, py::object &b) {
475+ a.put (defer<DeferredMap>(a.get (), b));
476+ return &a;
477+ }
478+
418479py::object GetItem::get_slice (const ddptensor &a,
419480 const std::vector<py::slice> &v) {
420481 const auto aa = std::move (a.get ());
@@ -423,5 +484,6 @@ py::object GetItem::get_slice(const ddptensor &a,
423484
424485FACTORY_INIT (DeferredGetItem, F_GETITEM);
425486FACTORY_INIT (DeferredSetItem, F_SETITEM);
487+ FACTORY_INIT (DeferredMap, F_MAP);
426488FACTORY_INIT (DeferredGather, F_GATHER);
427- FACTORY_INIT (DeferredGather , F_GETLOCAL);
489+ FACTORY_INIT (DeferredGetLocal , F_GETLOCAL);
0 commit comments