Skip to content
This repository was archived by the owner on Jan 26, 2026. It is now read-only.

Commit 737de51

Browse files
committed
handling input args to jit-compiled function
1 parent 19b18db commit 737de51

File tree

5 files changed

+110
-22
lines changed

5 files changed

+110
-22
lines changed

src/Deferred.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -106,20 +106,24 @@ void process_promises()
106106
if(runables.empty()) continue;
107107

108108
// create return statement and adjust function type
109-
uint64_t sz = dm.handleResult(builder);
109+
uint64_t osz = dm.handleResult(builder);
110110
// also request generation of c-wrapper function
111111
function->setAttr(::mlir::LLVM::LLVMDialect::getEmitCWrapperAttrName(), ::mlir::UnitAttr::get(&jit._context));
112112
// add the function to the module
113113
module.push_back(function);
114114
module.dump();
115115

116+
// get input buffers (before rsults!)
117+
auto input = std::move(dm.store_inputs());
118+
116119
// compile and run the module
117-
assert(sizeof(intptr_t) == sizeof(void*));
118-
intptr_t * output = new intptr_t[sz];
119-
if(jit.run(module, fname, output)) throw std::runtime_error("failed running jit");
120+
intptr_t * output = new intptr_t[osz];
121+
if(jit.run(module, fname, input, output)) throw std::runtime_error("failed running jit");
120122

121123
// push results to deliver promises
122-
dm.deliver(output, sz);
124+
dm.deliver(output, osz);
125+
126+
delete [] output;
123127
} while(!done);
124128
}
125129

src/include/ddptensor/DDPTensorImpl.hpp

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ class DDPTensorImpl : public tensor_i
2929
PVSlice _slice;
3030
void * _allocated;
3131
void * _aligned;
32+
intptr_t * _sizes;
33+
intptr_t * _strides;
3234
uint64_t _offset;
3335
DTypeId _dtype;
3436

@@ -42,13 +44,17 @@ class DDPTensorImpl : public tensor_i
4244
: _owner(owner),
4345
_slice(shape_type(rank ? rank : 1, rank ? sizes[0] : 1), static_cast<int>(owner==REPLICATED ? NOSPLIT : 0)),
4446
_allocated(allocated),
45-
_aligned(nullptr),
47+
_aligned(aligned),
48+
_sizes(new intptr_t[rank]),
49+
_strides(new intptr_t[rank]),
4650
_offset(offset),
4751
_dtype(dtype)
4852
{
4953
assert(rank <= 1);
5054
assert(rank == 0 || strides[0] == 1);
51-
dispatch(_dtype, aligned, [this](auto * ptr) { this->_aligned = ptr + this->_offset; });
55+
56+
memcpy(_sizes, sizes, rank*sizeof(intptr_t));
57+
memcpy(_strides, strides, rank*sizeof(intptr_t));
5258
}
5359

5460
DDPTensorImpl(DTypeId dtype, const shape_type & shp, rank_type owner=NOOWNER)
@@ -60,23 +66,38 @@ class DDPTensorImpl : public tensor_i
6066
_dtype(dtype)
6167
{
6268
alloc();
69+
70+
intptr_t stride = 1;
71+
auto rank = shp.size();
72+
for(auto i=0; i<rank; ++i) {
73+
_sizes[i] = shp[i];
74+
_strides[rank-i-1] = stride;
75+
stride *= shp[i];
76+
}
6377
}
6478

6579
void alloc()
6680
{
6781
auto esz = sizeof_dtype(_dtype);
6882
_allocated = new (std::align_val_t(esz)) char[esz*_slice.size()];
6983
_aligned = _allocated;
84+
auto rank = _slice.ndims();
85+
_sizes = new intptr_t[rank];
86+
_strides = new intptr_t[rank];
7087
_offset = 0;
7188
}
7289

7390
~DDPTensorImpl()
7491
{
92+
delete [] _sizes;
93+
delete [] _strides;
7594
}
7695

7796
void * data()
7897
{
79-
return _aligned;
98+
void * ret;
99+
dispatch(_dtype, _aligned, [this, &ret](auto * ptr) { ret = ptr + this->_offset; });
100+
return ret;
80101
}
81102

82103
bool is_sliced() const
@@ -90,7 +111,8 @@ class DDPTensorImpl : public tensor_i
90111
const auto sz = _slice.size();
91112
std::ostringstream oss;
92113

