|
1 | 1 | // SPDX-License-Identifier: BSD-3-Clause |
2 | 2 |
|
3 | | -#include <oneapi/tbb/concurrent_queue.h> |
4 | 3 | #include "include/ddptensor/Deferred.hpp" |
5 | 4 | #include "include/ddptensor/Transceiver.hpp" |
6 | 5 | #include "include/ddptensor/Mediator.hpp" |
7 | 6 | #include "include/ddptensor/Registry.hpp" |
8 | 7 |
|
| 8 | +#include <oneapi/tbb/concurrent_queue.h> |
| 9 | +#include <mlir/Dialect/Func/IR/FuncOps.h> |
| 10 | +#include <imex/Dialect/PTensor/IR/PTensorOps.h> |
| 11 | +#include <mlir/Dialect/LLVMIR/LLVMDialect.h> |
| 12 | + |
9 | 13 | static tbb::concurrent_bounded_queue<Runable::ptr_type> _deferred; |
10 | 14 |
|
11 | 15 | void push_runable(Runable::ptr_type && r) |
@@ -55,13 +59,46 @@ void Runable::fini() |
55 | 59 |
|
56 | 60 | void process_promises() |
57 | 61 | { |
| 62 | + jit::JIT jit; |
| 63 | + ::mlir::OpBuilder builder(&jit._context); |
| 64 | + auto loc = builder.getUnknownLoc(); |
| 65 | + jit::IdValueMap ivp; |
| 66 | + ::mlir::Value ret_value; |
| 67 | + |
| 68 | + // Create a MLIR module |
| 69 | + auto module = builder.create<::mlir::ModuleOp>(loc); |
| 70 | + // Create a func |
| 71 | + auto dtype = builder.getI64Type(); |
| 72 | + llvm::SmallVector<int64_t> shape(1, -1); //::mlir::ShapedType::kDynamicSize); |
| 73 | + auto rrtype = ::imex::ptensor::PTensorType::get(builder.getContext(), ::mlir::RankedTensorType::get(shape, dtype), true); // llvm::SmallVector<int64_t>() |
| 74 | + auto funcType = builder.getFunctionType({}, rrtype); |
| 75 | + std::string fname("ddpt_jit"); |
| 76 | + auto function = builder.create<::mlir::func::FuncOp>(loc, fname, funcType); |
| 77 | + // request generation of c-wrapper function |
| 78 | + function->setAttr(::mlir::LLVM::LLVMDialect::getEmitCWrapperAttrName(), ::mlir::UnitAttr::get(&jit._context)); |
| 79 | + // create function entry block |
| 80 | + auto &entryBlock = *function.addEntryBlock(); |
| 81 | + // Set the insertion point in the builder to the beginning of the function body |
| 82 | + builder.setInsertionPointToStart(&entryBlock); |
| 83 | + |
58 | 84 | while(true) { |
59 | 85 | Runable::ptr_type d; |
60 | 86 | _deferred.pop(d); |
61 | | - if(d) d->run(); |
62 | | - else break; |
63 | | - d.reset(); |
| 87 | + if(d) { |
| 88 | + d->run(); |
| 89 | + ret_value = d->generate_mlir(builder, loc, ivp); |
| 90 | + d.reset(); |
| 91 | + } else { |
| 92 | + break; |
| 93 | + } |
64 | 94 | } |
| 95 | + |
| 96 | + (void)builder.create<::mlir::func::ReturnOp>(loc, ret_value); |
| 97 | + // add the function to the module |
| 98 | + module.push_back(function); |
| 99 | + module.dump(); |
| 100 | + // finally compile and run the module |
| 101 | + if(jit.run(module, fname)) throw std::runtime_error("failed running jit"); |
65 | 102 | } |
66 | 103 |
|
67 | 104 | void sync() |
|
0 commit comments