Skip to content
This repository was archived by the owner on Jan 26, 2026. It is now read-only.

Commit 426d5b4

Browse files
fschlimbtkarna
andauthored
Unitsize (#37)
* updates for latest imex * generating correct GShape in DistTensorType args and for subview/getitem --------- Co-authored-by: Tuomas Karna <tuomas.karna@intel.com>
1 parent 8c6e63d commit 426d5b4

File tree

6 files changed

+20
-19
lines changed

6 files changed

+20
-19
lines changed

CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,6 @@ target_link_libraries(_ddpt_rt PRIVATE
186186
MLIRFuncToLLVM
187187
MLIRFuncTransforms
188188
MLIRLinalgDialect
189-
MLIRLinalgToLLVM
190189
MLIRLinalgTransforms
191190
MLIRLLVMDialect
192191
MLIRMathDialect

examples/stencil-2d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def main():
181181
for s in range(-r, r + 1):
182182
for t in range(-r, r + 1):
183183
B[r:b, r:b] += W[r + t, r + s] * A[r + t : b + t, r + s : b + s]
184-
A = A + 1.0
184+
A[:, :] = A + 1.0
185185

186186
np.sync()
187187
t1 = timer()

imex_version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
cfe20696de8126dee8fcd596ee8e4cc8522977f7
1+
e2ecbb3c08e4a87256cdf11d8057fe765d3a04b8

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ def build_cmake(self, ext):
4040
build_args = [
4141
"--config",
4242
config,
43-
# '--', '-j8'
43+
"-j8"
44+
# '--', '-j4'
4445
]
4546

4647
os.chdir(str(build_temp))

src/SetGetItem.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -280,9 +280,10 @@ struct DeferredGetItem : public Deferred {
280280
// get params and extract offsets/sizes/strides
281281
const auto dtype = this->dtype();
282282
auto av = dm.getDependent(builder, _a);
283-
auto &offs = _slc.offsets();
284-
auto &sizes = _slc.sizes();
285-
auto &strides = _slc.strides();
283+
const auto &offs = _slc.offsets();
284+
const auto &sizes =
285+
shape(); // we already converted ALL_SIZE as much as posible
286+
const auto &strides = _slc.strides();
286287
auto nd = offs.size();
287288
// convert C++ slices into vectors of MLIR Values
288289
std::vector<::mlir::OpFoldResult> offsV(nd);
@@ -291,7 +292,7 @@ struct DeferredGetItem : public Deferred {
291292
for (auto i = 0; i < nd; ++i) {
292293
offsV[i] = ::imex::createIndex(loc, builder, offs[i]);
293294
stridesV[i] = ::imex::createIndex(loc, builder, strides[i]);
294-
if (sizes[i] == ALL_SIZE) {
295+
if (sizes[i] < 0) {
295296
sizesV[i] =
296297
builder.create<::imex::ptensor::DimOp>(loc, av, i).getResult();
297298
} else {

src/jit/mlir.cpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -117,9 +117,10 @@ static ::mlir::Type makeSignlessType(::mlir::Type type) {
117117

118118
// convert ddpt's DTYpeId into MLIR type
119119
static ::mlir::Type getTType(::mlir::OpBuilder &builder, DTypeId dtype,
120-
::mlir::SmallVector<int64_t> &lhShape,
121-
::mlir::SmallVector<int64_t> &ownShape,
122-
::mlir::SmallVector<int64_t> &rhShape,
120+
const ::mlir::SmallVector<int64_t> &gShape,
121+
const ::mlir::SmallVector<int64_t> &lhShape,
122+
const ::mlir::SmallVector<int64_t> &ownShape,
123+
const ::mlir::SmallVector<int64_t> &rhShape,
123124
uint64_t team, bool balanced) {
124125
::mlir::Type etyp;
125126

@@ -154,9 +155,7 @@ static ::mlir::Type getTType(::mlir::OpBuilder &builder, DTypeId dtype,
154155
};
155156

156157
if (team) {
157-
if (ownShape.size()) {
158-
auto gShape = ownShape;
159-
gShape[0] += lhShape[0] + rhShape[0];
158+
if (gShape.size()) {
160159
return ::imex::dist::DistTensorType::get(gShape, etyp,
161160
{lhShape, ownShape, rhShape});
162161
} else {
@@ -183,8 +182,10 @@ ::mlir::Value DepManager::getDependent(::mlir::OpBuilder &builder,
183182
ownShape[i] = impl->local_shape()[i];
184183
rhShape[i] = impl->rh_shape()[i];
185184
}
186-
auto typ = getTType(builder, fut.dtype(), lhShape, ownShape, rhShape,
187-
fut.team(), fut.balanced());
185+
auto typ = getTType(
186+
builder, fut.dtype(),
187+
::mlir::SmallVector<int64_t>(impl->shape(), impl->shape() + rank),
188+
lhShape, ownShape, rhShape, fut.team(), fut.balanced());
188189
_func.insertArgument(idx, typ, {}, loc);
189190
auto val = _func.getArgument(idx);
190191
_args.push_back({guid, std::move(fut)});
@@ -516,7 +517,8 @@ JIT::JIT()
516517
crunner = crunner ? crunner : "libmlir_c_runner_utils.so";
517518
const char *idtr = getenv("DDPT_IDTR_SO");
518519
idtr = idtr ? idtr : "libidtr.so";
519-
_sharedLibPaths = {idtr, crunner};
520+
_sharedLibPaths = {idtr, crunner,
521+
"/home/fschlimb/llvm/lib/libmlir_runner_utils.so"};
520522

521523
// detect target architecture
522524
auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost();
@@ -562,14 +564,12 @@ void init() {
562564
::mlir::registerConvertFuncToLLVMPass();
563565
::mlir::bufferization::registerBufferizationPasses();
564566
::mlir::arith::registerArithPasses();
565-
::mlir::registerAffinePasses();
566567
::mlir::registerCanonicalizerPass();
567568
::mlir::registerConvertAffineToStandardPass();
568569
::mlir::registerFinalizeMemRefToLLVMConversionPass();
569570
::mlir::registerArithToLLVMConversionPass();
570571
::mlir::registerConvertMathToLLVMPass();
571572
::mlir::registerConvertControlFlowToLLVMPass();
572-
::mlir::registerConvertLinalgToLLVMPass();
573573
::mlir::registerConvertOpenMPToLLVMPass();
574574
::mlir::memref::registerMemRefPasses();
575575
::mlir::registerReconcileUnrealizedCastsPass();

0 commit comments

Comments
 (0)