93-
dispatch(_dtype, _aligned, [sz, &oss](auto * ptr) {
114+
dispatch(_dtype, _aligned, [this, sz, &oss](auto * ptr) {
115+
ptr += this->_offset;
94116
for(auto i=0; i<sz; ++i) {
95117
oss << ptr[i] << " ";
96118
}
@@ -127,7 +149,7 @@ class DDPTensorImpl : public tensor_i
127149
throw(std::runtime_error("Cast to scalar bool: tensor is not replicated"));
128150

129151
bool res;
130-
dispatch(_dtype, _aligned, [&res](auto * ptr) { res = static_cast<bool>(*ptr); });
152+
dispatch(_dtype, _aligned, [this, &res](auto * ptr) { res = static_cast<bool>(ptr[this->_offset]); });
131153
return res;
132154
}
133155

@@ -137,7 +159,7 @@ class DDPTensorImpl : public tensor_i
137159
throw(std::runtime_error("Cast to scalar float: tensor is not replicated"));
138160

139161
double res;
140-
dispatch(_dtype, _aligned, [&res](auto * ptr) { res = static_cast<double>(*ptr); });
162+
dispatch(_dtype, _aligned, [this, &res](auto * ptr) { res = static_cast<double>(ptr[this->_offset]); });
141163
return res;
142164
}
143165

@@ -147,7 +169,7 @@ class DDPTensorImpl : public tensor_i
147169
throw(std::runtime_error("Cast to scalar int: tensor is not replicated"));
148170

149171
float res;
150-
dispatch(_dtype, _aligned, [&res](auto * ptr) { res = static_cast<float>(*ptr); });
172+
dispatch(_dtype, _aligned, [this, &res](auto * ptr) { res = static_cast<float>(ptr[this->_offset]); });
151173
return res;
152174
}
153175

@@ -198,7 +220,18 @@ class DDPTensorImpl : public tensor_i
198220
auto sz = _slice.size()*item_size();
199221
buff.resize(pos + sz);
200222
void * out = buff.data() + pos;
201-
memcpy(out, _aligned, sz);
223+
dispatch(_dtype, _aligned, [this, sz, out](auto * ptr) { memcpy(out, ptr + this->_offset, sz); });
224+
}
225+
226+
virtual uint64_t store_memref(intptr_t * buff, int rank)
227+
{
228+
assert(rank == _slice.ndims() || (_slice.ndims() == 1 && _slice.size() == 1));
229+
buff[0] = reinterpret_cast<intptr_t>(_allocated);
230+
buff[1] = reinterpret_cast<intptr_t>(_aligned);
231+
buff[2] = static_cast<intptr_t>(_offset);
232+
memcpy(buff+3, _sizes, rank*sizeof(intptr_t));
233+
memcpy(buff+3+rank, _strides, rank*sizeof(intptr_t));
234+
return 3 + 2*rank;
202235
}
203236
};
204237

src/include/ddptensor/jit/mlir.hpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,12 @@ class DepManager
3333
private:
3434
using IdValueMap = std::unordered_map<id_type, std::pair<::mlir::Value, SetResFunc>>;
3535
using IdRankMap = std::unordered_map<id_type, int>;
36+
using ArgList = std::vector<std::pair<id_type, int>>;
37+
3638
::mlir::func::FuncOp & _func; // MLIR function to which ops are added
3739
IdValueMap _ivm; // guid -> {mlir::Value, deliver-callback}
3840
IdRankMap _irm; // guid -> rank as computed in MLIR
39-
std::vector<id_type> _args; // input args to generated function
41+
ArgList _args; // input arguments of the generated function
4042

4143
public:
4244
DepManager(::mlir::func::FuncOp & f)
@@ -55,11 +57,19 @@ class DepManager
5557
void drop(id_type guid);
5658

5759
/// create return statement and add results to function
60+
/// this must be called after store_inputs
5861
/// @return size of output in number of intptr_t's
5962
uint64_t handleResult(::mlir::OpBuilder & builder);
6063

6164
/// devlier promise after execution
6265
void deliver(intptr_t *, uint64_t);
66+
67+
/// @return total size of all input arguments in number of intptr_t
68+
uint64_t arg_size();
69+
70+
/// store all inputs into given buffer
71+
/// This must be called before handleResults()
72+
std::vector<void*> store_inputs();
6373
};
6474

6575
// A class to manage the MLIR business (compilation and execution).
@@ -77,7 +87,7 @@ class JIT {
7787

7888
JIT();
7989
// run
80-
int run(::mlir::ModuleOp &, const std::string &, void *);
90+
int run(::mlir::ModuleOp &, const std::string &, std::vector<void*> &, intptr_t *);
8191

8292
::mlir::MLIRContext _context;
8393
::mlir::PassManager _pm;

src/include/ddptensor/tensor_i.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@ class tensor_i
6969
virtual void bufferize(const NDSlice & slice, Buffer & buff) const = 0;
7070
// size of a single element (in bytes)
7171
virtual int item_size() const = 0;
72+
// store tensor information in form of coreesponding jit::JIT::MemRefDescriptor
73+
// @return stored size in number of intptr_t
74+
virtual uint64_t store_memref(intptr_t * buff, int rank) = 0;
7275
};
7376

7477
#if 0

