11// SPDX-License-Identifier: BSD-3-Clause
22
3+ #include " ddptensor/jit/mlir.hpp"
4+
35#include " mlir/IR/MLIRContext.h"
46
57#include " mlir/Dialect/Arithmetic/IR/Arithmetic.h"
2931#include < imex/InitIMEXPasses.h>
3032
3133#include < cstdlib>
34+ #include < iostream>
3235
3336// #include "llvm/ADT/StringRef.h"
3437// #include "llvm/IR/Module.h"
3942#include " llvm/Support/TargetSelect.h"
4043// #include "llvm/Support/raw_ostream.h"
4144
45+ namespace jit {
4246
4347static ::mlir::Type makeSignlessType (::mlir::Type type)
4448{
@@ -59,37 +63,15 @@ auto createI64(const ::mlir::Location & loc, ::mlir::OpBuilder & builder, int64_
5963 return builder.create <::mlir::arith::ConstantOp>(loc, attr).getResult ();
6064}
6165
62- int processMLIR (::mlir::ModuleOp &module )
63- {
64- const char * pl = getenv (" DDPT_PASSES" );
65- // "convert-ptensor-to-linalg,dist-elim,convert-shape-to-std,arith-bufferize,func.func(linalg-init-tensor-to-alloc-tensor,scf-bufferize,shape-bufferize,linalg-bufferize,tensor-bufferize),func-bufferize,func.func(finalizing-bufferize,convert-linalg-to-parallel-loops),canonicalize,func.func(lower-affine),fold-memref-subview-ops,convert-scf-to-cf,convert-memref-to-llvm,convert-func-to-llvm,convert-dtensor-to-llvm,reconcile-unrealized-casts",
66- if (!pl) pl = " convert-ptensor-to-linalg,dist-elim,convert-shape-to-std,arith-bufferize,func.func(linalg-init-tensor-to-alloc-tensor,scf-bufferize,shape-bufferize,linalg-bufferize,tensor-bufferize),func-bufferize,func.func(finalizing-bufferize,convert-linalg-to-parallel-loops),canonicalize,func.func(lower-affine),fold-memref-subview-ops,convert-scf-to-cf,convert-memref-to-llvm,convert-func-to-llvm,reconcile-unrealized-casts" ;
67- ::mlir::PassManager pm (module .getContext ());
68- if (::mlir::failed (::mlir::parsePassPipeline (pl, pm))) return 3 ;
69-
70- pm.enableStatistics ();
71- pm.enableIRPrinting ();
72- pm.dump ();
73- if (::mlir::failed (pm.run (module ))) return 4 ;
74-
75- return 0 ;
76- }
77-
78- int runJit (::mlir::ModuleOp & module )
79- {
80- // Initialize LLVM targets.
81- ::llvm::InitializeNativeTarget ();
82- ::llvm::InitializeNativeTargetAsmPrinter ();
83- // ::llvm::initializeLLVMPasses();
84-
85- // Register the translation from ::mlir to LLVM IR, which must happen before we
86- // can JIT-compile.
87- ::mlir::registerLLVMDialectTranslation (*module ->getContext ());
66+ int JIT::run (::mlir::ModuleOp & module , const std::string & fname)
67+ {
68+ if (::mlir::failed (_pm.run (module )))
69+ throw std::runtime_error (" failed to run pass manager" );
8870
8971 // An optimization pipeline to use within the execution engine.
90- auto optPipeline = ::mlir::makeOptimizingTransformer(0 , // / *optLevel=*/enableOpt ? 3 : 0,
91- /* sizeLevel=*/ 0 ,
92- /* targetMachine=*/ nullptr );
72+ auto optPipeline = ::mlir::makeOptimizingTransformer (/ * optLevel=*/ 0 ,
73+ /* sizeLevel=*/ 0 ,
74+ /* targetMachine=*/ nullptr );
9375
9476 // Create an ::mlir execution engine. The execution engine eagerly JIT-compiles
9577 // the module.
@@ -99,38 +81,77 @@ int runJit(::mlir::ModuleOp & module)
9981 assert (maybeEngine && " failed to construct an execution engine" );
10082 auto &engine = maybeEngine.get ();
10183
84+
85+ const char * fn = getenv (" DDPT_FN" );
86+ if (!fn) fn = fname.c_str ();
87+
88+ MemRefDescriptor<int64_t , 1 > result;
89+ auto r_ptr = &result;
90+ // int64_t arg = 7;
10291 // Invoke the JIT-compiled function.
103- auto invocationResult = engine->invokePacked (" ttt_" ); // , {{}, {}});
104- if (invocationResult) {
92+ if (engine->invoke (fn, ::mlir::ExecutionEngine::result (r_ptr))) {
10593 ::llvm::errs () << "JIT invocation failed\n";
106- return - 1 ;
94+ throw std::runtime_error ( " JIT invocation failed " ) ;
10795 }
96+ std::cout << " aptr=" << result.allocated << " dptr=" << result.aligned << " offset=" << result.offset << std::endl;
97+ std::cout << ((int64_t *)result.aligned )[result.offset ] << std::endl;
10898
10999 return 0 ;
110100}
111101
112- void ttt ()
102+ static const char * pass_pipeline =
103+ getenv (" DDPT_PASSES" )
104+ ? getenv(" DDPT_PASSES" )
105+ : " convert-ptensor-to-linalg,dist-elim,convert-shape-to-std,arith-bufferize,func.func(linalg-init-tensor-to-alloc-tensor,scf-bufferize,shape-bufferize,linalg-bufferize,tensor-bufferize),func-bufferize,func.func(finalizing-bufferize,convert-linalg-to-parallel-loops),canonicalize,func.func(lower-affine),fold-memref-subview-ops,convert-scf-to-cf,convert-memref-to-llvm,convert-func-to-llvm,reconcile-unrealized-casts" ;
106+
107+ JIT::JIT ()
108+ : _context(::mlir::MLIRContext::Threading::DISABLED),
109+ _pm (&_context)
113110{
114- std::string fname (" _mlir_ttt_" );
111+ // Register the translation from ::mlir to LLVM IR, which must happen before we
112+ // can JIT-compile.
113+ ::mlir::registerLLVMDialectTranslation (_context);
114+ // load the dialects we use
115+ _context.getOrLoadDialect <::mlir::arith::ArithmeticDialect>();
116+ _context.getOrLoadDialect <::mlir::func::FuncDialect>();
117+ _context.getOrLoadDialect <::imex::ptensor::PTensorDialect>();
118+ _context.getOrLoadDialect <::imex::dist::DistDialect>();
119+ // create the pass pipeline from string
120+ if (::mlir::failed (::mlir::parsePassPipeline (pass_pipeline, _pm)))
121+ throw std::runtime_error (" failed to parse pass pipeline" );
122+ // some verbosity
123+ _pm.enableStatistics ();
124+ _pm.enableIRPrinting ();
125+ _pm.dump ();
126+ }
115127
128+ void init ()
129+ {
116130 ::mlir::registerAllPasses ();
117131 ::imex::registerAllPasses ();
118132
119133 // ::mlir::DialectRegistry registry;
120134 // ::mlir::registerAllDialects(registry);
121135 // ::imex::registerAllDialects(registry);
122136
123- ::mlir::MLIRContext context (::mlir::MLIRContext::Threading::DISABLED);
137+ // Initialize LLVM targets.
138+ ::llvm::InitializeNativeTarget ();
139+ ::llvm::InitializeNativeTargetAsmPrinter ();
140+ // ::llvm::initializeLLVMPasses();
141+ }
124142
125- context.getOrLoadDialect <::mlir::arith::ArithmeticDialect>();
126- // context.getOrLoadDialect<::mlir::tensor::TensorDialect>();
127- // context.getOrLoadDialect<::mlir::linalg::LinalgDialect>();
128- context.getOrLoadDialect <::mlir::func::FuncDialect>();
129- // context.getOrLoadDialect<::mlir::shape::ShapeDialect>();
130- context.getOrLoadDialect <::imex::ptensor::PTensorDialect>();
131- context.getOrLoadDialect <::imex::dist::DistDialect>();
143+ // mock function for POC testing
144+ // delayed execution will do something like the below:
145+ // * create module
146+ // * create a function and define its types (input and return types)
147+ // * create the function body and return op
148+ // * add function to module
149+ // * compile & run the module
150+ void ttt ()
151+ {
152+ JIT jit;
132153
133- ::mlir::OpBuilder builder (&context );
154+ ::mlir::OpBuilder builder (&jit. _context );
134155 auto loc = builder.getUnknownLoc ();
135156 auto module = builder.create <::mlir::ModuleOp>(loc);
136157
@@ -140,7 +161,11 @@ void ttt()
140161 auto artype = ::imex::ptensor::PTensorType::get (builder.getContext (), ::mlir::RankedTensorType::get (shape, dtype), true );
141162 auto rrtype = ::imex::ptensor::PTensorType::get (builder.getContext (), ::mlir::RankedTensorType::get (llvm::SmallVector<int64_t >(), dtype), true );
142163 auto funcType = builder.getFunctionType ({}, rrtype);
164+
165+ std::string fname (" tttt" );
143166 auto function = builder.create <::mlir::func::FuncOp>(loc, fname, funcType);
167+ // request generation of c-wrapper function
168+ function->setAttr (::mlir::LLVM::LLVMDialect::getEmitCWrapperAttrName (), ::mlir::UnitAttr::get (&jit._context ));
144169
145170 // Create an ::mlir function for the given prototype.
146171 // ::mlir::func::FuncOp function(fproto);
@@ -162,54 +187,17 @@ void ttt()
162187 auto c1 = createI64 (loc, builder, 1 );
163188 auto c100 = createI64 (loc, builder, 100 );
164189
190+ // return np.sum(np.arange(1,10,1)+np.arange(1,100,10)) -> 495
165191 auto rangea = builder.create <::imex::ptensor::ARangeOp>(loc, artype, c0, c10, c1, true );
166192 auto rangeb = builder.create <::imex::ptensor::ARangeOp>(loc, artype, c0, c100, c10, true );
167193 auto added = builder.create <::imex::ptensor::EWBinOp>(loc, artype, builder.getI32IntegerAttr (::imex::ptensor::ADD), rangea, rangeb);
168194 auto reduced = builder.create <::imex::ptensor::ReductionOp>(loc, rrtype, builder.getI32IntegerAttr (::imex::ptensor::SUM), added);
169195 auto ret = builder.create <::mlir::func::ReturnOp>(loc, reduced.getResult ());
170-
196+ // add the function to the module
171197 module .push_back (function);
172- module .dump ();
173198
174- if (processMLIR (module )) throw std::runtime_error (" failed to process mlir" );
175- module .dump ();
176-
177- if (runJit (module )) throw std::runtime_error (" failed to run jit" );
178-
179- #if 0
180- std::vector<int> shape = {16, 16};
181- auto elemType = builder.getF64Type();
182- auto signlessElemType = makeSignlessType(elemType);
183- auto indexType = builder.getIndexType();
184- auto count = shape.size();
185- ::llvm::SmallVector<::mlir::Value> shapeVal(count);
186- ::llvm::SmallVector<int64_t> staticShape(count); // ::mlir::ShapedType::kDynamicSize);
187-
188- for(auto it : ::llvm::enumerate(shape)) {
189- auto i = it.index();
190- auto elem = it.value();
191- auto elemVal = getInt(loc, builder, elem);
192- staticShape[i] = elem;
193- shapeVal[i] = elemVal;
194- }
195-
196- ::mlir::Value init;
197- if(true) { //initVal.is_none()) {
198- init = builder.create<::mlir::linalg::InitTensorOp>(loc, shapeVal, signlessElemType);
199- }// else {
200- // auto val = doCast(builder, loc, ctx.context.unwrapVal(loc, builder, initVal), signlessElemType);
201- // ::llvm::SmallVector<int64_t> shape(count, ::mlir::ShapedType::kDynamicSize);
202- // auto type = ::mlir::RankedTensorType::get(shape, signlessElemType);
203- // auto body = [&](::mlir::OpBuilder &builder, ::mlir::Location loc, ::mlir::ValueRange /*indices*/) {
204- // builder.create<::mlir::tensor::YieldOp>(loc, val);
205- // };
206- // init = builder.create<::mlir::tensor::GenerateOp>(loc, type, shapeVal, body);
207- // }
208- if (::llvm::any_of(staticShape, [](auto val) { return val >= 0; })) {
209- auto newType = ::mlir::RankedTensorType::get(staticShape, signlessElemType);
210- init = builder.create<::mlir::tensor::CastOp>(loc, newType, init);
211- }
212- auto resTensorTypeSigness = init.getType().cast<::mlir::RankedTensorType>();
213- auto resTensorType = ::mlir::RankedTensorType::get(resTensorTypeSigness.getShape(), elemType, resTensorTypeSigness.getEncoding());
214- #endif // 0
199+ // finally compile and run the module
200+ if (jit.run (module , fname)) throw std::runtime_error (" failed running jit" );
215201}
202+
203+ } // namespace jit
0 commit comments