@@ -329,57 +329,41 @@ void DepManager::deliver(std::vector<intptr_t> &outputV, uint64_t sz) {
329329std::vector<intptr_t > JIT::run (::mlir::ModuleOp &module ,
330330 const std::string &fname,
331331 std::vector<void *> &inp, size_t osz) {
332- ::mlir::ModuleOp cached;
332+
333+ ::mlir::ExecutionEngine *enginePtr;
334+ std::unique_ptr<::mlir::ExecutionEngine> tmpEngine;
335+
333336 if (_useCache) {
334- static std::vector<
335- std::pair<std::array<unsigned char , 20 >, ::mlir::ModuleOp>>
336- cache;
337- llvm::raw_sha1_ostream xxx;
338- module ->print (xxx);
339- auto cksm = xxx.sha1 ();
340- for (auto x : cache) {
341- if (x.first == cksm) {
342- cached = x.second ;
343- break ;
344- }
345- }
346- if (cached) {
347- module = cached;
348- if (_verbose)
349- std::cerr << " cached..." << std::endl;
337+ static std::map<std::array<unsigned char , 20 >,
338+ std::unique_ptr<::mlir::ExecutionEngine>>
339+ engineCache;
340+
341+ llvm::raw_sha1_ostream shaOS;
342+ module ->print (shaOS);
343+ auto cksm = shaOS.sha1 ();
344+
345+ if (auto search = engineCache.find (cksm); search == engineCache.end ()) {
346+ engineCache[cksm] = createExecutionEngine (module );
350347 } else {
351348 if (_verbose)
352- std::cerr << " compiling..." << std::endl;
353- cache.push_back (std::make_pair (cksm, module ));
354- if (_verbose > 1 )
355- module .dump ();
349+ std::cerr << " cached..." << std::endl;
356350 }
351+ enginePtr = engineCache[cksm].get ();
352+ } else {
353+ tmpEngine = createExecutionEngine (module );
354+ enginePtr = tmpEngine.get ();
357355 }
358356
359- // An optimization pipeline to use within the execution engine.
360- auto optPipeline =
361- ::mlir::makeOptimizingTransformer (/* optLevel=*/ 0 ,
362- /* sizeLevel=*/ 0 ,
363- /* targetMachine=*/ nullptr );
364-
365- // Create an ::mlir execution engine. The execution engine eagerly
366- // JIT-compiles the module.
367- ::mlir::ExecutionEngineOptions opts;
368- opts.transformer = optPipeline;
369- opts.sharedLibPaths = _sharedLibPaths;
370- opts.enableObjectDump = _useCache;
371-
372- // lower to LLVM
373- if (::mlir::failed (_pm.run (module )))
374- throw std::runtime_error (" failed to run pass manager" );
375-
376- if (_verbose > 2 && !cached)
377- module .dump ();
378-
379- auto maybeEngine = ::mlir::ExecutionEngine::create (module , opts);
380- assert (maybeEngine && " failed to construct an execution engine" );
381- auto &engine = maybeEngine.get ();
357+ auto expectedFPtr =
358+ enginePtr->lookupPacked (std::string (" _mlir_ciface_" ) + fname);
359+ if (auto err = expectedFPtr.takeError ()) {
360+ ::llvm::errs () << "JIT invocation failed: " << toString(std::move(err))
361+ << "\n";
362+ throw std::runtime_error (" JIT invocation failed" );
363+ }
364+ auto jittedFuncPtr = *expectedFPtr;
382365
366+ // pack function arguments
383367 llvm::SmallVector<void *> args;
384368 std::vector<intptr_t > out (osz);
385369 auto tmp = out.data ();
@@ -393,16 +377,39 @@ std::vector<intptr_t> JIT::run(::mlir::ModuleOp &module,
393377 args.push_back (&arg);
394378 }
395379
396- // Invoke the JIT-compiled function.
397- if (engine->invokePacked (std::string (" _mlir_ciface_" ) + fname.c_str (),
398- args)) {
399- ::llvm::errs () << "JIT invocation failed\n";
400- throw std::runtime_error (" JIT invocation failed" );
401- }
380+ // call function
381+ (*jittedFuncPtr)(args.data ());
402382
403383 return out;
404384}
405385
386+ std::unique_ptr<::mlir::ExecutionEngine>
387+ JIT::createExecutionEngine (::mlir::ModuleOp &module ) {
388+ if (_verbose)
389+ std::cerr << " compiling..." << std::endl;
390+ if (_verbose > 1 )
391+ module .dump ();
392+
393+ // Create an ::mlir execution engine. The execution engine eagerly
394+ // JIT-compiles the module.
395+ ::mlir::ExecutionEngineOptions opts;
396+ opts.transformer = _optPipeline;
397+ opts.jitCodeGenOptLevel = llvm::CodeGenOpt::getLevel (_jit_opt_level);
398+ opts.sharedLibPaths = _sharedLibPaths;
399+ opts.enableObjectDump = true ;
400+
401+ // lower to LLVM
402+ if (::mlir::failed (_pm.run (module )))
403+ throw std::runtime_error (" failed to run pass manager" );
404+
405+ if (_verbose > 2 )
406+ module .dump ();
407+
408+ auto maybeEngine = ::mlir::ExecutionEngine::create (module , opts);
409+ assert (maybeEngine && " failed to construct an execution engine" );
410+ return std::move (maybeEngine.get ());
411+ }
412+
406413static const char *pass_pipeline =
407414 getenv (" DDPT_PASSES" ) ? getenv(" DDPT_PASSES" )
408415 : " func.func(ptensor-dist),"
@@ -443,7 +450,7 @@ static const char *pass_pipeline =
443450 " reconcile-unrealized-casts" ;
444451JIT::JIT ()
445452 : _context(::mlir::MLIRContext::Threading::DISABLED), _pm(&_context),
446- _verbose (0 ) {
453+ _verbose (0 ), _jit_opt_level( 3 ) {
447454 // Register the translation from ::mlir to LLVM IR, which must happen before
448455 // we can JIT-compile.
449456 ::mlir::registerLLVMDialectTranslation (_context);
@@ -481,12 +488,39 @@ JIT::JIT()
481488 _useCache = c == " 1" || c == " y" || c == " Y" || c == " on" || c == " ON" ;
482489 std::cerr << " enableObjectDump=" << _useCache << std::endl;
483490 }
491+ const char *ol_ = getenv (" DDPT_OPT_LEVEL" );
492+ if (ol_) {
493+ _jit_opt_level = std::stoi (ol_);
494+ if (_jit_opt_level < 0 || _jit_opt_level > 3 ) {
495+ throw std::runtime_error (std::string (" Bad optimization level: " ) + ol_);
496+ }
497+ }
484498
485499 const char *crunner = getenv (" DDPT_CRUNNER_SO" );
486500 crunner = crunner ? crunner : " libmlir_c_runner_utils.so" ;
487501 const char *idtr = getenv (" DDPT_IDTR_SO" );
488502 idtr = idtr ? idtr : " libidtr.so" ;
489503 _sharedLibPaths = {idtr, crunner};
504+
505+ // detect target architecture
506+ auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost ();
507+ if (!tmBuilderOrError) {
508+ throw std::runtime_error (
509+ " Failed to create a JITTargetMachineBuilder for the host\n " );
510+ }
511+
512+ // build TargetMachine
513+ auto tmOrError = tmBuilderOrError->createTargetMachine ();
514+ if (!tmOrError) {
515+ throw std::runtime_error (" Failed to create a TargetMachine for the host\n " );
516+ }
517+ _tm = std::move (tmOrError.get ());
518+
519+ // build optimizing pipeline
520+ _optPipeline = ::mlir::makeOptimizingTransformer (
521+ /* optLevel=*/ _jit_opt_level,
522+ /* sizeLevel=*/ 0 ,
523+ /* targetMachine=*/ _tm.get ());
490524}
491525
492526// register dialects and passes
0 commit comments