@@ -31,17 +31,33 @@ T * mr_to_ptr(void * ptr, intptr_t offset)
3131
3232extern " C" {
3333
34+ #define NO_TRANSCEIVER
35+ #ifdef NO_TRANSCEIVER
36+ static void initMPIRuntime () {
37+ if (getTransceiver () == nullptr )
38+ init_transceiver (new MPITransceiver (false ));
39+ }
40+ #endif
41+
3442// Return number of ranks/processes in given team/communicator
35- uint64_t idtr_nprocs (int64_t team )
43+ uint64_t idtr_nprocs (Transceiver * tc )
3644{
37- return getTransceiver ()->nranks ();
45+ #ifdef NO_TRANSCEIVER
46+ initMPIRuntime ();
47+ tc = getTransceiver ();
48+ #endif
49+ return tc->nranks ();
3850}
3951#pragma weak _idtr_nprocs = idtr_nprocs
4052
4153// Return rank in given team/communicator
42- uint64_t idtr_prank (int64_t team )
54+ uint64_t idtr_prank (Transceiver * tc )
4355{
44- return getTransceiver ()->rank ();
56+ #ifdef NO_TRANSCEIVER
57+ initMPIRuntime ();
58+ tc = getTransceiver ();
59+ #endif
60+ return tc->rank ();
4561}
4662#pragma weak _idtr_prank = idtr_prank
4763
@@ -173,7 +189,9 @@ void forall(uint64_t d, const T * cptr, const int64_t * sizes, const int64_t * s
173189 }
174190 } else {
175191 for (auto i=0 ; i<sz; ++i) {
192+ const T * tmp = cptr;
176193 forall (d+1 , cptr, sizes, strides, nd, op);
194+ cptr = tmp + strides[d];
177195 }
178196 }
179197}
@@ -190,20 +208,26 @@ bool is_contiguous(const int64_t * sizes, const int64_t * strides, uint64_t nd)
190208 return true ;
191209}
192210
193- void * bufferize (void * cptr, DTypeId dtype, const int64_t * sizes, const int64_t * strides, uint64_t nd, void * out)
194- {
195- if (is_contiguous (sizes, strides, nd)) {
196- return cptr;
197- } else {
198- dispatch (dtype, cptr, [sizes, strides, nd, out](auto * ptr) {
199- auto buff = static_cast <decltype (ptr)>(out);
200- forall (0 , ptr, sizes, strides, nd, [&buff](const auto * in) {
201- *buff = *in;
202- ++buff;
203- });
204- });
205- return out;
206- }
211+ void bufferize (void * cptr, DTypeId dtype, const int64_t * sizes, const int64_t * strides, const int64_t * tStarts, const int64_t * tSizes, uint64_t nd, uint64_t N, void * out)
212+ {
213+ dispatch (dtype, cptr, [sizes, strides, tStarts, tSizes, nd, N, out](auto * ptr) {
214+ auto buff = static_cast <decltype (ptr)>(out);
215+
216+ for (auto i=0 ; i<N; ++i) {
217+ auto szs = &tSizes[i*nd];
218+ if (szs[0 ] > 0 ) {
219+ auto sts = &tStarts[i*nd];
220+ uint64_t off = 0 ;
221+ for (int64_t r=0 ; r<nd; ++r) {
222+ off += sts[r] * strides[r];
223+ }
224+ forall (0 , &ptr[off], szs, strides, nd, [&buff](const auto * in) {
225+ *buff = *in;
226+ ++buff;
227+ });
228+ }
229+ }
230+ });
207231}
208232
209233extern " C" {
@@ -223,6 +247,7 @@ void _idtr_reduce_all(uint64_t rank, void * data, const int64_t * sizes, const i
223247 mlir2ddpt (static_cast <imex::ptensor::ReduceOpId>(op)));
224248}
225249
250+ #if 0
226251void _idtr_rebalance(uint64_t rank, const int64_t * gShape, const int64_t * lOffs,
227252 void * data, const int64_t * sizes, const int64_t * strides, int dtype,
228253 uint64_t outRank, void * out, const int64_t * outSizes, const int64_t * outStrides)
@@ -269,7 +294,7 @@ void _idtr_rebalance(uint64_t rank, const int64_t * gShape, const int64_t * lOff
269294 // Finally communicate elements
270295 getTransceiver()->alltoall(ptr, sszs.data(), soffs.data(), ddpttype, out, rszs.data(), roffs.data());
271296}
272-
297+ # endif
273298
274299// / @brief repartition tensor
275300// / We assume tensor is partitioned along the first dimension (only) and partitions are ordered by ranks
@@ -288,18 +313,20 @@ void _idtr_repartition(int64_t rank, int64_t * gShapePtr, int dtype,
288313 void * lDataPtr, int64_t * lOffsPtr, int64_t * lShapePtr, int64_t * lStridesPtr,
289314 int64_t * offsPtr, int64_t * szsPtr, void * outPtr, Transceiver * tc)
290315{
291- assert (is_contiguous (lShapePtr, lStridesPtr, rank));
292-
316+ #ifdef NO_TRANSCEIVER
317+ initMPIRuntime ();
318+ tc = getTransceiver ();
319+ #endif
293320 auto N = tc->nranks ();
294321 auto me = tc->rank ();
295322 auto ddpttype = mlir2ddpt (static_cast <::imex::ptensor::DType>(dtype));
296- auto nSz = std::accumulate (&lShapePtr[1 ], &lShapePtr[rank], 1 , std::multiplies<int64_t >());
297323
298324 // First we allgather the requested target partitioning
299325
300326 auto myBOff = 2 * rank * me;
301327 ::std::vector<int64_t > buff (2 *rank*N);
302328 for (int64_t i=0 ; i<rank; ++i) {
329+ // assert(offsPtr[i] - lOffsPtr[i] + szsPtr[i] <= gShapePtr[i]);
303330 buff[myBOff+i] = offsPtr[i];
304331 buff[myBOff+i+rank] = szsPtr[i];
305332 }
@@ -315,24 +342,44 @@ void _idtr_repartition(int64_t rank, int64_t * gShapePtr, int dtype,
315342 auto myOff = lOffsPtr[0 ];
316343 auto mySz = lShapePtr[0 ];
317344 auto myEnd = myOff + mySz;
345+ auto myTileSz = std::accumulate (&lShapePtr[1 ], &lShapePtr[rank], 1 , std::multiplies<int64_t >());
318346
319347 std::vector<int > soffs (N);
320348 std::vector<int > sszs (N, 0 );
349+ std::vector<int64_t > tStarts (N*rank, 0 );
350+ std::vector<int64_t > tSizes (N*rank, 0 );
351+ std::vector<int64_t > nSizes (N);
352+ int64_t totSSz = 0 ;
353+ bool needsBufferize = !is_contiguous (lShapePtr, lStridesPtr, rank);
321354
322355 for (auto i=0 ; i<N; ++i) {
356+ nSizes[i] = std::accumulate (&buff[2 *rank*i+rank+1 ], &buff[2 *rank*i+rank+rank], 1 , std::multiplies<int64_t >());
357+ if (nSizes[i] != myTileSz) needsBufferize = true ;
358+ }
359+ for (auto i=0 ; i<N; ++i) {
360+ auto nSz = nSizes[i];
323361 auto tOff = buff[2 *rank*i];
324362 auto tSz = buff[2 *rank*i+rank];
325363 auto tEnd = tOff + tSz;
364+
326365 if (tEnd > myOff && tOff < myEnd) {
327366 // We have a target partition which is inside my local data
328367 // we now compute what data goes to this target partition
329368 auto start = std::max (myOff, tOff);
330369 auto end = std::min (myEnd, tEnd);
331- soffs[i] = (int )(start - myOff) * nSz;
370+ tStarts[i*rank] = start - myOff;
371+ tSizes[i*rank] = end - start;
372+ soffs[i] = needsBufferize ? (i ? soffs[i-1 ] + sszs[i-1 ] : 0 ) : (int )(start - myOff) * myTileSz;
332373 sszs[i] = (int )(end - start) * nSz;
333374 } else {
334375 soffs[i] = i ? soffs[i-1 ] + sszs[i-1 ] : 0 ;
335376 }
377+ totSSz += sszs[i];
378+ for (auto r=1 ; r<rank; ++r) {
379+ tStarts[i*rank+r] = buff[2 *rank*i+r];
380+ tSizes[i*rank+r] = buff[2 *rank*i+rank+r];
381+ // assert(tSizes[i*rank+r] <= lShapePtr[r]);
382+ }
336383 }
337384
338385 // send our send sizes to others and receive theirs
@@ -348,7 +395,15 @@ void _idtr_repartition(int64_t rank, int64_t * gShapePtr, int dtype,
348395 }
349396
350397 // Finally communicate elements
351- getTransceiver ()->alltoall (lDataPtr, sszs.data (), soffs.data (), ddpttype, outPtr, rszs.data (), roffs.data ());
398+ if (needsBufferize) {
399+ // create send buffer if strided
400+ Buffer buff (totSSz * sizeof_dtype (ddpttype), 2 );
401+ bufferize (lDataPtr, ddpttype, lShapePtr, lStridesPtr, tStarts.data (), tSizes.data (), rank, N, buff.data ());
402+ getTransceiver ()->alltoall (buff.data (), sszs.data (), soffs.data (), ddpttype, outPtr, rszs.data (), roffs.data ());
403+ std::cerr << " yey\n " ;
404+ } else {
405+ getTransceiver ()->alltoall (lDataPtr, sszs.data (), soffs.data (), ddpttype, outPtr, rszs.data (), roffs.data ());
406+ }
352407}
353408
354409void _idtr_extractslice (int64_t * slcOffs,
@@ -360,13 +415,13 @@ void _idtr_extractslice(int64_t * slcOffs,
360415 int64_t * lSlcSizes,
361416 int64_t * gSlcOffsets )
362417{
363- std::cerr << " slcOffs: " << slcOffs[0 ] << " " << slcOffs[1 ] << std::endl;
364- std::cerr << " slcSizes: " << slcSizes[0 ] << " " << slcSizes[1 ] << std::endl;
365- std::cerr << " slcStrides: " << slcStrides[0 ] << " " << slcStrides[1 ] << std::endl;
366- std::cerr << " tOffs: " << tOffs[0 ] << " " << tOffs[1 ] << std::endl;
367- std::cerr << " tSizes: " << tSizes[0 ] << " " << tSizes[1 ] << std::endl;
368- std::cerr << " lSlcOffsets: " << lSlcOffsets[0 ] << " " << lSlcOffsets[1 ] << std::endl;
369- std::cerr << " lSlcSizes: " << lSlcSizes[0 ] << " " << lSlcSizes[1 ] << std::endl;
370- std::cerr << " gSlcOffsets: " << gSlcOffsets [0 ] << " " << gSlcOffsets [1 ] << std::endl;
418+ if (slcOffs) std::cerr << " slcOffs: " << slcOffs[0 ] << " " << slcOffs[1 ] << std::endl;
419+ if (slcSizes) std::cerr << " slcSizes: " << slcSizes[0 ] << " " << slcSizes[1 ] << std::endl;
420+ if (slcStrides) std::cerr << " slcStrides: " << slcStrides[0 ] << " " << slcStrides[1 ] << std::endl;
421+ if (tOffs) std::cerr << " tOffs: " << tOffs[0 ] << " " << tOffs[1 ] << std::endl;
422+ if (tSizes) std::cerr << " tSizes: " << tSizes[0 ] << " " << tSizes[1 ] << std::endl;
423+ if (lSlcOffsets) std::cerr << " lSlcOffsets: " << lSlcOffsets[0 ] << " " << lSlcOffsets[1 ] << std::endl;
424+ if (lSlcSizes) std::cerr << " lSlcSizes: " << lSlcSizes[0 ] << " " << lSlcSizes[1 ] << std::endl;
425+ if ( gSlcOffsets ) std::cerr << " gSlcOffsets: " << gSlcOffsets [0 ] << " " << gSlcOffsets [1 ] << std::endl;
371426}
372427} // extern "C"
0 commit comments