77#include < ddptensor/DDPTensorImpl.hpp>
88#include < ddptensor/MPITransceiver.hpp>
99#include < ddptensor/MemRefType.hpp>
10- #include < ddptensor/idtr.hpp>
1110
1211#include < imex/Dialect/PTensor/IR/PTensorDefs.h>
1312
1413#include < cassert>
1514#include < iostream>
1615#include < memory>
1716
17+ constexpr id_t UNKNOWN_GUID = -1 ;
18+
1819using container_type =
1920 std::unordered_map<id_type, std::unique_ptr<DDPTensorImpl>>;
2021
@@ -273,85 +274,52 @@ void bufferizeN(void *cptr, DTypeId dtype, const int64_t *sizes,
273274 });
274275}
275276
276- using MRIdx1d = Unranked1DMemRefType<uint64_t >;
277-
278- extern " C" {
279- // Elementwise inplace allreduce
280- void idtr_reduce_all (void *inout, DTypeId dtype, uint64_t N, ReduceOpId op) {
281- getTransceiver ()->reduce_all (inout, dtype, N, op);
282- }
277+ using MRIdx1d = Unranked1DMemRefType<int64_t >;
283278
284279// FIXME hard-coded for contiguous layout
285- void _idtr_reduce_all (void *data, int64_t sizesRank, int64_t *sizesDesc,
286- int64_t stridesRank, int64_t *stridesDesc, int dtype,
287- int op) {
288- MRIdx1d sizesMR (sizesRank, sizesDesc);
289- MRIdx1d stridesMR (stridesRank, stridesDesc);
290- auto sizes = reinterpret_cast <int64_t *>(sizesMR.data ());
291- auto strides = reinterpret_cast <int64_t *>(stridesMR.data ());
292- auto rank = sizesMR.size ();
293- assert (rank == 0 || (rank == 1 && strides[0 ] == 1 ));
294- idtr_reduce_all (data, mlir2ddpt (static_cast <::imex::ptensor::DType>(dtype)),
295- rank ? sizes[0 ] : 1 ,
296- mlir2ddpt (static_cast <imex::ptensor::ReduceOpId>(op)));
280+ template <typename T>
281+ void _idtr_reduce_all (int64_t dataRank, void *dataDescr, int op) {
282+ UnrankedMemRefType<T> data (dataRank, dataDescr);
283+ auto inout = data.data ();
284+ auto sizes = data.sizes ();
285+ auto strides = data.strides ();
286+ assert (dataRank == 0 || (dataRank == 1 && strides[0 ] == 1 ));
287+ getTransceiver ()->reduce_all (
288+ inout, DTYPE<T>::value, dataRank ? sizes[0 ] : 1 ,
289+ mlir2ddpt (static_cast <imex::ptensor::ReduceOpId>(op)));
297290}
298291
292+ extern " C" {
293+
294+ #define TYPED_REDUCEALL (_sfx, _typ ) \
295+ void _idtr_reduce_all_##_sfx(int64_t dataRank, void *dataDescr, int op) { \
296+ _idtr_reduce_all<_typ>(dataRank, dataDescr, op); \
297+ }
298+
299+ TYPED_REDUCEALL (f64 , double );
300+ TYPED_REDUCEALL (f32 , float );
301+ TYPED_REDUCEALL (i64 , int64_t );
302+ TYPED_REDUCEALL (i32 , int32_t );
303+ TYPED_REDUCEALL (i16 , int16_t );
304+ TYPED_REDUCEALL (i8 , int8_t );
305+ TYPED_REDUCEALL (i1, bool );
306+
307+ } // extern "C"
308+
299309// / @brief reshape tensor
300310// / We assume tensor is partitioned along the first dimension (only) and
301311// / partitions are ordered by ranks
302- // / @param gShapeRank
303- // / @param gShapeDesc
304- // / @param dtype
305- // / @param lDataPtr
306- // / @param lOffsRank
307- // / @param lOffsDesc
308- // / @param lShapeRank
309- // / @param lShapeDesc
310- // / @param lStridesRank
311- // / @param lStridesDesc
312- // / @param oGShapeRank
313- // / @param oGShapeDesc
314- // / @param oOffsRank
315- // / @param oOffsDesc
316- // / @param oShapeRank
317- // / @param oShapeDesc
318- // / @param outPtr
319- // / @param tc
320- void _idtr_reshape (int64_t gShapeRank , int64_t *gShapeDesc , int dtype,
321- void *lDataPtr, int64_t lOffsRank, int64_t *lOffsDesc,
322- int64_t lShapeRank, int64_t *lShapeDesc,
323- int64_t lStridesRank, int64_t *lStridesDesc,
324- int64_t oGShapeRank, int64_t *oGShapeDesc, int64_t oOffsRank,
325- int64_t *oOffsDesc, int64_t oShapeRank, int64_t *oShapeDesc,
326- void *outPtr, Transceiver *tc) {
312+ void _idtr_reshape (DTypeId ddpttype, int64_t lRank, int64_t *gShapePtr ,
313+ void *lDataPtr, int64_t *lShapePtr, int64_t *lStridesPtr,
314+ int64_t *lOffsPtr, int64_t oRank, int64_t *oGShapePtr,
315+ void *oDataPtr, int64_t *oShapePtr, int64_t *oOffsPtr,
316+ Transceiver *tc) {
327317#ifdef NO_TRANSCEIVER
328318 initMPIRuntime ();
329319 tc = getTransceiver ();
330320#endif
331321
332- assert (1 == gShapeRank && 1 == lOffsRank && 1 == lShapeRank &&
333- 1 == lStridesRank && 1 == oGShapeRank && 1 == oOffsRank &&
334- 1 == oShapeRank);
335-
336- MRIdx1d gShapeUMR (gShapeRank , gShapeDesc );
337- MRIdx1d oGShapeUMR (oGShapeRank, oGShapeDesc);
338- auto rank = gShapeUMR .size ();
339- auto oRank = oGShapeUMR.size ();
340-
341- auto gShapePtr = reinterpret_cast <int64_t *>(gShapeUMR .data ());
342- auto lOffsPtr =
343- reinterpret_cast <int64_t *>(MRIdx1d (lOffsRank, lOffsDesc).data ());
344- auto lShapePtr =
345- reinterpret_cast <int64_t *>(MRIdx1d (lShapeRank, lShapeDesc).data ());
346- auto lStridesPtr =
347- reinterpret_cast <int64_t *>(MRIdx1d (lStridesRank, lStridesDesc).data ());
348- auto oGShapePtr = reinterpret_cast <int64_t *>(oGShapeUMR.data ());
349- auto oOffsPtr =
350- reinterpret_cast <int64_t *>(MRIdx1d (oOffsRank, oOffsDesc).data ());
351- auto oShapePtr =
352- reinterpret_cast <int64_t *>(MRIdx1d (oShapeRank, oShapeDesc).data ());
353-
354- assert (std::accumulate (&gShapePtr [0 ], &gShapePtr [rank], 1 ,
322+ assert (std::accumulate (&gShapePtr [0 ], &gShapePtr [lRank], 1 ,
355323 std::multiplies<int64_t >()) ==
356324 std::accumulate (&oGShapePtr[0 ], &oGShapePtr[oRank], 1 ,
357325 std::multiplies<int64_t >()));
@@ -360,9 +328,8 @@ void _idtr_reshape(int64_t gShapeRank, int64_t *gShapeDesc, int dtype,
360328
361329 auto N = tc->nranks ();
362330 auto me = tc->rank ();
363- auto ddpttype = mlir2ddpt (static_cast <::imex::ptensor::DType>(dtype));
364331
365- int64_t cSz = std::accumulate (&lShapePtr[1 ], &lShapePtr[rank ], 1 ,
332+ int64_t cSz = std::accumulate (&lShapePtr[1 ], &lShapePtr[lRank ], 1 ,
366333 std::multiplies<int64_t >());
367334 int64_t mySz = cSz * lShapePtr[0 ];
368335 int64_t myOff = lOffsPtr[0 ] * cSz;
@@ -425,64 +392,76 @@ void _idtr_reshape(int64_t gShapeRank, int64_t *gShapeDesc, int dtype,
425392
426393 Buffer outbuff (totSSz * sizeof_dtype (ddpttype), 2 ); // FIXME debug value
427394 bufferizeN (lDataPtr, ddpttype, lShapePtr, lStridesPtr, lsOffs.data (),
428- lsEnds.data (), rank , N, outbuff.data ());
429- tc->alltoall (outbuff.data (), sszs.data (), soffs.data (), ddpttype, outPtr ,
395+ lsEnds.data (), lRank , N, outbuff.data ());
396+ tc->alltoall (outbuff.data (), sszs.data (), soffs.data (), ddpttype, oDataPtr ,
430397 rszs.data (), roffs.data ());
431398}
432399
433- // / @brief repartition tensor
400+ // / @brief reshape tensor
401+ template <typename T>
402+ void _idtr_reshape (int64_t gShapeRank , void *gShapeDescr , int64_t lOffsRank,
403+ void *lOffsDescr, int64_t lRank, void *lDescr,
404+ int64_t oGShapeRank, void *oGShapeDescr, int64_t oOffsRank,
405+ void *oOffsDescr, int64_t oRank, void *oDescr,
406+ Transceiver *tc) {
407+
408+ auto ddpttype = DTYPE<T>::value;
409+
410+ UnrankedMemRefType<T> lData (lRank, lDescr);
411+ UnrankedMemRefType<T> oData (oRank, oDescr);
412+
413+ _idtr_reshape (ddpttype, lRank, MRIdx1d (gShapeRank , gShapeDescr ).data (),
414+ lData.data (), lData.sizes (), lData.strides (),
415+ MRIdx1d (oOffsRank, oOffsDescr).data (), oRank,
416+ MRIdx1d (oGShapeRank, oGShapeDescr).data (), oData.data (),
417+ oData.sizes (), MRIdx1d (oOffsRank, oOffsDescr).data (), tc);
418+ }
419+
420+ extern " C" {
421+
422+ #define TYPED_RESHAPE (_sfx, _typ ) \
423+ void _idtr_reshape_##_sfx( \
424+ int64_t gShapeRank , void *gShapeDescr , int64_t lOffsRank, \
425+ void *lOffsDescr, int64_t rank, void *lDescr, int64_t oGShapeRank, \
426+ void *oGShapeDescr, int64_t oOffsRank, void *oOffsDescr, int64_t oRank, \
427+ void *oDescr, Transceiver *tc) { \
428+ _idtr_reshape<_typ>(gShapeRank , gShapeDescr , lOffsRank, lOffsDescr, rank, \
429+ lDescr, oGShapeRank, oGShapeDescr, oOffsRank, \
430+ oOffsDescr, oRank, oDescr, tc); \
431+ }
432+
433+ TYPED_RESHAPE (f64 , double );
434+ TYPED_RESHAPE (f32 , float );
435+ TYPED_RESHAPE (i64 , int64_t );
436+ TYPED_RESHAPE (i32 , int32_t );
437+ TYPED_RESHAPE (i16 , int16_t );
438+ TYPED_RESHAPE (i8 , int8_t );
439+ TYPED_RESHAPE (i1, bool );
440+
441+ } // extern "C"
442+
443+ // / @brief repartition tensor using generic and raw pointers
434444// / We assume tensor is partitioned along the first dimension (only) and
435445// / partitions are ordered by ranks
436- // / @param gShapeRank
437- // / @param gShapeDesc
438- // / @param dtype
439- // / @param lDataPtr
440- // / @param lOffsRank
441- // / @param lOffsDesc
442- // / @param lShapeRank
443- // / @param lShapeDesc
444- // / @param lStridesRank
445- // / @param lStridesDesc
446- // / @param offsRank
447- // / @param offsDesc
448- // / @param szsRank
449- // / @param szsDesc
450- // / @param outPtr
451- // / @param tc
452- void _idtr_repartition (int64_t gShapeRank , void *gShapeDesc , int dtype,
453- void *lDataPtr, int64_t lOffsRank, void *lOffsDesc,
454- int64_t lShapeRank, void *lShapeDesc,
455- int64_t lStridesRank, void *lStridesDesc,
456- int64_t offsRank, void *offsDesc, int64_t szsRank,
457- void *szsDesc, void *outPtr, Transceiver *tc) {
446+ void _idtr_repartition (DTypeId ddpttype, int64_t rank, void *lDataPtr,
447+ int64_t *lShapePtr, int64_t *lStridesPtr,
448+ int64_t *lOffsPtr, void *outPtr, int64_t *oShapePtr,
449+ int64_t *oOffsPtr, Transceiver *tc) {
450+
458451#ifdef NO_TRANSCEIVER
459452 initMPIRuntime ();
460453 tc = getTransceiver ();
461454#endif
462455 auto N = tc->nranks ();
463456 auto me = tc->rank ();
464- auto ddpttype = mlir2ddpt (static_cast <::imex::ptensor::DType>(dtype));
465-
466- // Construct unranked memrefs for metadata
467- MRIdx1d gShapeMR (gShapeRank , gShapeDesc );
468- MRIdx1d lOffsMR (lOffsRank, lOffsDesc);
469- MRIdx1d lShapeMR (lShapeRank, lShapeDesc);
470- MRIdx1d lStridesMR (lStridesRank, lStridesDesc);
471- MRIdx1d offsMR (offsRank, offsDesc);
472- MRIdx1d szsMR (szsRank, szsDesc);
473-
474- int64_t rank = gShapeMR .size ();
475- auto lShapePtr = reinterpret_cast <int64_t *>(lShapeMR.data ());
476- auto lStridesPtr = reinterpret_cast <int64_t *>(lStridesMR.data ());
477457
478458 // First we allgather the requested target partitioning
479459
480460 auto myBOff = 2 * rank * me;
481461 ::std::vector<int64_t > buff (2 * rank * N);
482462 for (int64_t i = 0 ; i < rank; ++i) {
483- // assert(offsPtr[i] - lOffs[i] + szsPtr[i] <= gShape[i]);
484- buff[myBOff + i] = offsMR[i];
485- buff[myBOff + i + rank] = szsMR[i];
463+ buff[myBOff + i] = oOffsPtr[i];
464+ buff[myBOff + i + rank] = oShapePtr[i];
486465 }
487466 ::std::vector<int > counts (N, rank * 2 );
488467 ::std::vector<int > dspl (N);
@@ -493,8 +472,8 @@ void _idtr_repartition(int64_t gShapeRank, void *gShapeDesc, int dtype,
493472
494473 // compute overlap of my local data with each requested part
495474
496- auto myOff = static_cast <int64_t >(lOffsMR [0 ]);
497- auto mySz = static_cast <int64_t >(lShapeMR [0 ]);
475+ auto myOff = static_cast <int64_t >(lOffsPtr [0 ]);
476+ auto mySz = static_cast <int64_t >(lShapePtr [0 ]);
498477 auto myEnd = myOff + mySz;
499478 auto myTileSz = std::accumulate (&lShapePtr[1 ], &lShapePtr[rank], 1 ,
500479 std::multiplies<int64_t >());
@@ -537,7 +516,7 @@ void _idtr_repartition(int64_t gShapeRank, void *gShapeDesc, int dtype,
537516 for (auto r = 1 ; r < rank; ++r) {
538517 tStarts[i * rank + r] = buff[2 * rank * i + r];
539518 tSizes[i * rank + r] = buff[2 * rank * i + rank + r];
540- // assert(tSizes[i*rank+r] <= lShapeMR [r]);
519+ // assert(tSizes[i*rank+r] <= lShapePtr [r]);
541520 }
542521 }
543522
@@ -568,6 +547,47 @@ void _idtr_repartition(int64_t gShapeRank, void *gShapeDesc, int dtype,
568547 }
569548}
570549
550+ // / @brief templated wrapper for typed function versions calling
551+ // / _idtr_repartition
552+ template <typename T>
553+ void _idtr_repartition (int64_t gShapeRank , void *gShapeDescr , int64_t lOffsRank,
554+ void *lOffsDescr, int64_t lRank, void *lDescr,
555+ int64_t oOffsRank, void *oOffsDescr, int64_t oRank,
556+ void *oDescr, Transceiver *tc) {
557+
558+ auto ddpttype = DTYPE<T>::value;
559+
560+ // Construct unranked memrefs for metadata
561+ UnrankedMemRefType<T> lData (lRank, lDescr);
562+ UnrankedMemRefType<T> oData (oRank, oDescr);
563+
564+ _idtr_repartition (ddpttype, lRank, lData.data (), lData.sizes (),
565+ lData.strides (), MRIdx1d (lOffsRank, lOffsDescr).data (),
566+ oData.data (), oData.sizes (),
567+ MRIdx1d (oOffsRank, oOffsDescr).data (), tc);
568+ }
569+
570+ extern " C" {
571+ #define TYPED_REPARTITON (_sfx, _typ ) \
572+ void _idtr_repartition_##_sfx( \
573+ int64_t gShapeRank , void *gShapeDescr , int64_t lOffsRank, \
574+ void *lOffsDescr, int64_t rank, void *lDescr, int64_t oOffsRank, \
575+ void *oOffsDescr, int64_t oRank, void *oDescr, Transceiver *tc) { \
576+ _idtr_repartition<_typ>(gShapeRank , gShapeDescr , lOffsRank, lOffsDescr, \
577+ rank, lDescr, oOffsRank, oOffsDescr, oRank, \
578+ oDescr, tc); \
579+ }
580+
581+ TYPED_REPARTITON (f64 , double );
582+ TYPED_REPARTITON (f32 , float );
583+ TYPED_REPARTITON (i64 , int64_t );
584+ TYPED_REPARTITON (i32 , int32_t );
585+ TYPED_REPARTITON (i16 , int16_t );
586+ TYPED_REPARTITON (i8 , int8_t );
587+ TYPED_REPARTITON (i1, bool );
588+
589+ } // extern "C"
590+
571591// debug helper
572592void _idtr_extractslice (int64_t *slcOffs, int64_t *slcSizes,
573593 int64_t *slcStrides, int64_t *tOffs, int64_t *tSizes,
@@ -595,5 +615,6 @@ void _idtr_extractslice(int64_t *slcOffs, int64_t *slcSizes,
595615 << std::endl;
596616}
597617
618+ extern " C" {
598619void _debugFunc () { std::cerr << " _debugfunc\n " ; }
599620} // extern "C"
0 commit comments