44 Intel Distributed Runtime for MLIR
55*/
66
7- #include < ddptensor/idtr.hpp>
8- // #include <ddptensor/jit/mlir.hpp>
97#include < ddptensor/DDPTensorImpl.hpp>
108#include < ddptensor/MPITransceiver.hpp>
9+ #include < ddptensor/MemRefType.hpp>
10+ #include < ddptensor/idtr.hpp>
1111
1212#include < imex/Dialect/PTensor/IR/PTensorDefs.h>
1313
@@ -392,24 +392,34 @@ void _idtr_reshape(int64_t rank, int64_t *gShapePtr, int dtype, void *lDataPtr,
392392 rszs.data (), roffs.data ());
393393}
394394
395+ using MRIdx1d = Unranked1DMemRefType<uint64_t >;
396+
395397// / @brief repartition tensor
396398// / We assume tensor is partitioned along the first dimension (only) and
397399// / partitions are ordered by ranks
398400// / @param rank
399- // / @param gShapePtr
401+ // / @param gShapeRank
402+ // / @param gShapeDesc
400403// / @param dtype
401404// / @param lDataPtr
402- // / @param lOffsPtr
403- // / @param lShapePtr
404- // / @param lStridesPtr
405- // / @param offsPtr
406- // / @param szsPtr
405+ // / @param lOffsRank
406+ // / @param lOffsDesc
407+ // / @param lShapeRank
408+ // / @param lShapeDesc
409+ // / @param lStridesRank
410+ // / @param lStridesDesc
411+ // / @param offsRank
412+ // / @param offsDesc
413+ // / @param szsRank
414+ // / @param szsDesc
407415// / @param outPtr
408416// / @param tc
409- void _idtr_repartition (int64_t rank, int64_t *gShapePtr , int dtype,
410- void *lDataPtr, int64_t *lOffsPtr, int64_t *lShapePtr,
411- int64_t *lStridesPtr, int64_t *offsPtr, int64_t *szsPtr,
412- void *outPtr, Transceiver *tc) {
417+ void _idtr_repartition (int64_t rank, int64_t gShapeRank , void *gShapeDesc ,
418+ int dtype, void *lDataPtr, int64_t lOffsRank,
419+ void *lOffsDesc, int64_t lShapeRank, void *lShapeDesc,
420+ int64_t lStridesRank, void *lStridesDesc,
421+ int64_t offsRank, void *offsDesc, int64_t szsRank,
422+ void *szsDesc, void *outPtr, Transceiver *tc) {
413423#ifdef NO_TRANSCEIVER
414424 initMPIRuntime ();
415425 tc = getTransceiver ();
@@ -418,14 +428,25 @@ void _idtr_repartition(int64_t rank, int64_t *gShapePtr, int dtype,
418428 auto me = tc->rank ();
419429 auto ddpttype = mlir2ddpt (static_cast <::imex::ptensor::DType>(dtype));
420430
431+ // Construct unranked memrefs for metadata
432+ MRIdx1d gShapeMR (gShapeRank , gShapeDesc );
433+ MRIdx1d lOffsMR (lOffsRank, lOffsDesc);
434+ MRIdx1d lShapeMR (lShapeRank, lShapeDesc);
435+ MRIdx1d lStridesMR (lStridesRank, lStridesDesc);
436+ MRIdx1d offsMR (offsRank, offsDesc);
437+ MRIdx1d szsMR (szsRank, szsDesc);
438+
439+ auto lShapePtr = reinterpret_cast <int64_t *>(lShapeMR.data ());
440+ auto lStridesPtr = reinterpret_cast <int64_t *>(lStridesMR.data ());
441+
421442 // First we allgather the requested target partitioning
422443
423444 auto myBOff = 2 * rank * me;
424445 ::std::vector<int64_t > buff (2 * rank * N);
425446 for (int64_t i = 0 ; i < rank; ++i) {
426- // assert(offsPtr[i] - lOffsPtr [i] + szsPtr[i] <= gShapePtr [i]);
427- buff[myBOff + i] = offsPtr [i];
428- buff[myBOff + i + rank] = szsPtr [i];
447+ // assert(offsPtr[i] - lOffs [i] + szsPtr[i] <= gShape [i]);
448+ buff[myBOff + i] = offsMR [i];
449+ buff[myBOff + i + rank] = szsMR [i];
429450 }
430451 ::std::vector<int > counts (N, rank * 2 );
431452 ::std::vector<int > dspl (N);
@@ -436,8 +457,8 @@ void _idtr_repartition(int64_t rank, int64_t *gShapePtr, int dtype,
436457
437458 // compute overlap of my local data with each requested part
438459
439- auto myOff = lOffsPtr [0 ];
440- auto mySz = lShapePtr [0 ];
460+ auto myOff = static_cast < int64_t >(lOffsMR [0 ]) ;
461+ auto mySz = static_cast < int64_t >(lShapeMR [0 ]) ;
441462 auto myEnd = myOff + mySz;
442463 auto myTileSz = std::accumulate (&lShapePtr[1 ], &lShapePtr[rank], 1 ,
443464 std::multiplies<int64_t >());
@@ -480,7 +501,7 @@ void _idtr_repartition(int64_t rank, int64_t *gShapePtr, int dtype,
480501 for (auto r = 1 ; r < rank; ++r) {
481502 tStarts[i * rank + r] = buff[2 * rank * i + r];
482503 tSizes[i * rank + r] = buff[2 * rank * i + rank + r];
483- // assert(tSizes[i*rank+r] <= lShapePtr [r]);
504+ // assert(tSizes[i*rank+r] <= lShapeMR [r]);
484505 }
485506 }
486507
0 commit comments