Skip to content
This repository was archived by the owner on Jan 26, 2026. It is now read-only.

Commit 6f1cb8e

Browse files
committed
collaborating with separate dist/ptensor passes in imex
1 parent 34db4a5 commit 6f1cb8e

File tree

11 files changed

+156
-279
lines changed

11 files changed

+156
-279
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ set(DDPTSrcs
104104
set(IDTRSrcs
105105
${PROJECT_SOURCE_DIR}/src/idtr.cpp
106106
${PROJECT_SOURCE_DIR}/src/CollComm.cpp
107+
${PROJECT_SOURCE_DIR}/src/DDPTensorImpl.cpp
107108
${PROJECT_SOURCE_DIR}/src/Deferred.cpp
108109
${PROJECT_SOURCE_DIR}/src/Factory.cpp
109110
${PROJECT_SOURCE_DIR}/src/Mediator.cpp

src/Creator.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
#include <imex/internal/PassUtils.h>
99

1010
#include <mlir/IR/Builders.h>
11-
#include <mlir/Dialect/Arithmetic/IR/Arithmetic.h>
11+
#include <mlir/Dialect/Arith/IR/Arith.h>
1212
#include <mlir/Dialect/Shape/IR/Shape.h>
1313
#include <mlir/Dialect/Tensor/IR/Tensor.h>
1414
#include <mlir/Dialect/Linalg/IR/Linalg.h>
@@ -171,15 +171,17 @@ struct DeferredArange : public Deferred
171171
auto stop = ::imex::createInt(loc, builder, _end);
172172
auto step = ::imex::createInt(loc, builder, _step);
173173
auto dtype = builder.getI64Type(); // FIXME
174-
auto artype = ::imex::ptensor::PTensorType::get(builder.getContext(), ::mlir::RankedTensorType::get({-1}, dtype), false, true);
174+
auto artype = ::imex::ptensor::PTensorType::get(builder.getContext(), ::mlir::RankedTensorType::get({-1}, dtype), false);
175175
auto dmy = ::imex::createInt<1>(loc, builder, 0);
176-
auto team = ::imex::createInt(loc, builder, 1);
176+
auto team = ::imex::createInt(loc, builder, reinterpret_cast<uint64_t>(getTransceiver()));
177177
dm.addVal(this->guid(),
178178
builder.create<::imex::ptensor::ARangeOp>(loc, artype, start, stop, step, dmy, team),
179-
[this](uint64_t rank, void *allocated, void *aligned, intptr_t offset, const intptr_t * sizes, const intptr_t * strides) {
179+
[this](uint64_t rank, void *allocated, void *aligned, intptr_t offset, const intptr_t * sizes, const intptr_t * strides,
180+
uint64_t * gs_allocated, uint64_t * gs_aligned, uint64_t * lo_allocated, uint64_t * lo_aligned) {
180181
assert(rank == 1);
181182
assert(strides[0] == 1);
182-
this->set_value(std::move(mk_tnsr(_dtype, rank, allocated, aligned, offset, sizes, strides)));
183+
this->set_value(std::move(mk_tnsr(_dtype, rank, allocated, aligned, offset, sizes, strides,
184+
gs_allocated, gs_aligned, lo_allocated, lo_aligned)));
183185
});
184186
return false;
185187
}

src/EWBinOp.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -464,8 +464,10 @@ struct DeferredEWBinOp : public Deferred
464464
auto bv = dm.getDependent(builder, _b);
465465
dm.addVal(this->guid(),
466466
builder.create<::imex::ptensor::EWBinOp>(loc, av.getType(), builder.getI32IntegerAttr(ddpt2mlir(_op)), av, bv),
467-
[this](uint64_t rank, void *allocated, void *aligned, intptr_t offset, const intptr_t * sizes, const intptr_t * strides) {
468-
this->set_value(std::move(mk_tnsr(_dtype, rank, allocated, aligned, offset, sizes, strides)));
467+
[this](uint64_t rank, void *allocated, void *aligned, intptr_t offset, const intptr_t * sizes, const intptr_t * strides,
468+
uint64_t * gs_allocated, uint64_t * gs_aligned, uint64_t * lo_allocated, uint64_t * lo_aligned) {
469+
this->set_value(std::move(mk_tnsr(_dtype, rank, allocated, aligned, offset, sizes, strides,
470+
gs_allocated, gs_aligned, lo_allocated, lo_aligned)));
469471
});
470472
return false;
471473
}

src/ReduceOp.cpp

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "ddptensor/Factory.hpp"
88

99
#include <imex/Dialect/PTensor/IR/PTensorOps.h>
10+
#include <imex/Dialect/Dist/IR/DistOps.h>
1011
#include <mlir/IR/Builders.h>
1112
#include <mlir/Dialect/Shape/IR/Shape.h>
1213

@@ -121,18 +122,25 @@ struct DeferredReduceOp : public Deferred
121122
{
122123
// FIXME reduction over individual dimensions is not supported
123124
auto av = dm.getDependent(builder, _a);
124-
auto aPtTyp = av.getType().dyn_cast<::imex::ptensor::PTensorType>();
125-
assert(aPtTyp);
125+
auto aDtTyp = av.getType().dyn_cast<::imex::dist::DistTensorType>();
126+
::mlir::Type dtype;
127+
if(aDtTyp) {
128+
dtype = aDtTyp.getPTensorType().getRtensor().getElementType();
129+
} else {
130+
auto aPtTyp = av.getType().dyn_cast<::imex::ptensor::PTensorType>();
131+
dtype = aPtTyp.getRtensor().getElementType();
132+
}
126133
// return type 0d with same dtype as input
127-
auto dtype = aPtTyp.getRtensor().getElementType();
128-
auto retPtTyp = ::imex::ptensor::PTensorType::get(builder.getContext(), ::mlir::RankedTensorType::get({}, dtype), false, true);
134+
auto retPtTyp = ::imex::ptensor::PTensorType::get(builder.getContext(), ::mlir::RankedTensorType::get({}, dtype), false);
129135
// reduction op
130136
auto mop = ddpt2mlir(_op);
131137
auto op = builder.getIntegerAttr(builder.getIntegerType(sizeof(mop)*8), mop);
132138
dm.addVal(this->guid(),
133139
builder.create<::imex::ptensor::ReductionOp>(loc, retPtTyp, op, av),
134-
[this](uint64_t rank, void *allocated, void *aligned, intptr_t offset, const intptr_t * sizes, const intptr_t * strides) {
135-
this->set_value(std::move(mk_tnsr(_dtype, rank, allocated, aligned, offset, sizes, strides)));
140+
[this](uint64_t rank, void *allocated, void *aligned, intptr_t offset, const intptr_t * sizes, const intptr_t * strides,
141+
uint64_t * gs_allocated, uint64_t * gs_aligned, uint64_t * lo_allocated, uint64_t * lo_aligned) {
142+
this->set_value(std::move(mk_tnsr(_dtype, rank, allocated, aligned, offset, sizes, strides,
143+
gs_allocated, gs_aligned, lo_allocated, lo_aligned)));
136144
});
137145
return false;
138146
}

src/idtr.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,20 @@ T * mr_to_ptr(void * ptr, intptr_t offset)
2929

3030
extern "C" {
3131

32+
// Return number of ranks/processes in given team/communicator
33+
uint64_t idtr_nprocs(int64_t team)
34+
{
35+
return getTransceiver()->nranks();
36+
}
37+
#pragma weak _idtr_nprocs = idtr_nprocs
38+
39+
// Return rank in given team/communicator
40+
uint64_t idtr_prank(int64_t team)
41+
{
42+
return getTransceiver()->rank();
43+
}
44+
#pragma weak _idtr_prank = idtr_prank
45+
3246
// Register a global tensor of given shape.
3347
// Returns guid.
3448
// The runtime does not own or manage any memory.

src/include/ddptensor/CppTypes.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,3 +148,7 @@ enum FactoryId : int {
148148
F_TONUMPY,
149149
FACTORY_LAST
150150
};
151+
152+
// size of memreftype in number of intptr_t's
153+
inline uint64_t memref_sz(int rank) { return 3 + 2 * rank; }
154+
inline uint64_t dtensor_sz(int rank) { return 2 * memref_sz(1) + memref_sz(rank) + 1; };

0 commit comments

Comments
 (0)