@@ -77,7 +77,7 @@ namespace x {
7777 PVSlice g_slc_view (a_ptr->slice (), slice);
7878 PVSlice my_rel_slice (g_slc_view, theTransceiver->rank ());
7979 NDSlice my_norm_slice = g_slc_view.map_slice (my_rel_slice.slice_of_rank ()); // slice());my_slice);
80-
80+
8181 if (is_spmd ()) theTransceiver->barrier ();
8282 _set_slice<A>(a_ptr->xarray (), my_rel_slice, b_ptr, my_norm_slice, val_guid);
8383 theTransceiver->barrier ();
@@ -129,37 +129,47 @@ namespace x {
129129 // We simply create a local buffer, copy our local data to the right place
130130 // and then call AllGatherV via inplace operation.
131131 template <typename T>
132- static py::object op (const std::shared_ptr<DPTensorX<T>> & a_ptr)
132+ static py::object op (rank_type root, const std::shared_ptr<DPTensorX<T>> & a_ptr)
133133 {
134134 auto nranks = theTransceiver->nranks ();
135135 auto rank = theTransceiver->rank ();
136+ bool sendonly = root != REPLICATED && root != rank;
136137 const auto & slc = a_ptr->slice ();
138+ auto mysz = slc.slice_of_rank ().size ();
137139
138140 // create buffer/numpy array
139- auto res = py::array_t <T>(std::move (slc.shape ()));
140- T * ptr = reinterpret_cast <T*>(res.mutable_data ());
141+ T * ptr = nullptr ;
142+ py::array res;
143+ if (sendonly) {
144+ if (mysz > 0 && a_ptr->is_sliced ()) ptr = new T[mysz];
145+ } else {
146+ res = py::array_t <T>(std::move (slc.shape ()));
147+ ptr = reinterpret_cast <T*>(res.mutable_data ());
148+ }
141149 int displacements[nranks];
142150 int counts[nranks];
143151 int off = 0 ;
144152 // for each rank compute counts and displacements
145153 for (auto i=0 ; i<nranks; ++i) {
146- uint64_t szi = slc.slice_of_rank (i).size ();
154+ uint64_t szi = i == rank ? mysz : slc.slice_of_rank (i).size ();
147155 counts[i] = szi;
148156 displacements[i] = off;
149157 // copy our local data
150158 if (i == rank) {
151159 if (a_ptr->is_sliced ()) {
152160 // if non-contiguous copy element by element
153161 const auto & av = xt::strided_view (a_ptr->xarray (), a_ptr->lslice ());
154- uint64_t i = off- 1 ;
155- for (auto v : av) ptr[++i ] = v;
162+ uint64_t j = sendonly ? - 1 : off - 1 ;
163+ for (auto v : av) ptr[++j ] = v;
156164 } else {
157- memcpy (&ptr[off], a_ptr->xarray ().data (), szi*sizeof (T));
165+ if (sendonly && mysz > 0 ) ptr = a_ptr->xarray ().data ();
166+ else memcpy (&ptr[off], a_ptr->xarray ().data (), szi*sizeof (T));
158167 }
159168 }
160169 off += szi;
161170 }
162- theTransceiver->allgather (ptr, counts, displacements, DTYPE<T>::value);
171+ theTransceiver->gather (ptr, counts, displacements, DTYPE<T>::value, root);
172+ if (sendonly && mysz > 0 && a_ptr->is_sliced ()) delete [] ptr;
163173 return res;
164174 }
165175 };
@@ -171,12 +181,12 @@ struct DeferredSetItem : public Deferred
171181 id_type _a;
172182 id_type _b;
173183 NDSlice _slc;
174-
184+
175185 DeferredSetItem () = default ;
176186 DeferredSetItem (const tensor_i::future_type & a, const tensor_i::future_type & b, const std::vector<py::slice> & v)
177187 : _a(a.id()), _b(b.id()), _slc(v)
178188 {}
179-
189+
180190 void run ()
181191 {
182192 const auto a = std::move (Registry::get (_a).get ());
@@ -249,10 +259,15 @@ py::object GetItem::get_local(const ddptensor & a, py::handle h)
249259 return TypeDispatch<x::SPMD>(aa, h);
250260}
251261
252- py::object GetItem::gather (const ddptensor & a)
262+ py::object GetItem::do_gather (const tensor_i::ptr_type & a, rank_type root)
263+ {
264+ return TypeDispatch<x::SPMD>(a, root);
265+ }
266+
267+ py::object GetItem::gather (const ddptensor & a, rank_type root)
253268{
254269 const auto aa = std::move (a.get ().get ());
255- return TypeDispatch<x::SPMD> (aa);
270+ return do_gather (aa, root );
256271}
257272
258273FACTORY_INIT (DeferredGetItem, F_GETITEM);
0 commit comments