@@ -151,6 +151,63 @@ namespace x {
151151} // namespace x
152152#endif // if 0
153153
154+ template <typename T> struct mk_array {
155+ template <typename C> static py::object op (C &&shp, void *&outPtr) {
156+ auto ary = py::array_t <T>(std::forward<C>(shp));
157+ outPtr = ary.mutable_data ();
158+ return ary;
159+ }
160+ };
161+
162+ template <typename T> struct wrap_array {
163+ template <typename C, typename S>
164+ static py::object op (C &&shp, S &&str, void *data, const py::handle &handle) {
165+ return py::array (std::forward<C>(shp), std::forward<S>(str),
166+ reinterpret_cast <T *>(data), handle);
167+ }
168+ };
169+
170+ // ***************************************************************************
171+
172+ struct DeferredGetLocal
173+ : public DeferredT<GetItem::py_promise_type, GetItem::py_future_type> {
174+ id_type _a;
175+ py::handle _handle;
176+
177+ DeferredGetLocal () = default ;
178+ DeferredGetLocal (const tensor_i::future_type &a, py::handle &handle)
179+ : _a(a.id()), _handle(handle) {}
180+
181+ void run () override {
182+ auto aa = std::move (Registry::get (_a).get ());
183+ auto a_ptr = std::dynamic_pointer_cast<DDPTensorImpl>(aa);
184+ assert (a_ptr);
185+ auto tmp_shp = a_ptr->local_shape ();
186+ auto tmp_str = a_ptr->local_strides ();
187+ auto nd = a_ptr->ndims ();
188+ auto eSz = sizeof_dtype (a_ptr->dtype ());
189+ std::vector<ssize_t > strides (nd);
190+ for (auto i = 0 ; i < nd; ++i) {
191+ strides[i] = eSz * tmp_str[i];
192+ }
193+ auto res = dispatch<wrap_array>(a_ptr->dtype (),
194+ std::vector<ssize_t >(tmp_shp, &tmp_shp[nd]),
195+ strides, a_ptr->data (), _handle);
196+ set_value (res);
197+ }
198+
199+ bool generate_mlir (::mlir::OpBuilder &builder, ::mlir::Location loc,
200+ jit::DepManager &dm) override {
201+ return true ;
202+ }
203+
204+ FactoryId factory () const { return F_GETLOCAL; }
205+
206+ template <typename S> void serialize (S &ser) {
207+ ser.template value <sizeof (_a)>(_a);
208+ }
209+ };
210+
154211// ***************************************************************************
155212
156213struct DeferredGather
@@ -162,14 +219,6 @@ struct DeferredGather
162219 DeferredGather (const tensor_i::future_type &a, rank_type root)
163220 : _a(a.id()), _root(root) {}
164221
165- template <typename T> struct mk_array {
166- template <typename C> static py::object op (C &&shp, void *&outPtr) {
167- auto ary = py::array_t <T>(std::forward<C>(shp));
168- outPtr = ary.mutable_data ();
169- return ary;
170- }
171- };
172-
173222 void run () override {
174223 // gather
175224 // We simply create a local buffer, copy our local data to the right place
@@ -352,6 +401,10 @@ ddptensor *GetItem::__getitem__(const ddptensor &a,
352401 return new ddptensor (defer<DeferredGetItem>(a.get (), v));
353402}
354403
404+ GetItem::py_future_type GetItem::get_local (const ddptensor &a, py::handle h) {
405+ return defer<DeferredGetLocal>(a.get (), h);
406+ }
407+
355408GetItem::py_future_type GetItem::gather (const ddptensor &a, rank_type root) {
356409 return defer<DeferredGather>(a.get (), root);
357410}
@@ -372,11 +425,7 @@ py::object GetItem::get_slice(const ddptensor &a,
372425 return {}; // FIXME TypeDispatch<x::SPMD>(aa.get(), NDSlice(v), aa.id());
373426}
374427
375- py::object GetItem::get_local (const ddptensor &a, py::handle h) {
376- const auto aa = std::move (a.get ().get ());
377- return {}; // FIXME TypeDispatch<x::SPMD>(aa, h);
378- }
379-
380428FACTORY_INIT (DeferredGetItem, F_GETITEM);
381429FACTORY_INIT (DeferredSetItem, F_SETITEM);
382430FACTORY_INIT (DeferredGather, F_GATHER);
431+ FACTORY_INIT (DeferredGather, F_GETLOCAL);
0 commit comments