@@ -224,13 +224,17 @@ class dtensor_impl : public tensor_i
224224 }
225225
226226 // since the API works on tensor_i we need to downcast to the actual type
227- const dtensor_impl<T> * cast (const ptr_type & b) const
227+ static dtensor_impl<T> * cast (ptr_type & b)
228228 {
229229 // FIXME; use attribute/vfunction + reinterpret_cast for better performance
230- auto ptr = dynamic_cast <const dtensor_impl<T>*>(b.get ());
230+ auto ptr = dynamic_cast <dtensor_impl<T>*>(b.get ());
231231 // if(ptr == nullptr) throw(std::runtime_error("Incompatible tensor types."));
232232 return ptr;
233233 }
234+ static const dtensor_impl<T> * cast (const ptr_type & b)
235+ {
236+ return cast (const_cast <ptr_type &>(b));
237+ }
234238
235239 ptr_type _ew_op (const char * op, const char * mod, py::args args, const py::kwargs & kwargs)
236240 {
@@ -331,42 +335,28 @@ class dtensor_impl : public tensor_i
331335 }
332336 }
333337
334- // FIXME We use a generic SPMD/PGAS mechanism to pull elements from remote
335- // on all procs simultaneously. Since __setitem__ is collective we could
336- // implement a probaly more efficient mechanism which pushes data and/or using RMA.
337- void __setitem__ (const NDSlice & slice, const ptr_type & val)
338+ // copy data from val into (*dest)[slice]
339+ // this is a non-collective call.
340+ static void _set_slice (const dtensor_impl<T> * val, const NDSlice & val_slice, dtensor_impl<T> * dest, const NDSlice & dest_slice)
338341 {
339- std::cerr << " __setitem__ " << slice << " " << val-> pvslice (). slice () << std::endl;
340- auto nd = shape ().size ();
341- if (owner () == REPLICATED && nd > 0 )
342+ std::cerr << " _set_slice " << val_slice << " " << dest_slice << std::endl;
343+ auto nd = dest-> shape ().size ();
344+ if (dest-> owner () == REPLICATED && nd > 0 )
342345 std::cerr << " Warning: __setitem__ on replicated data updates local tile only" << std::endl;
343- if (nd != slice .ndims ())
346+ if (nd != dest_slice .ndims ())
344347 throw std::runtime_error (" Index dimensionality must match array dimensionality" );
348+ if (val_slice.size () != dest_slice.size ())
349+ throw std::runtime_error (" Input and output slices must be of same size" );
345350
346- auto slc_sz = slice.size ();
347- auto val_sz = VPROD (val->shape ());
348- if (slc_sz != val_sz)
349- throw std::runtime_error (" Given tensor does not match: it has different size than 'slice'" );
350-
351- NDSlice norm_slice = pvslice ().normalized_slice ();
352- std::cerr << " norm_slice: " << norm_slice << std::endl;
353351 // Use given slice to create a global view into orig array
354- PVSlice g_slc_view (pvslice (), slice );
352+ PVSlice g_slc_view (dest-> pvslice (), dest_slice );
355353 std::cerr << " g_slice: " << g_slc_view.slice () << std::endl;
356- PVSlice my_view (g_slc_view, theTransceiver->rank ());
357- NDSlice my_slice = my_view.slice ();
358- std::cerr << " my_slice: " << my_slice << std::endl;
359- NDSlice my_norm_slice = g_slc_view.map_slice (my_slice);
360- std::cerr << " my_norm_slice: " << my_norm_slice << std::endl;
361-
362354 // Create a view into val
363- PVSlice needed_val_view (val->pvslice (), my_norm_slice );
355+ PVSlice needed_val_view (val->pvslice (), val_slice );
364356 std::cerr << " needed_val_view: " << needed_val_view.slice () << " (was " << val->pvslice ().slice () << " )" << std::endl;
365357
366358 // Get the pointer to the local buffer
367- auto ns = get_array_impl (_pyarray);
368- // auto my_binfo = _pyarray.cast<py::buffer>().request();
369- // T * my_buffer = reinterpret_cast<T*>(my_binfo.ptr);
359+ auto ns = get_array_impl (dest->_pyarray );
370360
371361 // we can now compute which ranks actually hold which piece of the data from val that we need locally
372362 for (rank_type i=0 ; i<theTransceiver->nranks (); ++i ) {
@@ -377,7 +367,7 @@ class dtensor_impl : public tensor_i
377367 std::cerr << i << " curr_needed_val_slice: " << curr_needed_val_slice << std::endl;
378368 NDSlice curr_local_val_slice = val_local_view.map_slice (curr_needed_val_slice);
379369 std::cerr << i << " curr_local_val_slice: " << curr_local_val_slice << std::endl;
380- NDSlice curr_needed_norm_slice = val-> pvslice () .map_slice (curr_needed_val_slice);
370+ NDSlice curr_needed_norm_slice = needed_val_view .map_slice (curr_needed_val_slice);
381371 std::cerr << i << " curr_needed_norm_slice: " << curr_needed_norm_slice << std::endl;
382372 PVSlice my_curr_needed_view = PVSlice (g_slc_view, curr_needed_norm_slice);
383373 std::cerr << i << " my_curr_needed_slice: " << my_curr_needed_view.slice () << std::endl;
@@ -387,23 +377,39 @@ class dtensor_impl : public tensor_i
387377 py::tuple tpl = _make_tuple (my_curr_local_slice); // my_curr_view.slice());
388378 if (i == theTransceiver->rank ()) {
389379 // copy locally
390- auto rhs = cast ( val) ->_pyarray .attr (" __getitem__" )(_make_tuple (curr_local_val_slice));
380+ auto rhs = val->_pyarray .attr (" __getitem__" )(_make_tuple (curr_local_val_slice));
391381 std::cerr << py::str (rhs).cast <std::string>() << std::endl;
392- _pyarray.attr (" __setitem__" )(tpl, rhs);
382+ dest-> _pyarray .attr (" __setitem__" )(tpl, rhs);
393383 } else {
394384 // pull slice directly into new array
395385 auto obj = ns.attr (" empty" )(_make_tuple (curr_local_val_slice.shape ()));
396386 auto binfo = obj.cast <py::buffer>().request ();
397387 theMediator->pull (i, val, curr_local_val_slice, binfo.ptr );
398- _pyarray.attr (" __setitem__" )(tpl, obj);
388+ dest-> _pyarray .attr (" __setitem__" )(tpl, obj);
399389 }
400390 }
401391 }
402392 }
403393
394+ // FIXME We use a generic SPMD/PGAS mechanism to pull elements from remote
395+ // on all procs simultaneously. Since __setitem__ is collective we could
396+ // implement a probaly more efficient mechanism which pushes data and/or using RMA.
397+ void __setitem__ (const NDSlice & slice, const ptr_type & val)
398+ {
399+ // Use given slice to create a global view into orig array
400+ PVSlice g_slc_view (this ->pvslice (), slice);
401+ std::cerr << " g_slice: " << g_slc_view.slice () << std::endl;
402+ NDSlice my_slice = g_slc_view.slice_of_rank (theTransceiver->rank ());
403+ std::cerr << " my_slice: " << my_slice << std::endl;
404+ NDSlice my_norm_slice = g_slc_view.map_slice (my_slice);
405+ std::cerr << " my_norm_slice: " << my_norm_slice << std::endl;
406+
407+ _set_slice (cast (val), my_norm_slice, this , my_slice);
408+ }
409+
404410 void bufferize (const NDSlice & slice, Buffer & buff)
405411 {
406- PVSlice my_local_view = PVSlice (tile_shape ()); // pvslice().view_normalized_by_rank(theTransceiver->rank());
412+ PVSlice my_local_view = PVSlice (tile_shape ());
407413 PVSlice lview = PVSlice (my_local_view, slice);
408414 NDSlice lslice = lview.slice ();
409415
@@ -422,6 +428,14 @@ class dtensor_impl : public tensor_i
422428 }
423429 }
424430
431+ py::object get_slice (const NDSlice & slice) const
432+ {
433+ auto shp = slice.shape ();
434+ auto out = create_dtensor (PVSlice (shp, NOSPLIT), shp, DTYPE<T>::value, " empty" );
435+ _set_slice (this , slice, cast (out), {shp});
436+ return cast (out)->_pyarray ;
437+ }
438+
425439 std::string __repr__ () const
426440 {
427441 return " dtensor(shape=" + to_string (shape (), ' x' ) + " , n_tiles="
0 commit comments