Skip to content
This repository was archived by the owner on Jan 26, 2026. It is now read-only.

Commit 96d6adc

Browse files
authored
accepting unranked memrefs in type-specialized runtime callbacks (#20)
accepting unranked memrefs in type-specialized runtime callbacks, not raw pointers
1 parent 50082ce commit 96d6adc

File tree

6 files changed

+168
-156
lines changed

6 files changed

+168
-156
lines changed

.github/workflows/ci.yml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,9 +126,10 @@ jobs:
126126
run: |
127127
. $GITHUB_WORKSPACE/third_party/install/miniconda/etc/profile.d/conda.sh
128128
. $GITHUB_WORKSPACE/third_party/install/miniconda/bin/activate ddpt
129-
export DDPT_IDTR_SO=build/lib.linux-x86_64-cpython-*/ddptensor/libidtr.so
129+
DDPT_ROOT=`pip show -f ddptensor | grep Location | awk '{print $2}'`
130+
export DDPT_IDTR_SO=${DDPT_ROOT}/ddptensor/libidtr.so
130131
export DDPT_CRUNNER_SO="$GITHUB_WORKSPACE"/third_party/install/llvm-mlir/lib/libmlir_c_runner_utils.so
131132
pytest test
132-
# DDPT_FORCE_DIST=1 pytest test
133+
DDPT_FORCE_DIST=1 pytest test
133134
- run: |
134135
echo "This job's status is ${{ job.status }}."

imex_version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
524df1c49aed52259f9ef8cea018c123b0bcada3
1+
399d083586cac8b2cb71f7a8da514ac2d816aac4

src/idtr.cpp

Lines changed: 135 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,15 @@
77
#include <ddptensor/DDPTensorImpl.hpp>
88
#include <ddptensor/MPITransceiver.hpp>
99
#include <ddptensor/MemRefType.hpp>
10-
#include <ddptensor/idtr.hpp>
1110

1211
#include <imex/Dialect/PTensor/IR/PTensorDefs.h>
1312

1413
#include <cassert>
1514
#include <iostream>
1615
#include <memory>
1716

17+
constexpr id_t UNKNOWN_GUID = -1;
18+
1819
using container_type =
1920
std::unordered_map<id_type, std::unique_ptr<DDPTensorImpl>>;
2021

@@ -273,85 +274,52 @@ void bufferizeN(void *cptr, DTypeId dtype, const int64_t *sizes,
273274
});
274275
}
275276

