|
| 1 | +// SPDX-License-Identifier: BSD-3-Clause |
| 2 | + |
| 3 | +#pragma once |
| 4 | + |
| 5 | +#include "UtilsAndTypes.hpp" |
| 6 | +#include "x.hpp" |
| 7 | + |
| 8 | +struct CollComm |
| 9 | +{ |
| 10 | + // We assume we split in first dimension. |
| 11 | + // We also assume partitions are assigned to ranks in sequence from 0-N. |
| 12 | + // With this we know that our buffers (old and new) get data in the |
| 13 | + // same order. The only thing which might have changed is the tile-size. |
| 14 | + // Actually, the tile-size might change only if old or new shape does not evenly |
| 15 | + // distribute data (e.g. last partition is smaller). |
| 16 | + // In theory we could re-shape in-place when the norm-tile-size does not change. |
| 17 | + // This is not implemented: we need an extra mechanism to work with reshape-views or alike. |
| 18 | + template<typename T, typename U> |
| 19 | + static tensor_i::ptr_type coll_copy(std::shared_ptr<x::DPTensorX<T>> b_ptr, const std::shared_ptr<x::DPTensorX<U>> & a_ptr) |
| 20 | + { |
| 21 | + assert(! a_ptr->is_sliced() && ! b_ptr->is_sliced()); |
| 22 | + |
| 23 | + auto o_slc = a_ptr->slice(); |
| 24 | + // norm tile-size of orig array |
| 25 | + auto o_ntsz = o_slc.tile_size(0); |
| 26 | + // tilesize of my local partition of orig array |
| 27 | + auto o_tsz = o_slc.tile_size(); |
| 28 | + // linearized local slice of orig array |
| 29 | + auto o_llslc = Slice(o_ntsz * theTransceiver->rank(), o_ntsz * theTransceiver->rank() + o_tsz); |
| 30 | + |
| 31 | + PVSlice n_slc = b_ptr->slice(); |
| 32 | + // norm tile-size of new (reshaped) array |
| 33 | + auto n_ntsz = n_slc.tile_size(0); |
| 34 | + // tilesize of my local partition of new (reshaped) array |
| 35 | + auto n_tsz = n_slc.tile_size(); |
| 36 | + // linearized/flattened/1d local slice of new (reshaped) array |
| 37 | + auto n_llslc = Slice(n_ntsz * theTransceiver->rank(), n_ntsz * theTransceiver->rank() + n_tsz); |
| 38 | + |
| 39 | + auto nr = theTransceiver->nranks(); |
| 40 | + // We need a few C-arrays for MPI (counts and displacements in send/recv buffers) |
| 41 | + int counts_send[nr] = {0}; |
| 42 | + int disp_send[nr] = {0}; |
| 43 | + int counts_recv[nr] = {0}; |
| 44 | + int disp_recv[nr] = {0}; |
| 45 | + |
| 46 | + for(auto r=0; r<nr; ++r) { |
| 47 | + // determine what I receive from rank r |
| 48 | + // e.g. which parts of my new slice overlap with rank r's old slice |
| 49 | + // Get local slice of rank r of orig array |
| 50 | + auto o_rslc = o_slc.local_slice_of_rank(r); |
| 51 | + // Flatten to 1d |
| 52 | + auto o_lrslc = Slice(o_ntsz * r, o_ntsz * r + o_rslc.size()); |
| 53 | + // Determine overlap with local partition of linearized new array |
| 54 | + auto roverlap = n_llslc.overlap(o_lrslc); |
| 55 | + // number of elements to be received from rank r |
| 56 | + counts_recv[r] = roverlap.size(); |
| 57 | + // displacement in new array where elements from rank r get copied to |
| 58 | + disp_recv[r] = roverlap._start - n_llslc._start; |
| 59 | + |
| 60 | + // determine what I send to rank r |
| 61 | + // e.g. which parts of my old slice overlap with rank r's new slice |
| 62 | + // Get local slice of rank r of new array |
| 63 | + auto n_rslc = n_slc.local_slice_of_rank(r); |
| 64 | + // Flatten to 1d |
| 65 | + auto n_lrslc = Slice(n_ntsz * r, n_ntsz * r + n_rslc.size()); |
| 66 | + // Determine overlap with local partition of linearized orig array |
| 67 | + auto soverlap = o_llslc.overlap(n_lrslc); |
| 68 | + // number of elements to be send to rank r |
| 69 | + counts_send[r] = soverlap.size(); |
| 70 | + // displacement in orig array where elements from rank r get copied from |
| 71 | + disp_send[r] = soverlap._start - o_llslc._start; |
| 72 | + } |
| 73 | + |
| 74 | + // Now we can send/recv directly to/from xarray buffers. |
| 75 | + theTransceiver->alltoall(a_ptr->xarray().data(), // buffer_send |
| 76 | + counts_send, |
| 77 | + disp_send, |
| 78 | + DTYPE<T>::value, |
| 79 | + b_ptr->xarray().data(), // buffer_recv |
| 80 | + counts_recv, |
| 81 | + disp_recv, |
| 82 | + DTYPE<T>::value); |
| 83 | + |
| 84 | + return b_ptr; |
| 85 | + } |
| 86 | +}; |
0 commit comments