Skip to content
This repository was archived by the owner on Jan 26, 2026. It is now read-only.

Commit 217e862

Browse files
committed
adjusting to new imex; support getitem resulting in 0d tensor; let idtr work outside of ddpt
1 parent dd0b3b8 commit 217e862

File tree

14 files changed

+141
-65
lines changed

14 files changed

+141
-65
lines changed

ddptensor/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def to_numpy(a):
5050
if not op.startswith("__"):
5151
OP = op.upper()
5252
exec(
53-
f"{op} = lambda this, other: dtensor(_cdt.EWBinOp.op(_cdt.{OP}, this._t, other._t if isinstance(other, ddptensor) else other))"
53+
f"{op} = lambda this, other: dtensor(_cdt.EWBinOp.op(_cdt.{OP}, this._t if isinstance(this, ddptensor) else this, other._t if isinstance(other, ddptensor) else other))"
5454
)
5555

5656
for op in api.api_categories["EWUnyOp"]:

ddptensor/ddptensor.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,11 @@ def _inplace(self, t):
5151
)
5252

5353
def __getitem__(self, key):
54-
return dtensor(self._t.__getitem__(key if isinstance(key, tuple) else (key,)))
54+
key = key if isinstance(key, tuple) else (key,)
55+
key = [x if isinstance(x, slice) else slice(x, x+1, 1) for x in key]
56+
return dtensor(self._t.__getitem__(key))
5557

5658
def __setitem__(self, key, value):
57-
self._t.__setitem__(key if isinstance(key, tuple) else (key,), value._t) # if isinstance(value, dtensor) else value)
59+
key = key if isinstance(key, tuple) else (key,)
60+
key = [x if isinstance(x, slice) else slice(x, x+1, 1) for x in key]
61+
self._t.__setitem__(key, value._t) # if isinstance(value, dtensor) else value)

src/Creator.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -241,14 +241,14 @@ ddptensor * Creator::arange(uint64_t start, uint64_t end, uint64_t step, DTypeId
241241
return new ddptensor(defer<DeferredArange>(start, end, step, dtype, team));
242242
}
243243

244-
ddptensor * Creator::mk_future(const py::object & b)
244+
std::pair<ddptensor *, bool> Creator::mk_future(const py::object & b)
245245
{
246246
if(py::isinstance<ddptensor>(b)) {
247-
return b.cast<ddptensor*>();
247+
return {b.cast<ddptensor*>(), false};
248248
} else if(py::isinstance<py::float_>(b)) {
249-
return Creator::full({}, b, FLOAT64);
249+
return {Creator::full({}, b, FLOAT64), true};
250250
} else if(py::isinstance<py::int_>(b)) {
251-
return Creator::full({}, b, INT64);
251+
return {Creator::full({}, b, INT64), true};
252252
}
253253
throw std::runtime_error("Invalid right operand to elementwise binary operation");
254254
};

src/Deferred.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ void process_promises()
107107
// create return statement and adjust function type
108108
uint64_t osz = dm.handleResult(builder);
109109
// also request generation of c-wrapper function
110-
function->setAttr(::mlir::LLVM::LLVMDialect::getEmitCWrapperAttrName(), ::mlir::UnitAttr::get(&jit._context));
110+
function->setAttr(::mlir::LLVM::LLVMDialect::getEmitCWrapperAttrName(), builder.getUnitAttr());
111111
function.getFunctionType().dump(); std::cout << std::endl;
112112
// add the function to the module
113113
module.push_back(function);

src/EWBinOp.cpp

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,7 @@ struct DeferredEWBinOp : public Deferred
444444

445445
DeferredEWBinOp() = default;
446446
DeferredEWBinOp(EWBinOpId op, const tensor_i::future_type & a, const tensor_i::future_type & b)
447-
: Deferred(a.dtype(), a.rank(), true),
447+
: Deferred(a.dtype(), std::max(a.rank(), b.rank()), true),
448448
_a(a.id()), _b(b.id()), _op(op)
449449
{}
450450