276-
using MRIdx1d = Unranked1DMemRefType<uint64_t>;
277-
278-
extern "C" {
279-
// Elementwise inplace allreduce
280-
void idtr_reduce_all(void *inout, DTypeId dtype, uint64_t N, ReduceOpId op) {
281-
getTransceiver()->reduce_all(inout, dtype, N, op);
282-
}
277+
using MRIdx1d = Unranked1DMemRefType<int64_t>;
283278

284279
// FIXME hard-coded for contiguous layout
285-
void _idtr_reduce_all(void *data, int64_t sizesRank, int64_t *sizesDesc,
286-
int64_t stridesRank, int64_t *stridesDesc, int dtype,
287-
int op) {
288-
MRIdx1d sizesMR(sizesRank, sizesDesc);
289-
MRIdx1d stridesMR(stridesRank, stridesDesc);
290-
auto sizes = reinterpret_cast<int64_t *>(sizesMR.data());
291-
auto strides = reinterpret_cast<int64_t *>(stridesMR.data());
292-
auto rank = sizesMR.size();
293-
assert(rank == 0 || (rank == 1 && strides[0] == 1));
294-
idtr_reduce_all(data, mlir2ddpt(static_cast<::imex::ptensor::DType>(dtype)),
295-
rank ? sizes[0] : 1,
296-
mlir2ddpt(static_cast<imex::ptensor::ReduceOpId>(op)));
280+
template <typename T>
281+
void _idtr_reduce_all(int64_t dataRank, void *dataDescr, int op) {
282+
UnrankedMemRefType<T> data(dataRank, dataDescr);
283+
auto inout = data.data();
284+
auto sizes = data.sizes();
285+
auto strides = data.strides();
286+
assert(dataRank == 0 || (dataRank == 1 && strides[0] == 1));
287+
getTransceiver()->reduce_all(
288+
inout, DTYPE<T>::value, dataRank ? sizes[0] : 1,
289+
mlir2ddpt(static_cast<imex::ptensor::ReduceOpId>(op)));
297290
}
298291

292+
extern "C" {
293+
294+
#define TYPED_REDUCEALL(_sfx, _typ) \
295+
void _idtr_reduce_all_##_sfx(int64_t dataRank, void *dataDescr, int op) { \
296+
_idtr_reduce_all<_typ>(dataRank, dataDescr, op); \
297+
}
298+
299+
TYPED_REDUCEALL(f64, double);
300+
TYPED_REDUCEALL(f32, float);
301+
TYPED_REDUCEALL(i64, int64_t);
302+
TYPED_REDUCEALL(i32, int32_t);
303+
TYPED_REDUCEALL(i16, int16_t);
304+
TYPED_REDUCEALL(i8, int8_t);
305+
TYPED_REDUCEALL(i1, bool);
306+
307+
} // extern "C"
308+
299309
/// @brief reshape tensor
300310
/// We assume tensor is partitioned along the first dimension (only) and
301311
/// partitions are ordered by ranks
302-
/// @param gShapeRank
303-
/// @param gShapeDesc
304-
/// @param dtype
305-
/// @param lDataPtr
306-
/// @param lOffsRank
307-
/// @param lOffsDesc
308-
/// @param lShapeRank
309-
/// @param lShapeDesc
310-
/// @param lStridesRank
311-
/// @param lStridesDesc
312-
/// @param oGShapeRank
313-
/// @param oGShapeDesc
314-
/// @param oOffsRank
315-
/// @param oOffsDesc
316-
/// @param oShapeRank
317-
/// @param oShapeDesc
318-
/// @param outPtr
319-
/// @param tc
320-
void _idtr_reshape(int64_t gShapeRank, int64_t *gShapeDesc, int dtype,
321-
void *lDataPtr, int64_t lOffsRank, int64_t *lOffsDesc,
322-
int64_t lShapeRank, int64_t *lShapeDesc,
323-
int64_t lStridesRank, int64_t *lStridesDesc,
324-
int64_t oGShapeRank, int64_t *oGShapeDesc, int64_t oOffsRank,
325-
int64_t *oOffsDesc, int64_t oShapeRank, int64_t *oShapeDesc,
326-
void *outPtr, Transceiver *tc) {
312+
void _idtr_reshape(DTypeId ddpttype, int64_t lRank, int64_t *gShapePtr,
313+
void *lDataPtr, int64_t *lShapePtr, int64_t *lStridesPtr,
314+
int64_t *lOffsPtr, int64_t oRank, int64_t *oGShapePtr,
315+
void *oDataPtr, int64_t *oShapePtr, int64_t *oOffsPtr,
316+
Transceiver *tc) {
327317
#ifdef NO_TRANSCEIVER
328318
initMPIRuntime();
329319
tc = getTransceiver();
330320
#endif
331321

332-
assert(1 == gShapeRank && 1 == lOffsRank && 1 == lShapeRank &&
333-
1 == lStridesRank && 1 == oGShapeRank && 1 == oOffsRank &&
334-
1 == oShapeRank);
335-
336-
MRIdx1d gShapeUMR(gShapeRank, gShapeDesc);
337-
MRIdx1d oGShapeUMR(oGShapeRank, oGShapeDesc);
338-
auto rank = gShapeUMR.size();
339-
auto oRank = oGShapeUMR.size();
340-
341-
auto gShapePtr = reinterpret_cast<int64_t *>(gShapeUMR.data());
342-
auto lOffsPtr =
343-
reinterpret_cast<int64_t *>(MRIdx1d(lOffsRank, lOffsDesc).data());
344-
auto lShapePtr =
345-
reinterpret_cast<int64_t *>(MRIdx1d(lShapeRank, lShapeDesc).data());
346-
auto lStridesPtr =
347-
reinterpret_cast<int64_t *>(MRIdx1d(lStridesRank, lStridesDesc).data());
348-
auto oGShapePtr = reinterpret_cast<int64_t *>(oGShapeUMR.data());
349-
auto oOffsPtr =
350-
reinterpret_cast<int64_t *>(MRIdx1d(oOffsRank, oOffsDesc).data());
351-
auto oShapePtr =
352-
reinterpret_cast<int64_t *>(MRIdx1d(oShapeRank, oShapeDesc).data());
353-
354-
assert(std::accumulate(&gShapePtr[0], &gShapePtr[rank], 1,
322+
assert(std::accumulate(&gShapePtr[0], &gShapePtr[lRank], 1,
355323
std::multiplies<int64_t>()) ==
356324
std::accumulate(&oGShapePtr[0], &oGShapePtr[oRank], 1,
357325
std::multiplies<int64_t>()));
@@ -360,9 +328,8 @@ void _idtr_reshape(int64_t gShapeRank, int64_t *gShapeDesc, int dtype,
360328

361329
auto N = tc->nranks();
362330
auto me = tc->rank();
363-
auto ddpttype = mlir2ddpt(static_cast<::imex::ptensor::DType>(dtype));
364331

365-
int64_t cSz = std::accumulate(&lShapePtr[1], &lShapePtr[rank], 1,
332+
int64_t cSz = std::accumulate(&lShapePtr[1], &lShapePtr[lRank], 1,
366333
std::multiplies<int64_t>());
367334
int64_t mySz = cSz * lShapePtr[0];
368335
int64_t myOff = lOffsPtr[0] * cSz;
@@ -425,64 +392,76 @@ void _idtr_reshape(int64_t gShapeRank, int64_t *gShapeDesc, int dtype,
425392

426393
Buffer outbuff(totSSz * sizeof_dtype(ddpttype), 2); // FIXME debug value
427394
bufferizeN(lDataPtr, ddpttype, lShapePtr, lStridesPtr, lsOffs.data(),
428-
lsEnds.data(), rank, N, outbuff.data());
429-
tc->alltoall(outbuff.data(), sszs.data(), soffs.data(), ddpttype, outPtr,
395+
lsEnds.data(), lRank, N, outbuff.data());
396+
tc->alltoall(outbuff.data(), sszs.data(), soffs.data(), ddpttype, oDataPtr,
430397
rszs.data(), roffs.data());
431398
}
432399

433-
/// @brief repartition tensor
400+
/// @brief reshape tensor
401+
template <typename T>
402+
void _idtr_reshape(int64_t gShapeRank, void *gShapeDescr, int64_t lOffsRank,
403+
void *lOffsDescr, int64_t lRank, void *lDescr,
404+
int64_t oGShapeRank, void *oGShapeDescr, int64_t oOffsRank,
405+
void *oOffsDescr, int64_t oRank, void *oDescr,
406+
Transceiver *tc) {
407+
408+
auto ddpttype = DTYPE<T>::value;
409+
410+
UnrankedMemRefType<T> lData(lRank, lDescr);
411+
UnrankedMemRefType<T> oData(oRank, oDescr);
412+
413+
_idtr_reshape(ddpttype, lRank, MRIdx1d(gShapeRank, gShapeDescr).data(),
414+
lData.data(), lData.sizes(), lData.strides(),
415+
MRIdx1d(oOffsRank, oOffsDescr).data(), oRank,
416+
MRIdx1d(oGShapeRank, oGShapeDescr).data(), oData.data(),
417+
oData.sizes(), MRIdx1d(oOffsRank, oOffsDescr).data(), tc);
418+
}
419+
420+
extern "C" {
421+
422+
#define TYPED_RESHAPE(_sfx, _typ) \
423+
void _idtr_reshape_##_sfx( \
424+
int64_t gShapeRank, void *gShapeDescr, int64_t lOffsRank, \
425+
void *lOffsDescr, int64_t rank, void *lDescr, int64_t oGShapeRank, \
426+
void *oGShapeDescr, int64_t oOffsRank, void *oOffsDescr, int64_t oRank, \
427+
void *oDescr, Transceiver *tc) { \
428+
_idtr_reshape<_typ>(gShapeRank, gShapeDescr, lOffsRank, lOffsDescr, rank, \
429+
lDescr, oGShapeRank, oGShapeDescr, oOffsRank, \
430+
oOffsDescr, oRank, oDescr, tc); \
431+
}
432+
433+
TYPED_RESHAPE(f64, double);
434+
TYPED_RESHAPE(f32, float);
435+
TYPED_RESHAPE(i64, int64_t);
436+
TYPED_RESHAPE(i32, int32_t);
437+
TYPED_RESHAPE(i16, int16_t);
438+
TYPED_RESHAPE(i8, int8_t);
439+
TYPED_RESHAPE(i1, bool);
440+
441+
} // extern "C"
442+
443+
/// @brief repartition tensor using generic and raw pointers
434444
/// We assume tensor is partitioned along the first dimension (only) and
435445
/// partitions are ordered by ranks
436-
/// @param gShapeRank
437-
/// @param gShapeDesc
438-
/// @param dtype
439-
/// @param lDataPtr
440-
/// @param lOffsRank
441-
/// @param lOffsDesc
442-
/// @param lShapeRank
443-
/// @param lShapeDesc
444-
/// @param lStridesRank
445-
/// @param lStridesDesc
446-
/// @param offsRank
447-
/// @param offsDesc
448-
/// @param szsRank
449-
/// @param szsDesc
450-
/// @param outPtr
451-
/// @param tc
452-
void _idtr_repartition(int64_t gShapeRank, void *gShapeDesc, int dtype,
453-
void *lDataPtr, int64_t lOffsRank, void *lOffsDesc,
454-
int64_t lShapeRank, void *lShapeDesc,
455-
int64_t lStridesRank, void *lStridesDesc,
456-
int64_t offsRank, void *offsDesc, int64_t szsRank,
457-
void *szsDesc, void *outPtr, Transceiver *tc) {
446+
void _idtr_repartition(DTypeId ddpttype, int64_t rank, void *lDataPtr,
447+
int64_t *lShapePtr, int64_t *lStridesPtr,
448+
int64_t *lOffsPtr, void *outPtr, int64_t *oShapePtr,
449+
int64_t *oOffsPtr, Transceiver *tc) {
450+
458451
#ifdef NO_TRANSCEIVER
459452
initMPIRuntime();
460453
tc = getTransceiver();
461454
#endif
462455
auto N = tc->nranks();
463456
auto me = tc->rank();
464-
auto ddpttype = mlir2ddpt(static_cast<::imex::ptensor::DType>(dtype));
465-
466-
// Construct unranked memrefs for metadata
467-
MRIdx1d gShapeMR(gShapeRank, gShapeDesc);
468-
MRIdx1d lOffsMR(lOffsRank, lOffsDesc);
469-
MRIdx1d lShapeMR(lShapeRank, lShapeDesc);
470-
MRIdx1d lStridesMR(lStridesRank, lStridesDesc);
471-
MRIdx1d offsMR(offsRank, offsDesc);
472-
MRIdx1d szsMR(szsRank, szsDesc);
473-
474-
int64_t rank = gShapeMR.size();
475-
auto lShapePtr = reinterpret_cast<int64_t *>(lShapeMR.data());
476-
auto lStridesPtr = reinterpret_cast<int64_t *>(lStridesMR.data());
477457

478458
// First we allgather the requested target partitioning
479459

480460
auto myBOff = 2 * rank * me;
481461
::std::vector<int64_t> buff(2 * rank * N);
482462
for (int64_t i = 0; i < rank; ++i) {
483-
// assert(offsPtr[i] - lOffs[i] + szsPtr[i] <= gShape[i]);
484-
buff[myBOff + i] = offsMR[i];
485-
buff[myBOff + i + rank] = szsMR[i];
463+
buff[myBOff + i] = oOffsPtr[i];
464+
buff[myBOff + i + rank] = oShapePtr[i];
486465
}
487466
::std::vector<int> counts(N, rank * 2);
488467
::std::vector<int> dspl(N);
@@ -493,8 +472,8 @@ void _idtr_repartition(int64_t gShapeRank, void *gShapeDesc, int dtype,
493472

494473
// compute overlap of my local data with each requested part
495474

496-
auto myOff = static_cast<int64_t>(lOffsMR[0]);
497-
auto mySz = static_cast<int64_t>(lShapeMR[0]);
475+
auto myOff = static_cast<int64_t>(lOffsPtr[0]);
476+
auto mySz = static_cast<int64_t>(lShapePtr[0]);
498477
auto myEnd = myOff + mySz;
499478
auto myTileSz = std::accumulate(&lShapePtr[1], &lShapePtr[rank], 1,
500479
std::multiplies<int64_t>());
@@ -537,7 +516,7 @@ void _idtr_repartition(int64_t gShapeRank, void *gShapeDesc, int dtype,
537516
for (auto r = 1; r < rank; ++r) {
538517
tStarts[i * rank + r] = buff[2 * rank * i + r];
539518
tSizes[i * rank + r] = buff[2 * rank * i + rank + r];
540-
// assert(tSizes[i*rank+r] <= lShapeMR[r]);
519+
// assert(tSizes[i*rank+r] <= lShapePtr[r]);
541520
}
542521
}
543522

@@ -568,6 +547,47 @@ void _idtr_repartition(int64_t gShapeRank, void *gShapeDesc, int dtype,
568547
}
569548
}
570549

550+
/// @brief templated wrapper for typed function versions calling
551+
/// _idtr_repartition
552+
template <typename T>
553+
void _idtr_repartition(int64_t gShapeRank, void *gShapeDescr, int64_t lOffsRank,
554+
void *lOffsDescr, int64_t lRank, void *lDescr,
555+
int64_t oOffsRank, void *oOffsDescr, int64_t oRank,
556+
void *oDescr, Transceiver *tc) {
557+
558+
auto ddpttype = DTYPE<T>::value;
559+
560+
// Construct unranked memrefs for metadata
561+
UnrankedMemRefType<T> lData(lRank, lDescr);
562+
UnrankedMemRefType<T> oData(oRank, oDescr);
563+
564+
_idtr_repartition(ddpttype, lRank, lData.data(), lData.sizes(),
565+
lData.strides(), MRIdx1d(lOffsRank, lOffsDescr).data(),
566+
oData.data(), oData.sizes(),
567+
MRIdx1d(oOffsRank, oOffsDescr).data(), tc);
568+
}
569+
570+
extern "C" {
571+
#define TYPED_REPARTITON(_sfx, _typ) \
572+
void _idtr_repartition_##_sfx( \
573+
int64_t gShapeRank, void *gShapeDescr, int64_t lOffsRank, \
574+
void *lOffsDescr, int64_t rank, void *lDescr, int64_t oOffsRank, \
575+
void *oOffsDescr, int64_t oRank, void *oDescr, Transceiver *tc) { \
576+
_idtr_repartition<_typ>(gShapeRank, gShapeDescr, lOffsRank, lOffsDescr, \
577+
rank, lDescr, oOffsRank, oOffsDescr, oRank, \
578+
oDescr, tc); \
579+
}
580+
581+
TYPED_REPARTITON(f64, double);
582+
TYPED_REPARTITON(f32, float);
583+
TYPED_REPARTITON(i64, int64_t);
584+
TYPED_REPARTITON(i32, int32_t);
585+
TYPED_REPARTITON(i16, int16_t);
586+
TYPED_REPARTITON(i8, int8_t);
587+
TYPED_REPARTITON(i1, bool);
588+
589+
} // extern "C"
590+
571591
// debug helper
572592
void _idtr_extractslice(int64_t *slcOffs, int64_t *slcSizes,
573593
int64_t *slcStrides, int64_t *tOffs, int64_t *tSizes,
@@ -595,5 +615,6 @@ void _idtr_extractslice(int64_t *slcOffs, int64_t *slcSizes,
595615
<< std::endl;
596616
}
597617

618+
extern "C" {
598619
void _debugFunc() { std::cerr << "_debugfunc\n"; }
599620
} // extern "C"

src/include/ddptensor/MemRefType.hpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,23 @@ template <typename T, size_t N> struct MemRefDescriptor {
1111
intptr_t strides[N] = {0};
1212
};
1313

14+
// Use with care.
15+
template <typename T> class UnrankedMemRefType {
16+
int64_t _rank;
17+
intptr_t *_descriptor;
18+
19+
public:
20+
UnrankedMemRefType(int64_t rank, void *p)
21+
: _rank(rank), _descriptor(reinterpret_cast<intptr_t *>(p)){};
22+
23+
T *data() { return reinterpret_cast<T *>(_descriptor[1]); };
24+
int64_t rank() const { return _rank; }
25+
int64_t *sizes() { return reinterpret_cast<int64_t *>(&_descriptor[3]); };
26+
int64_t *strides() {
27+
return reinterpret_cast<int64_t *>(&_descriptor[3 + _rank]);
28+
};
29+
};
30+
1431
template <typename T> struct Unranked1DMemRefType {
1532
MemRefDescriptor<T, 1> *descriptor;
1633

0 commit comments

Comments
 (0)