Skip to content
This repository was archived by the owner on Jan 26, 2026. It is now read-only.

Commit 1edc643

Browse files
committed
adding caching of compiled code; adding Tosa bits
1 parent fe1976a commit 1edc643

File tree

6 files changed

+84
-39
lines changed

6 files changed

+84
-39
lines changed

CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,8 @@ target_link_libraries(_ddpt_rt PRIVATE
185185
MLIRLinalgTransforms
186186
MLIRLLVMDialect
187187
MLIRMathDialect
188+
MLIRMathToFuncs
189+
MLIRMathToLibm
188190
MLIRMathToLLVM
189191
MLIRMathTransforms
190192
MLIRMemRefDialect
@@ -197,6 +199,8 @@ target_link_libraries(_ddpt_rt PRIVATE
197199
MLIRShapeDialect
198200
MLIRShapeOpsTransforms
199201
MLIRShapeToStandard
202+
MLIRTosaDialect
203+
MLIRTosaToLinalg
200204
MLIRTensorTransforms
201205
)
202206
# LLVM${LLVM_NATIVE_ARCH}CodeGen

src/DDPTensorImpl.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,6 @@ void DDPTensorImpl::add_to_args(std::vector<void *> &args, int ndims) {
162162
buff[2] = static_cast<intptr_t>(_offset);
163163
memcpy(buff + 3, _sizes, ndims * sizeof(intptr_t));
164164
memcpy(buff + 3 + ndims, _strides, ndims * sizeof(intptr_t));
165-
for (auto i = 0; i < 3 + 2 * ndims; ++i)
166-
std::cerr << " " << buff[i];
167165
args.push_back(buff);
168166
// second the transceiver
169167
args.push_back(&_transceiver);

src/Deferred.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
#include <oneapi/tbb/concurrent_queue.h>
2121

2222
#include <iostream>
23-
#include <unordered_set>
2423

2524
// thread-safe FIFO queue holding deferred objects
2625
static tbb::concurrent_bounded_queue<Runable::ptr_type> _deferred;
@@ -71,7 +70,7 @@ void Runable::defer(Runable::ptr_type &&p) { push_runable(std::move(p)); }
7170
void Runable::fini() { _deferred.clear(); }
7271

7372
// process promises as they arrive through calls to defer
74-
// This is run in a separate thread until shutdon is requested.
73+
// This is run in a separate thread until shutdown is requested.
7574
// Shutdown is indicated by a Deferred object which evaluates to false.
7675
// The loop repeatedly creates MLIR functions for jit-compilation by letting
7776
// Deferred objects add their MLIR code until an object can not produce MLIR
@@ -138,14 +137,12 @@ void process_promises() {
138137

139138
if (osz > 0 || !input.empty()) {
140139
// compile and run the module
141-
intptr_t *output = new intptr_t[osz];
142-
if (jit.run(module, fname, input, output))
140+
auto output = jit.run(module, fname, input, osz);
141+
if (output.size() != osz)
143142
throw std::runtime_error("failed running jit");
144143

145144
// push results to deliver promises
146145
dm.deliver(output, osz);
147-
148-
delete[] output;
149146
} else {
150147
std::cerr << "\tskipping\n";
151148
}

src/include/ddptensor/jit/mlir.hpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ class DepManager {
113113
uint64_t handleResult(::mlir::OpBuilder &builder);
114114

115115
/// devlier promise after execution
116-
void deliver(intptr_t *, uint64_t);
116+
void deliver(std::vector<intptr_t> &, uint64_t);
117117

118118
/// @return total size of all input arguments in number of intptr_t
119119
uint64_t arg_size();
@@ -137,12 +137,13 @@ class JIT {
137137

138138
JIT();
139139
// run
140-
int run(::mlir::ModuleOp &, const std::string &, std::vector<void *> &,
141-
intptr_t *);
140+
std::vector<intptr_t> run(::mlir::ModuleOp &, const std::string &,
141+
std::vector<void *> &, size_t);
142142

143143
::mlir::MLIRContext _context;
144144
::mlir::PassManager _pm;
145-
bool _verbose;
145+
bool _verbose, _useCache;
146+
const char *_sharedLibPaths;
146147
};
147148

148149
// size of memreftype in number of intptr_t's

src/jit/mlir.cpp

Lines changed: 70 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -65,19 +65,20 @@
6565
// #include "mlir/Dialect/SparseTensor/Pipelines/Passes.h"
6666
// #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
6767
#include "mlir/Dialect/Tensor/Transforms/Passes.h"
68-
// #include "mlir/Dialect/Tosa/Transforms/Passes.h"
68+
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
6969
// #include "mlir/Dialect/Transform/Transforms/Passes.h"
7070
// #include "mlir/Dialect/Vector/Transforms/Passes.h"
7171
#include "mlir/Transforms/Passes.h"
7272
// #include <mlir/InitAllPasses.h>
7373

7474
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
7575
#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
76-
7776
#include "mlir/ExecutionEngine/ExecutionEngine.h"
7877
#include "mlir/ExecutionEngine/OptUtils.h"
7978
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
8079

80+
#include <llvm/Support/raw_sha1_ostream.h>
81+
8182
#include <imex/Dialect/PTensor/IR/PTensorOps.h>
8283
#include <imex/InitIMEXDialects.h>
8384
#include <imex/InitIMEXPasses.h>
@@ -178,7 +179,6 @@ std::vector<void *> DepManager::store_inputs() {
178179
std::vector<void *> res;
179180
for (auto a : _args) {
180181
auto f = Registry::get(a.first);
181-
std::cerr << " store guid " << a.first;
182182
f.get().get()->add_to_args(res, a.second);
183183
_ivm.erase(a.first); // inputs need no delivery
184184
_icm.erase(a.first);
@@ -254,7 +254,8 @@ uint64_t DepManager::handleResult(::mlir::OpBuilder &builder) {
254254
return 2 * sz;
255255
}
256256

257-
void DepManager::deliver(intptr_t *output, uint64_t sz) {
257+
void DepManager::deliver(std::vector<intptr_t> &outputV, uint64_t sz) {
258+
auto output = outputV.data();
258259
size_t pos = 0;
259260
for (auto &v : _icm) {
260261
auto rank = _irm[v.first];
@@ -305,14 +306,30 @@ void DepManager::deliver(intptr_t *output, uint64_t sz) {
305306
}
306307
}
307308

308-
int JIT::run(::mlir::ModuleOp &module, const std::string &fname,
309-
std::vector<void *> &inp, intptr_t *out) {
310-
// lower to LLVM
311-
if (::mlir::failed(_pm.run(module)))
312-
throw std::runtime_error("failed to run pass manager");
313-
314-
if (_verbose)
315-
module.dump();
309+
std::vector<intptr_t> JIT::run(::mlir::ModuleOp &module,
310+
const std::string &fname,
311+
std::vector<void *> &inp, size_t osz) {
312+
if (_useCache) {
313+
::mlir::ModuleOp cached;
314+
static std::vector<
315+
std::pair<std::array<unsigned char, 20>, ::mlir::ModuleOp>>
316+
cache;
317+
llvm::raw_sha1_ostream xxx;
318+
module->print(xxx);
319+
auto cksm = xxx.sha1();
320+
for (auto x : cache) {
321+
if (x.first == cksm) {
322+
cached = x.second;
323+
break;
324+
}
325+
}
326+
if (cached) {
327+
module = cached;
328+
std::cerr << "using cached module" << std::endl;
329+
} else {
330+
cache.push_back(std::make_pair(cksm, module));
331+
}
332+
}
316333

317334
// An optimization pipeline to use within the execution engine.
318335
auto optPipeline =
@@ -322,21 +339,27 @@ int JIT::run(::mlir::ModuleOp &module, const std::string &fname,
322339

323340
// Create an ::mlir execution engine. The execution engine eagerly
324341
// JIT-compiles the module.
325-
::mlir::ExecutionEngineOptions engineOptions;
326-
engineOptions.transformer = optPipeline;
327-
// const char * crunner = getenv("DDPT_CRUNNER_SO");
328-
// crunner = crunner ? crunner : "libmlir_c_runner_utils.so";
329-
const char *idtr = getenv("DDPT_IDTR_SO");
330-
idtr = idtr ? idtr : "libidtr.so";
331-
// ::llvm::ArrayRef<::llvm::StringRef> shlibs = {crunner, idtr};
332-
engineOptions.sharedLibPaths = {idtr};
333-
auto maybeEngine = ::mlir::ExecutionEngine::create(module, engineOptions);
342+
::mlir::ExecutionEngineOptions opts;
343+
opts.transformer = optPipeline;
344+
opts.sharedLibPaths = {_sharedLibPaths};
345+
opts.enableObjectDump = _useCache;
346+
347+
// lower to LLVM
348+
if (::mlir::failed(_pm.run(module)))
349+
throw std::runtime_error("failed to run pass manager");
350+
351+
if (_verbose)
352+
module.dump();
353+
354+
auto maybeEngine = ::mlir::ExecutionEngine::create(module, opts);
334355
assert(maybeEngine && "failed to construct an execution engine");
335356
auto &engine = maybeEngine.get();
336357

337358
llvm::SmallVector<void *> args;
359+
std::vector<intptr_t> out(osz);
360+
auto tmp = out.data();
338361
// first arg must be the result ptr
339-
args.push_back(&out);
362+
args.push_back(&tmp);
340363
// we need a void*& for every input tensor
341364
// we refer directly to the storage in inp
342365
for (auto &arg : inp) {
@@ -350,7 +373,7 @@ int JIT::run(::mlir::ModuleOp &module, const std::string &fname,
350373
throw std::runtime_error("JIT invocation failed");
351374
}
352375

353-
return 0;
376+
return out;
354377
}
355378

356379
static const char *pass_pipeline =
@@ -362,11 +385,13 @@ static const char *pass_pipeline =
362385
// "builtin.module(func.func(ptensor-dist),convert-dist-to-standard,convert-ptensor-to-linalg,arith-bufferize,func.func(empty-tensor-to-alloc-tensor,scf-bufferize,linalg-bufferize,tensor-bufferize,bufferization-bufferize),func-bufferize,func.func(finalizing-bufferize,convert-linalg-to-parallel-loops),canonicalize,fold-memref-alias-ops,expand-strided-metadata,lower-affine,convert-scf-to-cf,convert-memref-to-llvm,convert-func-to-llvm,reconcile-unrealized-casts)";
363386
: "func.func(ptensor-dist,dist-coalesce),convert-dist-to-standard,"
364387
"convert-ptensor-to-linalg,canonicalize,convert-shape-to-std,arith-"
365-
"expand,canonicalize,arith-bufferize,func-bufferize,func.func(empty-"
366-
"tensor-to-alloc-tensor,scf-bufferize,tensor-bufferize,linalg-"
388+
"expand,canonicalize,arith-bufferize,func-bufferize,func.func(tosa-"
389+
"to-linalg,"
390+
"empty-tensor-to-alloc-tensor,scf-bufferize,tensor-bufferize,linalg-"
367391
"bufferize,bufferization-bufferize,linalg-detensorize,tensor-"
368392
"bufferize,finalizing-bufferize,convert-linalg-to-parallel-loops),"
369-
"canonicalize,fold-memref-alias-ops,expand-strided-metadata,lower-"
393+
"canonicalize,fold-memref-alias-ops,expand-strided-metadata,convert-"
394+
"math-to-funcs,convert-math-to-libm,lower-"
370395
"affine,convert-scf-to-cf,convert-memref-to-llvm,convert-func-to-"
371396
"llvm,reconcile-unrealized-casts";
372397
JIT::JIT()
@@ -391,12 +416,27 @@ JIT::JIT()
391416
if (v == "1" || v == "y" || v == "Y" || v == "on" || v == "ON")
392417
_verbose = true;
393418
}
419+
_pm.enableTiming();
394420
// some verbosity
395421
if (_verbose) {
396422
_pm.enableStatistics();
397423
_pm.enableIRPrinting();
398424
_pm.dump();
399425
}
426+
427+
const char *envptr = getenv("DDPT_USE_CACHE");
428+
envptr = envptr ? envptr : "1";
429+
{
430+
auto c = std::string(envptr);
431+
_useCache = c == "1" || c == "y" || c == "Y" || c == "on" || c == "ON";
432+
std::cerr << "enableObjectDump=" << _useCache << std::endl;
433+
}
434+
435+
// const char * crunner = getenv("DDPT_CRUNNER_SO");
436+
// crunner = crunner ? crunner : "libmlir_c_runner_utils.so";
437+
envptr = getenv("DDPT_IDTR_SO");
438+
_sharedLibPaths = envptr ? envptr : "libidtr.so";
439+
// ::llvm::ArrayRef<::llvm::StringRef> shlibs = {crunner, envptr};
400440
}
401441

402442
// register dialects and passes
@@ -411,6 +451,10 @@ void init() {
411451
::mlir::registerConvertShapeToStandardPass();
412452
::mlir::tensor::registerTensorPasses();
413453
::mlir::registerLinalgPasses();
454+
::mlir::registerTosaToLinalg();
455+
::mlir::registerConvertMathToFuncs();
456+
::mlir::registerConvertMathToLibm();
457+
::mlir::tosa::registerTosaOptPasses();
414458
::mlir::func::registerFuncPasses();
415459
::mlir::registerConvertFuncToLLVMPass();
416460
::mlir::bufferization::registerBufferizationPasses();

test/stencil-2d.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def main():
119119

120120
# there is certainly a more Pythonic way to initialize W,
121121
# but it will have no impact on performance.
122+
t0 = timer()
122123
W = np.zeros(((2 * r + 1), (2 * r + 1)), dtype=np.float64)
123124
A = np.empty((n, n), dtype=np.float64)
124125
B = np.zeros((n, n), dtype=np.float64)
@@ -149,8 +150,8 @@ def main():
149150

150151
for k in range(iterations + 1):
151152
# start timer after a warmup iteration
153+
np.sync()
152154
if k <= 1:
153-
np.sync()
154155
t0 = timer()
155156

156157
if pattern == "star":

0 commit comments

Comments
 (0)