Skip to content
This repository was archived by the owner on Jan 26, 2026. It is now read-only.

Commit f0fd1ab

Browse files
committed
first working hard-coded jit
1 parent 88e72e0 commit f0fd1ab

File tree

3 files changed

+115
-89
lines changed

3 files changed

+115
-89
lines changed

src/ddptensor.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,7 @@ using namespace pybind11::literals; // to bring _a
3636
#include "ddptensor/Service.hpp"
3737
#include "ddptensor/Factory.hpp"
3838
#include "ddptensor/IO.hpp"
39-
40-
extern void ttt();
39+
#include "ddptensor/jit/mlir.hpp"
4140

4241
// #########################################################################
4342
// The following classes are wrappers bridging pybind11 defs to TypeDispatch
@@ -123,6 +122,8 @@ PYBIND11_MODULE(_ddptensor, m) {
123122
Factory::init<F_SERVICE>();
124123
Factory::init<F_TONUMPY>();
125124

125+
jit::init();
126+
126127
m.doc() = "A partitioned and distributed tensor";
127128

128129
def_enums(m);
@@ -137,7 +138,7 @@ PYBIND11_MODULE(_ddptensor, m) {
137138
.def("_get_local", &GetItem::get_local)
138139
.def("_gather", &GetItem::gather)
139140
.def("to_numpy", &IO::to_numpy)
140-
.def("ttt", &ttt);
141+
.def("ttt", &jit::ttt);
141142

142143
py::class_<Creator>(m, "Creator")
143144
.def("create_from_shape", &Creator::create_from_shape)

src/include/ddptensor/jit/mlir.hpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// SPDX-License-Identifier: BSD-3-Clause
2+
3+
#pragma once
4+
5+
#include "mlir/IR/BuiltinOps.h"
6+
#include "mlir/IR/MLIRContext.h"
7+
#include "mlir/Pass/PassManager.h"
8+
9+
namespace jit {
10+
11+
// initialize jit
12+
void init();
13+
14+
void ttt();
15+
16+
// A class to manage the MLIR business (compilation and execution).
17+
// Just a stub for now, will need to be extended with paramters and maybe more.
18+
class JIT {
19+
public:
20+
template<typename T, size_t N>
21+
struct MemRefDescriptor {
22+
T *allocated = nullptr;
23+
T *aligned = nullptr;
24+
intptr_t offset = 0;
25+
intptr_t sizes[N] = {0};
26+
intptr_t strides[N] = {0};
27+
};
28+
29+
JIT();
30+
// run
31+
int run(::mlir::ModuleOp &, const std::string &);
32+
33+
::mlir::MLIRContext _context;
34+
::mlir::PassManager _pm;
35+
};
36+
37+
} // namespace jit

src/jit/mlir.cpp

Lines changed: 74 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
// SPDX-License-Identifier: BSD-3-Clause
22

3+
#include "ddptensor/jit/mlir.hpp"
4+
35
#include "mlir/IR/MLIRContext.h"
46

57
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
@@ -29,6 +31,7 @@
2931
#include <imex/InitIMEXPasses.h>
3032

3133
#include <cstdlib>
34+
#include <iostream>
3235

3336
//#include "llvm/ADT/StringRef.h"
3437
//#include "llvm/IR/Module.h"
@@ -39,6 +42,7 @@
3942
#include "llvm/Support/TargetSelect.h"
4043
//#include "llvm/Support/raw_ostream.h"
4144

45+
namespace jit {
4246

4347
static ::mlir::Type makeSignlessType(::mlir::Type type)
4448
{
@@ -59,37 +63,15 @@ auto createI64(const ::mlir::Location & loc, ::mlir::OpBuilder & builder, int64_
5963
return builder.create<::mlir::arith::ConstantOp>(loc, attr).getResult();
6064
}
6165

62-
int processMLIR(::mlir::ModuleOp &module)
63-
{
64-
const char * pl = getenv("DDPT_PASSES");
65-
// "convert-ptensor-to-linalg,dist-elim,convert-shape-to-std,arith-bufferize,func.func(linalg-init-tensor-to-alloc-tensor,scf-bufferize,shape-bufferize,linalg-bufferize,tensor-bufferize),func-bufferize,func.func(finalizing-bufferize,convert-linalg-to-parallel-loops),canonicalize,func.func(lower-affine),fold-memref-subview-ops,convert-scf-to-cf,convert-memref-to-llvm,convert-func-to-llvm,convert-dtensor-to-llvm,reconcile-unrealized-casts",
66-
if(!pl) pl = "convert-ptensor-to-linalg,dist-elim,convert-shape-to-std,arith-bufferize,func.func(linalg-init-tensor-to-alloc-tensor,scf-bufferize,shape-bufferize,linalg-bufferize,tensor-bufferize),func-bufferize,func.func(finalizing-bufferize,convert-linalg-to-parallel-loops),canonicalize,func.func(lower-affine),fold-memref-subview-ops,convert-scf-to-cf,convert-memref-to-llvm,convert-func-to-llvm,reconcile-unrealized-casts";
67-
::mlir::PassManager pm(module.getContext());
68-
if(::mlir::failed(::mlir::parsePassPipeline(pl, pm))) return 3;
69-
70-
pm.enableStatistics();
71-
pm.enableIRPrinting();
72-
pm.dump();
73-
if (::mlir::failed(pm.run(module))) return 4;
74-
75-
return 0;
76-
}
77-
78-
int runJit(::mlir::ModuleOp & module)
79-
{
80-
// Initialize LLVM targets.
81-
::llvm::InitializeNativeTarget();
82-
::llvm::InitializeNativeTargetAsmPrinter();
83-
//::llvm::initializeLLVMPasses();
84-
85-
// Register the translation from ::mlir to LLVM IR, which must happen before we
86-
// can JIT-compile.
87-
::mlir::registerLLVMDialectTranslation(*module->getContext());
66+
int JIT::run(::mlir::ModuleOp & module, const std::string & fname)
67+
{
68+
if (::mlir::failed(_pm.run(module)))
69+
throw std::runtime_error("failed to run pass manager");
8870

8971
// An optimization pipeline to use within the execution engine.
90-
auto optPipeline = ::mlir::makeOptimizingTransformer(0, // /*optLevel=*/enableOpt ? 3 : 0,
91-
/*sizeLevel=*/0,
92-
/*targetMachine=*/nullptr);
72+
auto optPipeline = ::mlir::makeOptimizingTransformer(/*optLevel=*/0,
73+
/*sizeLevel=*/0,
74+
/*targetMachine=*/nullptr);
9375

9476
// Create an ::mlir execution engine. The execution engine eagerly JIT-compiles
9577
// the module.
@@ -99,38 +81,77 @@ int runJit(::mlir::ModuleOp & module)
9981
assert(maybeEngine && "failed to construct an execution engine");
10082
auto &engine = maybeEngine.get();
10183

84+
85+
const char * fn = getenv("DDPT_FN");
86+
if(!fn) fn = fname.c_str();
87+
88+
MemRefDescriptor<int64_t, 1> result;
89+
auto r_ptr = &result;
90+
// int64_t arg = 7;
10291
// Invoke the JIT-compiled function.
103-
auto invocationResult = engine->invokePacked("ttt_"); //, {{}, {}});
104-
if (invocationResult) {
92+
if(engine->invoke(fn, ::mlir::ExecutionEngine::result(r_ptr))) {
10593
::llvm::errs() << "JIT invocation failed\n";
106-
return -1;
94+
throw std::runtime_error("JIT invocation failed");
10795
}
96+
std::cout << "aptr=" << result.allocated << " dptr=" << result.aligned << " offset=" << result.offset << std::endl;
97+
std::cout << ((int64_t*)result.aligned)[result.offset] << std::endl;
10898

10999
return 0;
110100
}
111101

112-
void ttt()
102+
static const char * pass_pipeline =
103+
getenv("DDPT_PASSES")
104+
? getenv("DDPT_PASSES")
105+
: "convert-ptensor-to-linalg,dist-elim,convert-shape-to-std,arith-bufferize,func.func(linalg-init-tensor-to-alloc-tensor,scf-bufferize,shape-bufferize,linalg-bufferize,tensor-bufferize),func-bufferize,func.func(finalizing-bufferize,convert-linalg-to-parallel-loops),canonicalize,func.func(lower-affine),fold-memref-subview-ops,convert-scf-to-cf,convert-memref-to-llvm,convert-func-to-llvm,reconcile-unrealized-casts";
106+
107+
JIT::JIT()
108+
: _context(::mlir::MLIRContext::Threading::DISABLED),
109+
_pm(&_context)
113110
{
114-
std::string fname("_mlir_ttt_");
111+
// Register the translation from ::mlir to LLVM IR, which must happen before we
112+
// can JIT-compile.
113+
::mlir::registerLLVMDialectTranslation(_context);
114+
// load the dialects we use
115+
_context.getOrLoadDialect<::mlir::arith::ArithmeticDialect>();
116+
_context.getOrLoadDialect<::mlir::func::FuncDialect>();
117+
_context.getOrLoadDialect<::imex::ptensor::PTensorDialect>();
118+
_context.getOrLoadDialect<::imex::dist::DistDialect>();
119+
// create the pass pipeline from string
120+
if(::mlir::failed(::mlir::parsePassPipeline(pass_pipeline, _pm)))
121+
throw std::runtime_error("failed to parse pass pipeline");
122+
// some verbosity
123+
_pm.enableStatistics();
124+
_pm.enableIRPrinting();
125+
_pm.dump();
126+
}
115127

128+
void init()
129+
{
116130
::mlir::registerAllPasses();
117131
::imex::registerAllPasses();
118132

119133
// ::mlir::DialectRegistry registry;
120134
// ::mlir::registerAllDialects(registry);
121135
// ::imex::registerAllDialects(registry);
122136

123-
::mlir::MLIRContext context(::mlir::MLIRContext::Threading::DISABLED);
137+
// Initialize LLVM targets.
138+
::llvm::InitializeNativeTarget();
139+
::llvm::InitializeNativeTargetAsmPrinter();
140+
//::llvm::initializeLLVMPasses();
141+
}
124142

125-
context.getOrLoadDialect<::mlir::arith::ArithmeticDialect>();
126-
// context.getOrLoadDialect<::mlir::tensor::TensorDialect>();
127-
// context.getOrLoadDialect<::mlir::linalg::LinalgDialect>();
128-
context.getOrLoadDialect<::mlir::func::FuncDialect>();
129-
// context.getOrLoadDialect<::mlir::shape::ShapeDialect>();
130-
context.getOrLoadDialect<::imex::ptensor::PTensorDialect>();
131-
context.getOrLoadDialect<::imex::dist::DistDialect>();
143+
// mock function for POC testing
144+
// delayed execution will do something like the below:
145+
// * create module
146+
// * create a function and define its types (input and return types)
147+
// * create the function body and return op
148+
// * add function to module
149+
// * compile & run the module
150+
void ttt()
151+
{
152+
JIT jit;
132153

133-
::mlir::OpBuilder builder(&context);
154+
::mlir::OpBuilder builder(&jit._context);
134155
auto loc = builder.getUnknownLoc();
135156
auto module = builder.create<::mlir::ModuleOp>(loc);
136157

@@ -140,7 +161,11 @@ void ttt()
140161
auto artype = ::imex::ptensor::PTensorType::get(builder.getContext(), ::mlir::RankedTensorType::get(shape, dtype), true);
141162
auto rrtype = ::imex::ptensor::PTensorType::get(builder.getContext(), ::mlir::RankedTensorType::get(llvm::SmallVector<int64_t>(), dtype), true);
142163
auto funcType = builder.getFunctionType({}, rrtype);
164+
165+
std::string fname("tttt");
143166
auto function = builder.create<::mlir::func::FuncOp>(loc, fname, funcType);
167+
// request generation of c-wrapper function
168+
function->setAttr(::mlir::LLVM::LLVMDialect::getEmitCWrapperAttrName(), ::mlir::UnitAttr::get(&jit._context));
144169

145170
// Create an ::mlir function for the given prototype.
146171
//::mlir::func::FuncOp function(fproto);
@@ -162,54 +187,17 @@ void ttt()
162187
auto c1 = createI64(loc, builder, 1);
163188
auto c100 = createI64(loc, builder, 100);
164189

190+
// return np.sum(np.arange(1,10,1)+np.arange(1,100,10)) -> 495
165191
auto rangea = builder.create<::imex::ptensor::ARangeOp>(loc, artype, c0, c10, c1, true);
166192
auto rangeb = builder.create<::imex::ptensor::ARangeOp>(loc, artype, c0, c100, c10, true);
167193
auto added = builder.create<::imex::ptensor::EWBinOp>(loc, artype, builder.getI32IntegerAttr(::imex::ptensor::ADD), rangea, rangeb);
168194
auto reduced = builder.create<::imex::ptensor::ReductionOp>(loc, rrtype, builder.getI32IntegerAttr(::imex::ptensor::SUM), added);
169195
auto ret = builder.create<::mlir::func::ReturnOp>(loc, reduced.getResult());
170-
196+
// add the function to the module
171197
module.push_back(function);
172-
module.dump();
173198

174-
if(processMLIR(module)) throw std::runtime_error("failed to process mlir");
175-
module.dump();
176-
177-
if(runJit(module)) throw std::runtime_error("failed to run jit");
178-
179-
#if 0
180-
std::vector<int> shape = {16, 16};
181-
auto elemType = builder.getF64Type();
182-
auto signlessElemType = makeSignlessType(elemType);
183-
auto indexType = builder.getIndexType();
184-
auto count = shape.size();
185-
::llvm::SmallVector<::mlir::Value> shapeVal(count);
186-
::llvm::SmallVector<int64_t> staticShape(count); // ::mlir::ShapedType::kDynamicSize);
187-
188-
for(auto it : ::llvm::enumerate(shape)) {
189-
auto i = it.index();
190-
auto elem = it.value();
191-
auto elemVal = getInt(loc, builder, elem);
192-
staticShape[i] = elem;
193-
shapeVal[i] = elemVal;
194-
}
195-
196-
::mlir::Value init;
197-
if(true) { //initVal.is_none()) {
198-
init = builder.create<::mlir::linalg::InitTensorOp>(loc, shapeVal, signlessElemType);
199-
}// else {
200-
// auto val = doCast(builder, loc, ctx.context.unwrapVal(loc, builder, initVal), signlessElemType);
201-
// ::llvm::SmallVector<int64_t> shape(count, ::mlir::ShapedType::kDynamicSize);
202-
// auto type = ::mlir::RankedTensorType::get(shape, signlessElemType);
203-
// auto body = [&](::mlir::OpBuilder &builder, ::mlir::Location loc, ::mlir::ValueRange /*indices*/) {
204-
// builder.create<::mlir::tensor::YieldOp>(loc, val);
205-
// };
206-
// init = builder.create<::mlir::tensor::GenerateOp>(loc, type, shapeVal, body);
207-
// }
208-
if (::llvm::any_of(staticShape, [](auto val) { return val >= 0; })) {
209-
auto newType = ::mlir::RankedTensorType::get(staticShape, signlessElemType);
210-
init = builder.create<::mlir::tensor::CastOp>(loc, newType, init);
211-
}
212-
auto resTensorTypeSigness = init.getType().cast<::mlir::RankedTensorType>();
213-
auto resTensorType = ::mlir::RankedTensorType::get(resTensorTypeSigness.getShape(), elemType, resTensorTypeSigness.getEncoding());
214-
#endif // 0
199+
// finally compile and run the module
200+
if(jit.run(module, fname)) throw std::runtime_error("failed running jit");
215201
}
202+
203+
} // namespace jit

0 commit comments

Comments
 (0)