@@ -462,12 +462,13 @@ struct DeferredEWBinOp : public Deferred
462462
// FIXME the type of the result is based on a only
463463
auto av = dm.getDependent(builder, _a);
464464
auto bv = dm.getDependent(builder, _b);
465-
466-
auto aPtTyp = ::imex::dist::getPTensorType(av);
467-
assert(aPtTyp);
465+
466+
auto aTyp = ::imex::dist::getPTensorType(av);
467+
::mlir::SmallVector<int64_t> shape(rank(), ::mlir::ShapedType::kDynamic);
468+
auto outTyp = ::imex::ptensor::PTensorType::get(shape, aTyp.getElementType());
468469

469470
dm.addVal(this->guid(),
470-
builder.create<::imex::ptensor::EWBinOp>(loc, aPtTyp, builder.getI32IntegerAttr(ddpt2mlir(_op)), av, bv),
471+
builder.create<::imex::ptensor::EWBinOp>(loc, outTyp, builder.getI32IntegerAttr(ddpt2mlir(_op)), av, bv),
471472
[this](Transceiver * transceiver, uint64_t rank, void *allocated, void *aligned, intptr_t offset, const intptr_t * sizes, const intptr_t * strides,
472473
uint64_t * gs_allocated, uint64_t * gs_aligned, uint64_t * lo_allocated, uint64_t * lo_aligned, uint64_t balanced) {
473474
this->set_value(std::move(mk_tnsr(transceiver, _dtype, rank, allocated, aligned, offset, sizes, strides,
@@ -490,13 +491,17 @@ struct DeferredEWBinOp : public Deferred
490491
}
491492
};
492493

493-
ddptensor * EWBinOp::op(EWBinOpId op, const ddptensor & a, const py::object & b)
494+
ddptensor * EWBinOp::op(EWBinOpId op, const py::object & a, const py::object & b)
494495
{
495496
auto bb = Creator::mk_future(b);
497+
auto aa = Creator::mk_future(a);
496498
if(op == __MATMUL__) {
497-
return LinAlgOp::vecdot(a, *bb, 0);
499+
return LinAlgOp::vecdot(*aa.first, *bb.first, 0);
498500
}
499-
return new ddptensor(defer<DeferredEWBinOp>(op, a.get(), bb->get()));
501+
auto res = new ddptensor(defer<DeferredEWBinOp>(op, aa.first->get(), bb.first->get()));
502+
if(aa.second) delete aa.first;
503+
if(bb.second) delete bb.first;
504+
return res;
500505
}
501506

502507
FACTORY_INIT(DeferredEWBinOp, F_EWBINOP);

src/IEWBinOp.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,9 @@ struct DeferredIEWBinOp : public Deferred
112112
ddptensor * IEWBinOp::op(IEWBinOpId op, ddptensor & a, const py::object & b)
113113
{
114114
auto bb = Creator::mk_future(b);
115-
return new ddptensor(defer<DeferredIEWBinOp>(op, a.get(), bb->get()));
115+
auto res = new ddptensor(defer<DeferredIEWBinOp>(op, a.get(), bb.first->get()));
116+
if(bb.second) delete bb.first;
117+
return res;
116118
}
117119

118120
FACTORY_INIT(DeferredIEWBinOp, F_IEWBINOP);

src/ReduceOp.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ struct DeferredReduceOp : public Deferred
126126
assert(aPtTyp);
127127
::mlir::Type dtype = aPtTyp.getElementType();
128128
// return type 0d with same dtype as input
129-
auto retPtTyp = ::imex::ptensor::PTensorType::get(builder.getContext(), 0, dtype, false);
129+
auto retPtTyp = ::imex::ptensor::PTensorType::get({::mlir::ShapedType::kDynamic}, dtype);
130130
// reduction op
131131
auto mop = ddpt2mlir(_op);
132132
auto op = builder.getIntegerAttr(builder.getIntegerType(sizeof(mop)*8), mop);

src/SetGetItem.cpp

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -280,22 +280,28 @@ struct DeferredGetItem : public Deferred
280280
auto & strides = _slc.strides();
281281
auto nd = offs.size();
282282
// convert C++ slices into vectors of MLIR Values
283-
std::vector<::mlir::Value> offsV(nd);
284-
std::vector<::mlir::Value> sizesV(nd);
285-
std::vector<::mlir::Value> stridesV(nd);
283+
std::vector<::mlir::OpFoldResult> offsV(nd);
284+
std::vector<::mlir::OpFoldResult> sizesV(nd);
285+
std::vector<::mlir::OpFoldResult> stridesV(nd);
286+
::mlir::SmallVector<int64_t> shape(nd, ::mlir::ShapedType::kDynamic);
286287
for(auto i = 0; i<nd; ++i) {
287288
offsV[i] = ::imex::createIndex(loc, builder, offs[i]);
288-
sizesV[i] = ::imex::createIndex(loc, builder, sizes[i]);
289289
stridesV[i] = ::imex::createIndex(loc, builder, strides[i]);
290+
if(sizes[i] == 1) {
291+
sizesV[i] = builder.getIndexAttr(sizes[i]);
292+
shape[i] = sizes[i];
293+
} else {
294+
sizesV[i] = ::imex::createIndex(loc, builder, sizes[i]);
295+
}
290296
}
297+
298+
auto oTyp = ::imex::dist::getPTensorType(av);
299+
// auto outnd = nd == 0 || _slc.size() == 1 ? 0 : nd;
300+
auto outTyp = ::imex::ptensor::PTensorType::get(shape, oTyp.getElementType());
291301
// now we can create the PTensor op using the above Values
302+
auto res = builder.create<::imex::ptensor::SubviewOp>(loc, outTyp, av, offsV, sizesV, stridesV);
292303
dm.addVal(this->guid(),
293-
builder.create<::imex::ptensor::ExtractSliceOp>(loc,
294-
::imex::dist::getPTensorType(av),
295-
av,
296-
offsV,
297-
sizesV,
298-
stridesV),
304+
res,
299305
[this, dtype](Transceiver * transceiver, uint64_t rank, void *allocated, void *aligned, intptr_t offset, const intptr_t * sizes, const intptr_t * strides,
300306
uint64_t * gs_allocated, uint64_t * gs_aligned, uint64_t * lo_allocated, uint64_t * lo_aligned, uint64_t balanced) {
301307
this->set_value(std::move(mk_tnsr(transceiver, dtype, rank, allocated, aligned, offset, sizes, strides,

src/idtr.cpp

Lines changed: 87 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -31,17 +31,33 @@ T * mr_to_ptr(void * ptr, intptr_t offset)
3131

3232
extern "C" {
3333

34+
#define NO_TRANSCEIVER
35+
#ifdef NO_TRANSCEIVER
36+
static void initMPIRuntime() {
37+
if(getTransceiver() == nullptr)
38+
init_transceiver(new MPITransceiver(false));
39+
}
40+
#endif
41+
3442
// Return number of ranks/processes in given team/communicator
35-
uint64_t idtr_nprocs(int64_t team)
43+
uint64_t idtr_nprocs(Transceiver * tc)
3644
{
37-
return getTransceiver()->nranks();
45+
#ifdef NO_TRANSCEIVER
46+
initMPIRuntime();
47+
tc = getTransceiver();
48+
#endif
49+
return tc->nranks();
3850
}
3951
#pragma weak _idtr_nprocs = idtr_nprocs
4052

4153
// Return rank in given team/communicator
42-
uint64_t idtr_prank(int64_t team)
54+
uint64_t idtr_prank(Transceiver * tc)
4355
{
44-
return getTransceiver()->rank();
56+
#ifdef NO_TRANSCEIVER
57+
initMPIRuntime();
58+
tc = getTransceiver();
59+
#endif
60+
return tc->rank();
4561
}
4662
#pragma weak _idtr_prank = idtr_prank
4763

@@ -173,7 +189,9 @@ void forall(uint64_t d, const T * cptr, const int64_t * sizes, const int64_t * s
173189
}
174190
} else {
175191
for(auto i=0; i<sz; ++i) {
192+
const T * tmp = cptr;
176193
forall(d+1, cptr, sizes, strides, nd, op);
194+
cptr = tmp + strides[d];
177195
}
178196
}
179197
}
@@ -190,20 +208,26 @@ bool is_contiguous(const int64_t * sizes, const int64_t * strides, uint64_t nd)
190208
return true;
191209
}
192210

193-
void * bufferize(void * cptr, DTypeId dtype, const int64_t * sizes, const int64_t * strides, uint64_t nd, void * out)
194-
{
195-
if(is_contiguous(sizes, strides, nd)) {
196-
return cptr;
197-
} else {
198-
dispatch(dtype, cptr, [sizes, strides, nd, out](auto * ptr) {
199-
auto buff = static_cast<decltype(ptr)>(out);
200-
forall(0, ptr, sizes, strides, nd, [&buff](const auto * in) {
201-
*buff = *in;
202-
++buff;
203-
});
204-
});
205-
return out;
206-
}
211+
void bufferize(void * cptr, DTypeId dtype, const int64_t * sizes, const int64_t * strides, const int64_t * tStarts, const int64_t * tSizes, uint64_t nd, uint64_t N, void * out)
212+
{
213+
dispatch(dtype, cptr, [sizes, strides, tStarts, tSizes, nd, N, out](auto * ptr) {
214+
auto buff = static_cast<decltype(ptr)>(out);
215+
216+
for(auto i=0; i<N; ++i) {
217+
auto szs = &tSizes[i*nd];
218+
if(szs[0] > 0) {
219+
auto sts = &tStarts[i*nd];
220+
uint64_t off = 0;
221+
for(int64_t r=0; r<nd; ++r) {
222+
off += sts[r] * strides[r];
223+
}
224+
forall(0, &ptr[off], szs, strides, nd, [&buff](const auto * in) {
225+
*buff = *in;
226+
++buff;
227+
});
228+
}
229+
}
230+
});
207231
}
208232

209233
extern "C" {
@@ -223,6 +247,7 @@ void _idtr_reduce_all(uint64_t rank, void * data, const int64_t * sizes, const i
223247
mlir2ddpt(static_cast<imex::ptensor::ReduceOpId>(op)));
224248
}
225249

250+
#if 0
226251
void _idtr_rebalance(uint64_t rank, const int64_t * gShape, const int64_t * lOffs,
227252
void * data, const int64_t * sizes, const int64_t * strides, int dtype,
228253
uint64_t outRank, void * out, const int64_t * outSizes, const int64_t * outStrides)
@@ -269,7 +294,7 @@ void _idtr_rebalance(uint64_t rank, const int64_t * gShape, const int64_t * lOff
269294
// Finally communicate elements
270295
getTransceiver()->alltoall(ptr, sszs.data(), soffs.data(), ddpttype, out, rszs.data(), roffs.data());
271296
}
272-
297+
#endif
273298

274299
/// @brief repartition tensor
275300
/// We assume tensor is partitioned along the first dimension (only) and partitions are ordered by ranks
@@ -288,18 +313,20 @@ void _idtr_repartition(int64_t rank, int64_t * gShapePtr, int dtype,
288313
void * lDataPtr, int64_t * lOffsPtr, int64_t * lShapePtr, int64_t * lStridesPtr,
289314
int64_t * offsPtr, int64_t * szsPtr, void * outPtr, Transceiver * tc)
290315
{
291-
assert(is_contiguous(lShapePtr, lStridesPtr, rank));
292-
316+
#ifdef NO_TRANSCEIVER
317+
initMPIRuntime();
318+
tc = getTransceiver();
319+
#endif
293320
auto N = tc->nranks();
294321
auto me = tc->rank();
295322
auto ddpttype = mlir2ddpt(static_cast<::imex::ptensor::DType>(dtype));
296-
auto nSz = std::accumulate(&lShapePtr[1], &lShapePtr[rank], 1, std::multiplies<int64_t>());
297323

298324
// First we allgather the requested target partitioning
299325

300326
auto myBOff = 2 * rank * me;
301327
::std::vector<int64_t> buff(2*rank*N);
302328
for(int64_t i=0; i<rank; ++i) {
329+
// assert(offsPtr[i] - lOffsPtr[i] + szsPtr[i] <= gShapePtr[i]);
303330
buff[myBOff+i] = offsPtr[i];
304331
buff[myBOff+i+rank] = szsPtr[i];
305332
}
@@ -315,24 +342,44 @@ void _idtr_repartition(int64_t rank, int64_t * gShapePtr, int dtype,
315342
auto myOff = lOffsPtr[0];
316343
auto mySz = lShapePtr[0];
317344
auto myEnd = myOff + mySz;
345+
auto myTileSz = std::accumulate(&lShapePtr[1], &lShapePtr[rank], 1, std::multiplies<int64_t>());
318346

319347
std::vector<int> soffs(N);
320348
std::vector<int> sszs(N, 0);
349+
std::vector<int64_t> tStarts(N*rank, 0);
350+
std::vector<int64_t> tSizes(N*rank, 0);
351+
std::vector<int64_t> nSizes(N);
352+
int64_t totSSz = 0;
353+
bool needsBufferize = !is_contiguous(lShapePtr, lStridesPtr, rank);
321354

322355
for(auto i=0; i<N; ++i) {
356+
nSizes[i] = std::accumulate(&buff[2*rank*i+rank+1], &buff[2*rank*i+rank+rank], 1, std::multiplies<int64_t>());
357+
if(nSizes[i] != myTileSz) needsBufferize = true;
358+
}
359+
for(auto i=0; i<N; ++i) {
360+
auto nSz = nSizes[i];
323361
auto tOff = buff[2*rank*i];
324362
auto tSz = buff[2*rank*i+rank];
325363
auto tEnd = tOff + tSz;
364+
326365
if(tEnd > myOff && tOff < myEnd) {
327366
// We have a target partition which is inside my local data
328367
// we now compute what data goes to this target partition
329368
auto start = std::max(myOff, tOff);
330369
auto end = std::min(myEnd, tEnd);
331-
soffs[i] = (int)(start - myOff) * nSz;
370+
tStarts[i*rank] = start - myOff;
371+
tSizes[i*rank] = end - start;
372+
soffs[i] = needsBufferize ? (i ? soffs[i-1] + sszs[i-1] : 0) : (int)(start - myOff) * myTileSz;
332373
sszs[i] = (int)(end - start) * nSz;
333374
} else {
334375
soffs[i] = i ? soffs[i-1] + sszs[i-1] : 0;
335376
}
377+
totSSz += sszs[i];
378+
for(auto r=1; r<rank; ++r) {
379+
tStarts[i*rank+r] = buff[2*rank*i+r];
380+
tSizes[i*rank+r] = buff[2*rank*i+rank+r];
381+
// assert(tSizes[i*rank+r] <= lShapePtr[r]);
382+
}
336383
}
337384

338385
// send our send sizes to others and receive theirs
@@ -348,7 +395,15 @@ void _idtr_repartition(int64_t rank, int64_t * gShapePtr, int dtype,
348395
}
349396

350397
// Finally communicate elements
351-
getTransceiver()->alltoall(lDataPtr, sszs.data(), soffs.data(), ddpttype, outPtr, rszs.data(), roffs.data());
398+
if(needsBufferize) {
399+
// create send buffer if strided
400+
Buffer buff(totSSz * sizeof_dtype(ddpttype), 2);
401+
bufferize(lDataPtr, ddpttype, lShapePtr, lStridesPtr, tStarts.data(), tSizes.data(), rank, N, buff.data());
402+
getTransceiver()->alltoall(buff.data(), sszs.data(), soffs.data(), ddpttype, outPtr, rszs.data(), roffs.data());
403+
std::cerr << "yey\n";
404+
} else {
405+
getTransceiver()->alltoall(lDataPtr, sszs.data(), soffs.data(), ddpttype, outPtr, rszs.data(), roffs.data());
406+
}
352407
}
353408

354409
void _idtr_extractslice(int64_t * slcOffs,
@@ -360,13 +415,13 @@ void _idtr_extractslice(int64_t * slcOffs,
360415
int64_t * lSlcSizes,
361416
int64_t * gSlcOffsets)
362417
{
363-
std::cerr << "slcOffs: " << slcOffs[0] << " " << slcOffs[1] << std::endl;
364-
std::cerr << "slcSizes: " << slcSizes[0] << " " << slcSizes[1] << std::endl;
365-
std::cerr << "slcStrides: " << slcStrides[0] << " " << slcStrides[1] << std::endl;
366-
std::cerr << "tOffs: " << tOffs[0] << " " << tOffs[1] << std::endl;
367-
std::cerr << "tSizes: " << tSizes[0] << " " << tSizes[1] << std::endl;
368-
std::cerr << "lSlcOffsets: " << lSlcOffsets[0] << " " << lSlcOffsets[1] << std::endl;
369-
std::cerr << "lSlcSizes: " << lSlcSizes[0] << " " << lSlcSizes[1] << std::endl;
370-
std::cerr << "gSlcOffsets: " << gSlcOffsets[0] << " " << gSlcOffsets[1] << std::endl;
418+
if(slcOffs) std::cerr << "slcOffs: " << slcOffs[0] << " " << slcOffs[1] << std::endl;
419+
if(slcSizes) std::cerr << "slcSizes: " << slcSizes[0] << " " << slcSizes[1] << std::endl;
420+
if(slcStrides) std::cerr << "slcStrides: " << slcStrides[0] << " " << slcStrides[1] << std::endl;
421+
if(tOffs) std::cerr << "tOffs: " << tOffs[0] << " " << tOffs[1] << std::endl;
422+
if(tSizes) std::cerr << "tSizes: " << tSizes[0] << " " << tSizes[1] << std::endl;
423+
if(lSlcOffsets) std::cerr << "lSlcOffsets: " << lSlcOffsets[0] << " " << lSlcOffsets[1] << std::endl;
424+
if(lSlcSizes) std::cerr << "lSlcSizes: " << lSlcSizes[0] << " " << lSlcSizes[1] << std::endl;
425+
if(gSlcOffsets) std::cerr << "gSlcOffsets: " << gSlcOffsets[0] << " " << gSlcOffsets[1] << std::endl;
371426
}
372427
} // extern "C"

src/include/ddptensor/Creator.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,5 @@ struct Creator
1111
static ddptensor * create_from_shape(CreatorId op, const shape_type & shape, DTypeId dtype=FLOAT64);
1212
static ddptensor * full(const shape_type & shape, const py::object & val, DTypeId dtype=FLOAT64);
1313
static ddptensor * arange(uint64_t start, uint64_t end, uint64_t step, DTypeId dtype=INT64, uint64_t team=0);
14-
static ddptensor * mk_future(const py::object & b);
14+
static std::pair<ddptensor *, bool> mk_future(const py::object & b);
1515
};

0 commit comments

Comments
 (0)