src/jit/mlir.cpp

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -108,20 +108,48 @@ ::mlir::Value DepManager::getDependent(::mlir::OpBuilder & builder, id_type guid
108108
if(auto d = _ivm.find(guid); d == _ivm.end()) {
109109
// Not found -> this must be an input argument to the jit function
110110
auto idx = _args.size();
111-
auto fut = Registry::get(d->first);
111+
auto fut = Registry::get(guid);
112112
auto typ = getPTType(builder, fut.dtype(), fut.rank());
113113
_func.insertArgument(idx, typ, {}, loc);
114114
auto val = _func.getArgument(idx);
115-
_args.push_back(guid);
115+
_args.push_back({guid, fut.rank()});
116116
_ivm[guid] = {val, {}};
117117
return val;
118118
} else {
119119
return d->second.first;
120120
}
121121
}
122122

123+
// size of memreftype in number of intptr_t's
124+
static inline uint64_t memref_sz(int rank) { return 3 + 2 * rank; }
125+
126+
uint64_t DepManager::arg_size()
127+
{
128+
uint64_t sz = 0;
129+
for(auto a : _args) {
130+
sz += memref_sz(a.second);
131+
}
132+
return sz;
133+
}
134+
135+
std::vector<void*> DepManager::store_inputs()
136+
{
137+
std::vector<void*> res(_args.size());
138+
int i = 0;
139+
for(auto a : _args) {
140+
auto f = Registry::get(a.first);
141+
intptr_t * buff = new intptr_t[memref_sz(a.second)];
142+
auto sz = f.get().get()->store_memref(buff, a.second);
143+
res[i] = buff;
144+
_ivm.erase(a.first); // inputs need no delivery
145+
++i;
146+
}
147+
return res;
148+
}
149+
123150
void DepManager::addVal(id_type guid, ::mlir::Value val, SetResFunc cb)
124151
{
152+
assert(_ivm.find(guid) == _ivm.end());
125153
_ivm[guid] = {val, cb};
126154
}
127155

@@ -158,7 +186,7 @@ uint64_t DepManager::handleResult(::mlir::OpBuilder & builder)
158186
auto rank = ptt.getRtensor().getShape().size();
159187
_irm[v.first] = rank;
160188
// add sizeof(MemRefDescriptor<elementtype, rank>) to sz
161-
sz += 3 + 2 * rank;
189+
sz += memref_sz(rank);
162190
++idx;
163191
}
164192

@@ -179,16 +207,17 @@ void DepManager::deliver(intptr_t * output, uint64_t sz)
179207
intptr_t offset = output[pos+2];
180208
intptr_t * sizes = output + pos + 3;
181209
intptr_t * stride = output + pos + 3 + rank;
182-
pos += 3 + 2 * rank;
210+
pos += memref_sz(rank);
183211
v.second.second(rank, allocated, aligned, offset, sizes, stride);
184212
}
185213
}
186214

187-
int JIT::run(::mlir::ModuleOp & module, const std::string & fname, void * out)
215+
int JIT::run(::mlir::ModuleOp & module, const std::string & fname, std::vector<void*> & inp, intptr_t * out)
188216
{
189217
if (::mlir::failed(_pm.run(module)))
190218
throw std::runtime_error("failed to run pass manager");
191219

220+
module.dump();
192221
// An optimization pipeline to use within the execution engine.
193222
auto optPipeline = ::mlir::makeOptimizingTransformer(/*optLevel=*/0,
194223
/*sizeLevel=*/0,
@@ -202,12 +231,20 @@ int JIT::run(::mlir::ModuleOp & module, const std::string & fname, void * out)
202231
assert(maybeEngine && "failed to construct an execution engine");
203232
auto &engine = maybeEngine.get();
204233

205-
206234
const char * fn = getenv("DDPT_FN");
207235
if(!fn) fn = fname.c_str();
208236

237+
llvm::SmallVector<void *> args;
238+
// first arg must be the result ptr
239+
args.push_back(&out);
240+
// we need a void*& for every input tensor
241+
// we refer directly to the storage in inp
242+
for(auto & arg : inp) {
243+
args.push_back(&arg);
244+
}
245+
209246
// Invoke the JIT-compiled function.
210-
if(engine->invoke(fn, ::mlir::ExecutionEngine::result(out))) {
247+
if(engine->invokePacked(std::string("_mlir_ciface_") + fn, args)) {
211248
::llvm::errs() << "JIT invocation failed\n";
212249
throw std::runtime_error("JIT invocation failed");
213250
}
@@ -243,6 +280,7 @@ JIT::JIT()
243280

244281
void init()
245282
{
283+
assert(sizeof(intptr_t) == sizeof(void*));
246284
::mlir::registerAllPasses();
247285
::imex::registerAllPasses();
248286

0 commit comments

Comments
 (0)