@@ -273,46 +273,84 @@ void bufferizeN(void *cptr, DTypeId dtype, const int64_t *sizes,
273273 });
274274}
275275
276+ using MRIdx1d = Unranked1DMemRefType<uint64_t >;
277+
276278extern " C" {
277279// Elementwise inplace allreduce
278280void idtr_reduce_all (void *inout, DTypeId dtype, uint64_t N, ReduceOpId op) {
279281 getTransceiver ()->reduce_all (inout, dtype, N, op);
280282}
281283
282284// FIXME hard-coded for contiguous layout
283- void _idtr_reduce_all (uint64_t rank, void *data, const int64_t *sizes,
284- const int64_t *strides, int dtype, int op) {
285- assert (rank == 0 || strides[rank - 1 ] == 1 );
285+ void _idtr_reduce_all (void *data, int64_t sizesRank, int64_t *sizesDesc,
286+ int64_t stridesRank, int64_t *stridesDesc, int dtype,
287+ int op) {
288+ MRIdx1d sizesMR (sizesRank, sizesDesc);
289+ MRIdx1d stridesMR (stridesRank, stridesDesc);
290+ auto sizes = reinterpret_cast <int64_t *>(sizesMR.data ());
291+ auto strides = reinterpret_cast <int64_t *>(stridesMR.data ());
292+ auto rank = sizesMR.size ();
293+ assert (rank == 0 || (rank == 1 && strides[0 ] == 1 ));
286294 idtr_reduce_all (data, mlir2ddpt (static_cast <::imex::ptensor::DType>(dtype)),
287- rank ? rank : 1 ,
295+ rank ? sizes[ 0 ] : 1 ,
288296 mlir2ddpt (static_cast <imex::ptensor::ReduceOpId>(op)));
289297}
290298
291299// / @brief reshape tensor
292300// / We assume tensor is partitioned along the first dimension (only) and
293301// / partitions are ordered by ranks
294- // / @param rank
295- // / @param gShapePtr
302+ // / @param gShapeRank
303+ // / @param gShapeDesc
296304// / @param dtype
297305// / @param lDataPtr
298- // / @param lOffsPtr
299- // / @param lShapePtr
300- // / @param lStridesPtr
301- // / @param oRank
302- // / @param oGShapePtr
303- // / @param oOffsPtr
304- // / @param oShapePtr
306+ // / @param lOffsRank
307+ // / @param lOffsDesc
308+ // / @param lShapeRank
309+ // / @param lShapeDesc
310+ // / @param lStridesRank
311+ // / @param lStridesDesc
312+ // / @param oGShapeRank
313+ // / @param oGShapeDesc
314+ // / @param oOffsRank
315+ // / @param oOffsDesc
316+ // / @param oShapeRank
317+ // / @param oShapeDesc
305318// / @param outPtr
306319// / @param tc
307- void _idtr_reshape (int64_t rank, int64_t *gShapePtr , int dtype, void *lDataPtr,
308- int64_t *lOffsPtr, int64_t *lShapePtr, int64_t *lStridesPtr,
309- int64_t oRank, int64_t *oGShapePtr, int64_t *oOffsPtr,
310- int64_t *oShapePtr, void *outPtr, Transceiver *tc) {
320+ void _idtr_reshape (int64_t gShapeRank , int64_t *gShapeDesc , int dtype,
321+ void *lDataPtr, int64_t lOffsRank, int64_t *lOffsDesc,
322+ int64_t lShapeRank, int64_t *lShapeDesc,
323+ int64_t lStridesRank, int64_t *lStridesDesc,
324+ int64_t oGShapeRank, int64_t *oGShapeDesc, int64_t oOffsRank,
325+ int64_t *oOffsDesc, int64_t oShapeRank, int64_t *oShapeDesc,
326+ void *outPtr, Transceiver *tc) {
311327#ifdef NO_TRANSCEIVER
312328 initMPIRuntime ();
313329 tc = getTransceiver ();
314330#endif
315331
332+ assert (1 == gShapeRank && 1 == lOffsRank && 1 == lShapeRank &&
333+ 1 == lStridesRank && 1 == oGShapeRank && 1 == oOffsRank &&
334+ 1 == oShapeRank);
335+
336+ MRIdx1d gShapeUMR (gShapeRank , gShapeDesc );
337+ MRIdx1d oGShapeUMR (oGShapeRank, oGShapeDesc);
338+ auto rank = gShapeUMR .size ();
339+ auto oRank = oGShapeUMR.size ();
340+
341+ auto gShapePtr = reinterpret_cast <int64_t *>(gShapeUMR .data ());
342+ auto lOffsPtr =
343+ reinterpret_cast <int64_t *>(MRIdx1d (lOffsRank, lOffsDesc).data ());
344+ auto lShapePtr =
345+ reinterpret_cast <int64_t *>(MRIdx1d (lShapeRank, lShapeDesc).data ());
346+ auto lStridesPtr =
347+ reinterpret_cast <int64_t *>(MRIdx1d (lStridesRank, lStridesDesc).data ());
348+ auto oGShapePtr = reinterpret_cast <int64_t *>(oGShapeUMR.data ());
349+ auto oOffsPtr =
350+ reinterpret_cast <int64_t *>(MRIdx1d (oOffsRank, oOffsDesc).data ());
351+ auto oShapePtr =
352+ reinterpret_cast <int64_t *>(MRIdx1d (oShapeRank, oShapeDesc).data ());
353+
316354 assert (std::accumulate (&gShapePtr [0 ], &gShapePtr [rank], 1 ,
317355 std::multiplies<int64_t >()) ==
318356 std::accumulate (&oGShapePtr[0 ], &oGShapePtr[oRank], 1 ,
@@ -392,12 +430,9 @@ void _idtr_reshape(int64_t rank, int64_t *gShapePtr, int dtype, void *lDataPtr,
392430 rszs.data (), roffs.data ());
393431}
394432
395- using MRIdx1d = Unranked1DMemRefType<uint64_t >;
396-
397433// / @brief repartition tensor
398434// / We assume tensor is partitioned along the first dimension (only) and
399435// / partitions are ordered by ranks
400- // / @param rank
401436// / @param gShapeRank
402437// / @param gShapeDesc
403438// / @param dtype
@@ -414,9 +449,9 @@ using MRIdx1d = Unranked1DMemRefType<uint64_t>;
414449// / @param szsDesc
415450// / @param outPtr
416451// / @param tc
417- void _idtr_repartition (int64_t rank, int64_t gShapeRank , void *gShapeDesc ,
418- int dtype, void *lDataPtr, int64_t lOffsRank,
419- void *lOffsDesc, int64_t lShapeRank, void *lShapeDesc,
452+ void _idtr_repartition (int64_t gShapeRank , void *gShapeDesc , int dtype ,
453+ void *lDataPtr, int64_t lOffsRank, void *lOffsDesc ,
454+ int64_t lShapeRank, void *lShapeDesc,
420455 int64_t lStridesRank, void *lStridesDesc,
421456 int64_t offsRank, void *offsDesc, int64_t szsRank,
422457 void *szsDesc, void *outPtr, Transceiver *tc) {
@@ -436,6 +471,7 @@ void _idtr_repartition(int64_t rank, int64_t gShapeRank, void *gShapeDesc,
436471 MRIdx1d offsMR (offsRank, offsDesc);
437472 MRIdx1d szsMR (szsRank, szsDesc);
438473
474+ int64_t rank = gShapeMR .size ();
439475 auto lShapePtr = reinterpret_cast <int64_t *>(lShapeMR.data ());
440476 auto lStridesPtr = reinterpret_cast <int64_t *>(lStridesMR.data ());
441477
@@ -520,10 +556,11 @@ void _idtr_repartition(int64_t rank, int64_t gShapeRank, void *gShapeDesc,
520556 // Finally communicate elements
521557 if (needsBufferize) {
522558 // create send buffer if strided
523- Buffer buff (totSSz * sizeof_dtype (ddpttype), 2 );
559+ Buffer tmpbuff;
560+ tmpbuff.resize (totSSz * sizeof_dtype (ddpttype));
524561 bufferize (lDataPtr, ddpttype, lShapePtr, lStridesPtr, tStarts.data (),
525- tSizes.data (), rank, N, buff .data ());
526- tc->alltoall (buff .data (), sszs.data (), soffs.data (), ddpttype, outPtr,
562+ tSizes.data (), rank, N, tmpbuff .data ());
563+ tc->alltoall (tmpbuff .data (), sszs.data (), soffs.data (), ddpttype, outPtr,
527564 rszs.data (), roffs.data ());
528565 } else {
529566 tc->alltoall (lDataPtr, sszs.data (), soffs.data (), ddpttype, outPtr,
0 commit comments