Skip to content
This repository was archived by the owner on Jan 26, 2026. It is now read-only.

Commit 50082ce

Browse files
authored
adjusting to new callback signature for 1dmemrefs (#19)
* adjusting to new callback signature for 1dmemrefs; fixing gathering 0d array * fixing handling of temporary base tensors
1 parent 7b33242 commit 50082ce

File tree

8 files changed

+109
-60
lines changed

8 files changed

+109
-60
lines changed

CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ set(CMAKE_C_EXTENSIONS OFF)
3939
set(CMAKE_CXX_EXTENSIONS OFF)
4040
set(CMAKE_CXX_STANDARD_REQUIRED ON)
4141

42+
# Expected LLVM SHA
43+
file(STRINGS ${CMAKE_CURRENT_SOURCE_DIR}/imex_version.txt EXPECTED_IMEX_SHA)
44+
message(STATUS "Expected IMEX sha: \"${EXPECTED_IMEX_SHA}\"")
45+
4246
# Common installation directories
4347
#include(GNUInstallDirs)
4448

imex_version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
94418f99f4b58eb57cc661057956eb36e2fce66b
1+
524df1c49aed52259f9ef8cea018c123b0bcada3

src/CollComm.cpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,16 @@
55
void bufferize(DDPTensorImpl::ptr_type a_ptr, void *outPtr) {
66
dispatch(a_ptr->dtype(), a_ptr->data(), [&a_ptr, outPtr](auto *ptr) {
77
auto buff = static_cast<decltype(ptr)>(outPtr);
8-
9-
forall(0, ptr, a_ptr->local_shape(), a_ptr->local_strides(), a_ptr->ndims(),
10-
[&buff](const auto *in) {
11-
*buff = *in;
12-
++buff;
13-
});
8+
auto shp = a_ptr->local_shape();
9+
if (shp) {
10+
forall(0, ptr, shp, a_ptr->local_strides(), a_ptr->ndims(),
11+
[&buff](const auto *in) {
12+
*buff = *in;
13+
++buff;
14+
});
15+
} else {
16+
buff[0] = ptr[0];
17+
}
1418
});
1519
}
1620

@@ -20,7 +24,7 @@ void gather_tensor(DDPTensorImpl::ptr_type a_ptr, rank_type root,
2024
void *outPtr) {
2125
auto trscvr = a_ptr->transceiver();
2226

23-
if (!trscvr) {
27+
if (!trscvr || a_ptr->owner() == REPLICATED) {
2428
bufferize(a_ptr, outPtr);
2529
return;
2630
}

src/ManipOp.cpp

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,22 +44,21 @@ struct DeferredReshape : public Deferred {
4444
auto op =
4545
builder.create<::imex::ptensor::ReshapeOp>(loc, outTyp, av, shp, copyA);
4646

47-
auto future_a = Registry::get(_a);
48-
4947
dm.addVal(this->guid(), op,
50-
[this, future_a](Transceiver *transceiver, uint64_t rank,
51-
void *allocated, void *aligned, intptr_t offset,
52-
const intptr_t *sizes, const intptr_t *strides,
53-
int64_t *gs_allocated, int64_t *gs_aligned,
54-
uint64_t *lo_allocated, uint64_t *lo_aligned,
55-
uint64_t balanced) {
48+
[this](Transceiver *transceiver, uint64_t rank, void *allocated,
49+
void *aligned, intptr_t offset, const intptr_t *sizes,
50+
const intptr_t *strides, int64_t *gs_allocated,
51+
int64_t *gs_aligned, uint64_t *lo_allocated,
52+
uint64_t *lo_aligned, uint64_t balanced) {
5653
auto t =
5754
mk_tnsr(transceiver, _dtype, rank, allocated, aligned,
5855
offset, sizes, strides, gs_allocated, gs_aligned,
5956
lo_allocated, lo_aligned, balanced);
6057
if (_copy != COPY_ALWAYS) {
6158
assert(!"copy-free reshape not supported");
62-
t->set_base(future_a.get());
59+
if (Registry::has(_a)) {
60+
t->set_base(Registry::get(_a).get());
61+
} // else _a is a temporary and was dropped
6362
}
6463
this->set_value(std::move(t));
6564
});

src/SetGetItem.cpp

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,12 @@ struct DeferredGather
113113
py::object res;
114114
if (!sendonly || !trscvr) {
115115
auto tmp = a_ptr->shape();
116-
res = dispatch<mk_array>(a_ptr->dtype(),
117-
std::vector<ssize_t>(tmp, &tmp[a_ptr->ndims()]),
118-
outPtr);
116+
std::vector<ssize_t> tmpv(tmp, &tmp[a_ptr->ndims()]);
117+
// numpy treats 0d arrays as empty arrays, not as a scalar as we do
118+
if (tmpv.empty()) {
119+
tmpv.emplace_back(1);
120+
}
121+
res = dispatch<mk_array>(a_ptr->dtype(), std::move(tmpv), outPtr);
119122
}
120123

121124
gather_tensor(a_ptr, _root, outPtr);
@@ -293,19 +296,19 @@ struct DeferredGetItem : public Deferred {
293296
auto res = builder.create<::imex::ptensor::SubviewOp>(
294297
loc, outTyp, av, offsV, sizesV, stridesV);
295298

296-
auto future_a = Registry::get(_a);
297-
298299
dm.addVal(
299300
this->guid(), res,
300-
[this, dtype, future_a](
301-
Transceiver *transceiver, uint64_t rank, void *allocated,
302-
void *aligned, intptr_t offset, const intptr_t *sizes,
303-
const intptr_t *strides, int64_t *gs_allocated, int64_t *gs_aligned,
304-
uint64_t *lo_allocated, uint64_t *lo_aligned, uint64_t balanced) {
301+
[this, dtype](Transceiver *transceiver, uint64_t rank, void *allocated,
302+
void *aligned, intptr_t offset, const intptr_t *sizes,
303+
const intptr_t *strides, int64_t *gs_allocated,
304+
int64_t *gs_aligned, uint64_t *lo_allocated,
305+
uint64_t *lo_aligned, uint64_t balanced) {
305306
auto t = mk_tnsr(transceiver, dtype, rank, allocated, aligned, offset,
306307
sizes, strides, gs_allocated, gs_aligned,
307308
lo_allocated, lo_aligned, balanced);
308-
t->set_base(future_a.get());
309+
if (Registry::has(_a)) {
310+
t->set_base(Registry::get(_a).get());
311+
} // else _a is a temporary and was dropped
309312
this->set_value(std::move(t));
310313
});
311314
return false;

src/idtr.cpp

Lines changed: 63 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -273,46 +273,84 @@ void bufferizeN(void *cptr, DTypeId dtype, const int64_t *sizes,
273273
});
274274
}
275275

276+
using MRIdx1d = Unranked1DMemRefType<uint64_t>;
277+
276278
extern "C" {
277279
// Elementwise inplace allreduce
278280
void idtr_reduce_all(void *inout, DTypeId dtype, uint64_t N, ReduceOpId op) {
279281
getTransceiver()->reduce_all(inout, dtype, N, op);
280282
}
281283

282284
// FIXME hard-coded for contiguous layout
283-
void _idtr_reduce_all(uint64_t rank, void *data, const int64_t *sizes,
284-
const int64_t *strides, int dtype, int op) {
285-
assert(rank == 0 || strides[rank - 1] == 1);
285+
void _idtr_reduce_all(void *data, int64_t sizesRank, int64_t *sizesDesc,
286+
int64_t stridesRank, int64_t *stridesDesc, int dtype,
287+
int op) {
288+
MRIdx1d sizesMR(sizesRank, sizesDesc);
289+
MRIdx1d stridesMR(stridesRank, stridesDesc);
290+
auto sizes = reinterpret_cast<int64_t *>(sizesMR.data());
291+
auto strides = reinterpret_cast<int64_t *>(stridesMR.data());
292+
auto rank = sizesMR.size();
293+
assert(rank == 0 || (rank == 1 && strides[0] == 1));
286294
idtr_reduce_all(data, mlir2ddpt(static_cast<::imex::ptensor::DType>(dtype)),
287-
rank ? rank : 1,
295+
rank ? sizes[0] : 1,
288296
mlir2ddpt(static_cast<imex::ptensor::ReduceOpId>(op)));
289297
}
290298

291299
/// @brief reshape tensor
292300
/// We assume tensor is partitioned along the first dimension (only) and
293301
/// partitions are ordered by ranks
294-
/// @param rank
295-
/// @param gShapePtr
302+
/// @param gShapeRank
303+
/// @param gShapeDesc
296304
/// @param dtype
297305
/// @param lDataPtr
298-
/// @param lOffsPtr
299-
/// @param lShapePtr
300-
/// @param lStridesPtr
301-
/// @param oRank
302-
/// @param oGShapePtr
303-
/// @param oOffsPtr
304-
/// @param oShapePtr
306+
/// @param lOffsRank
307+
/// @param lOffsDesc
308+
/// @param lShapeRank
309+
/// @param lShapeDesc
310+
/// @param lStridesRank
311+
/// @param lStridesDesc
312+
/// @param oGShapeRank
313+
/// @param oGShapeDesc
314+
/// @param oOffsRank
315+
/// @param oOffsDesc
316+
/// @param oShapeRank
317+
/// @param oShapeDesc
305318
/// @param outPtr
306319
/// @param tc
307-
void _idtr_reshape(int64_t rank, int64_t *gShapePtr, int dtype, void *lDataPtr,
308-
int64_t *lOffsPtr, int64_t *lShapePtr, int64_t *lStridesPtr,
309-
int64_t oRank, int64_t *oGShapePtr, int64_t *oOffsPtr,
310-
int64_t *oShapePtr, void *outPtr, Transceiver *tc) {
320+
void _idtr_reshape(int64_t gShapeRank, int64_t *gShapeDesc, int dtype,
321+
void *lDataPtr, int64_t lOffsRank, int64_t *lOffsDesc,
322+
int64_t lShapeRank, int64_t *lShapeDesc,
323+
int64_t lStridesRank, int64_t *lStridesDesc,
324+
int64_t oGShapeRank, int64_t *oGShapeDesc, int64_t oOffsRank,
325+
int64_t *oOffsDesc, int64_t oShapeRank, int64_t *oShapeDesc,
326+
void *outPtr, Transceiver *tc) {
311327
#ifdef NO_TRANSCEIVER
312328
initMPIRuntime();
313329
tc = getTransceiver();
314330
#endif
315331

332+
assert(1 == gShapeRank && 1 == lOffsRank && 1 == lShapeRank &&
333+
1 == lStridesRank && 1 == oGShapeRank && 1 == oOffsRank &&
334+
1 == oShapeRank);
335+
336+
MRIdx1d gShapeUMR(gShapeRank, gShapeDesc);
337+
MRIdx1d oGShapeUMR(oGShapeRank, oGShapeDesc);
338+
auto rank = gShapeUMR.size();
339+
auto oRank = oGShapeUMR.size();
340+
341+
auto gShapePtr = reinterpret_cast<int64_t *>(gShapeUMR.data());
342+
auto lOffsPtr =
343+
reinterpret_cast<int64_t *>(MRIdx1d(lOffsRank, lOffsDesc).data());
344+
auto lShapePtr =
345+
reinterpret_cast<int64_t *>(MRIdx1d(lShapeRank, lShapeDesc).data());
346+
auto lStridesPtr =
347+
reinterpret_cast<int64_t *>(MRIdx1d(lStridesRank, lStridesDesc).data());
348+
auto oGShapePtr = reinterpret_cast<int64_t *>(oGShapeUMR.data());
349+
auto oOffsPtr =
350+
reinterpret_cast<int64_t *>(MRIdx1d(oOffsRank, oOffsDesc).data());
351+
auto oShapePtr =
352+
reinterpret_cast<int64_t *>(MRIdx1d(oShapeRank, oShapeDesc).data());
353+
316354
assert(std::accumulate(&gShapePtr[0], &gShapePtr[rank], 1,
317355
std::multiplies<int64_t>()) ==
318356
std::accumulate(&oGShapePtr[0], &oGShapePtr[oRank], 1,
@@ -392,12 +430,9 @@ void _idtr_reshape(int64_t rank, int64_t *gShapePtr, int dtype, void *lDataPtr,
392430
rszs.data(), roffs.data());
393431
}
394432

395-
using MRIdx1d = Unranked1DMemRefType<uint64_t>;
396-
397433
/// @brief repartition tensor
398434
/// We assume tensor is partitioned along the first dimension (only) and
399435
/// partitions are ordered by ranks
400-
/// @param rank
401436
/// @param gShapeRank
402437
/// @param gShapeDesc
403438
/// @param dtype
@@ -414,9 +449,9 @@ using MRIdx1d = Unranked1DMemRefType<uint64_t>;
414449
/// @param szsDesc
415450
/// @param outPtr
416451
/// @param tc
417-
void _idtr_repartition(int64_t rank, int64_t gShapeRank, void *gShapeDesc,
418-
int dtype, void *lDataPtr, int64_t lOffsRank,
419-
void *lOffsDesc, int64_t lShapeRank, void *lShapeDesc,
452+
void _idtr_repartition(int64_t gShapeRank, void *gShapeDesc, int dtype,
453+
void *lDataPtr, int64_t lOffsRank, void *lOffsDesc,
454+
int64_t lShapeRank, void *lShapeDesc,
420455
int64_t lStridesRank, void *lStridesDesc,
421456
int64_t offsRank, void *offsDesc, int64_t szsRank,
422457
void *szsDesc, void *outPtr, Transceiver *tc) {
@@ -436,6 +471,7 @@ void _idtr_repartition(int64_t rank, int64_t gShapeRank, void *gShapeDesc,
436471
MRIdx1d offsMR(offsRank, offsDesc);
437472
MRIdx1d szsMR(szsRank, szsDesc);
438473

474+
int64_t rank = gShapeMR.size();
439475
auto lShapePtr = reinterpret_cast<int64_t *>(lShapeMR.data());
440476
auto lStridesPtr = reinterpret_cast<int64_t *>(lStridesMR.data());
441477

@@ -520,10 +556,11 @@ void _idtr_repartition(int64_t rank, int64_t gShapeRank, void *gShapeDesc,
520556
// Finally communicate elements
521557
if (needsBufferize) {
522558
// create send buffer if strided
523-
Buffer buff(totSSz * sizeof_dtype(ddpttype), 2);
559+
Buffer tmpbuff;
560+
tmpbuff.resize(totSSz * sizeof_dtype(ddpttype));
524561
bufferize(lDataPtr, ddpttype, lShapePtr, lStridesPtr, tStarts.data(),
525-
tSizes.data(), rank, N, buff.data());
526-
tc->alltoall(buff.data(), sszs.data(), soffs.data(), ddpttype, outPtr,
562+
tSizes.data(), rank, N, tmpbuff.data());
563+
tc->alltoall(tmpbuff.data(), sszs.data(), soffs.data(), ddpttype, outPtr,
527564
rszs.data(), roffs.data());
528565
} else {
529566
tc->alltoall(lDataPtr, sszs.data(), soffs.data(), ddpttype, outPtr,

src/include/ddptensor/MemRefType.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,10 @@ template <typename T, size_t N> struct MemRefDescriptor {
1212
};
1313

1414
template <typename T> struct Unranked1DMemRefType {
15-
int64_t rank;
1615
MemRefDescriptor<T, 1> *descriptor;
1716

18-
Unranked1DMemRefType(int64_t _rank, void *p)
19-
: rank(_rank), descriptor(static_cast<MemRefDescriptor<T, 1> *>(p)) {
17+
Unranked1DMemRefType(int64_t rank, void *p)
18+
: descriptor(static_cast<MemRefDescriptor<T, 1> *>(p)) {
2019
assert(rank == 1);
2120
};
2221

@@ -25,4 +24,5 @@ template <typename T> struct Unranked1DMemRefType {
2524
return *(d->aligned + d->offset + idx * d->strides[0]);
2625
};
2726
T *data() { return descriptor->aligned; };
27+
int64_t size() { return descriptor->sizes[0]; };
2828
};

test/utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33

44

55
def runAndCompare(func, do_gather=True):
6-
a = func(ddptensor)
7-
if do_gather:
8-
a = ddptensor.spmd.gather(a)
6+
aa = func(ddptensor)
7+
a = ddptensor.spmd.gather(aa) if do_gather else aa
98
b = func(numpy)
109
if isinstance(b, numpy.ndarray):
10+
print(aa)
11+
print(a)
12+
print(b)
1113
return a.shape == b.shape and numpy.allclose(a, b, rtol=1e-8, atol=1e-8)
1214
return float(a) == float(b)
1315

0 commit comments

Comments
 (0)