Skip to content
This repository was archived by the owner on Jan 26, 2026. It is now read-only.

Commit abe5cb2

Browse files
committed
arange working through deferred and jit (access possibly only after fini()
1 parent 3f96772 commit abe5cb2

File tree

7 files changed

+110
-42
lines changed

7 files changed

+110
-42
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ set(MyCppSources ${MyCppSources} ${PROJECT_SOURCE_DIR}/src/jit/mlir.cpp ${P2C_HP
8888
pybind11_add_module(_ddptensor MODULE ${MyCppSources})
8989

9090
target_compile_definitions(_ddptensor PRIVATE XTENSOR_USE_XSIMD=1 XTENSOR_USE_TBB=1 USE_MKL=1 DDPT_2TYPES=1)
91+
target_compile_options(_ddptensor PRIVATE "-ftemplate-backtrace-limit=0")
9192
target_include_directories(_ddptensor PRIVATE
9293
${PROJECT_SOURCE_DIR}/src/include
9394
${PROJECT_SOURCE_DIR}/third_party/xtl/include

src/Creator.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,10 @@ namespace x {
5555
PVSlice pvslice({static_cast<uint64_t>(Slice(start, end, step).size())});
5656
auto lslc = pvslice.local_slice();
5757
const auto & l1dslc = lslc.dim(0);
58+
5859
auto a = xt::arange<T>(start + l1dslc._start*step, start + l1dslc._end * step, l1dslc._step);
5960
auto r = operatorx<T>::mk_tx(std::move(pvslice), std::move(a));
61+
6062
return r;
6163
}
6264
}; // class creatorx
@@ -160,7 +162,14 @@ struct DeferredArange : public Deferred
160162
llvm::SmallVector<int64_t> shape(1, -1); //::mlir::ShapedType::kDynamicSize);
161163
auto artype = ::imex::ptensor::PTensorType::get(builder.getContext(), ::mlir::RankedTensorType::get(shape, dtype), true);
162164
auto ar = builder.create<::imex::ptensor::ARangeOp>(loc, artype, start, end, step, true);
163-
ivm[_guid] = ar;
165+
auto setter = [this](uint64_t rank, void *allocated, void *aligned, intptr_t offset, intptr_t * sizes, intptr_t * strides) {
166+
// FIXME GC assert(allocated == aligned);
167+
assert(rank == 1);
168+
assert(strides[0] == 1);
169+
shape_type shape(1, sizes[0]);
170+
this->set_value(std::move(x::operatorx<uint64_t>::mk_tx(shape, reinterpret_cast<uint64_t*>(aligned)+offset)));
171+
};
172+
ivm[_guid] = {ar, setter};
164173
return ar;
165174
}
166175

src/Deferred.cpp

Lines changed: 68 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
#include <imex/Dialect/PTensor/IR/PTensorOps.h>
1111
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
1212

13+
#include <iostream>
14+
1315
static tbb::concurrent_bounded_queue<Runable::ptr_type> _deferred;
1416

1517
void push_runable(Runable::ptr_type && r)
@@ -63,42 +65,95 @@ void process_promises()
6365
::mlir::OpBuilder builder(&jit._context);
6466
auto loc = builder.getUnknownLoc();
6567
jit::IdValueMap ivp;
66-
::mlir::Value ret_value;
6768

6869
// Create a MLIR module
6970
auto module = builder.create<::mlir::ModuleOp>(loc);
7071
// Create a func
7172
auto dtype = builder.getI64Type();
72-
llvm::SmallVector<int64_t> shape(1, -1); //::mlir::ShapedType::kDynamicSize);
73-
auto rrtype = ::imex::ptensor::PTensorType::get(builder.getContext(), ::mlir::RankedTensorType::get(shape, dtype), true); // llvm::SmallVector<int64_t>()
74-
auto funcType = builder.getFunctionType({}, rrtype);
73+
// create dummy type, we'll replace it with the actual type later
74+
auto dummyFuncType = builder.getFunctionType({}, dtype);
7575
std::string fname("ddpt_jit");
76-
auto function = builder.create<::mlir::func::FuncOp>(loc, fname, funcType);
77-
// request generation of c-wrapper function
78-
function->setAttr(::mlir::LLVM::LLVMDialect::getEmitCWrapperAttrName(), ::mlir::UnitAttr::get(&jit._context));
76+
auto function = builder.create<::mlir::func::FuncOp>(loc, fname, dummyFuncType);
7977
// create function entry block
8078
auto &entryBlock = *function.addEntryBlock();
8179
// Set the insertion point in the builder to the beginning of the function body
8280
builder.setInsertionPointToStart(&entryBlock);
83-
81+
// we need to keep runables/deferred/futures alive until we set their values below
82+
std::vector<Runable::ptr_type> runables;
83+
8484
while(true) {
8585
Runable::ptr_type d;
8686
_deferred.pop(d);
8787
if(d) {
88-
d->run();
89-
ret_value = d->generate_mlir(builder, loc, ivp);
90-
d.reset();
88+
// d->run();
89+
(void) d->generate_mlir(builder, loc, ivp);
90+
// keep alive for later set_value
91+
runables.push_back(std::move(d));
92+
//d.reset();
9193
} else {
9294
break;
9395
}
9496
}
9597

96-
(void)builder.create<::mlir::func::ReturnOp>(loc, ret_value);
98+
// Now we have to define the return type as a ValueRange of all arrays which we have created
99+
// (runnables have put them into ivp)
100+
// We also compute the total size of the struct llvm created for this return type
101+
// llvm will basically return a struct with all the arrays as members, each of type JIT::MemRefDescriptor
102+
103+
// Need a container to put all return values, will be used to construct TypeRange
104+
std::vector<::mlir::Type> ret_types;
105+
ret_types.reserve(ivp.size());
106+
std::vector<::mlir::Value> ret_values;
107+
ret_types.reserve(ivp.size());
108+
std::unordered_map<id_type, uint64_t> rank_map;
109+
// here we store the total size of the llvm struct
110+
uint64_t sz = 0;
111+
for(auto & v : ivp) {
112+
auto value = v.second.first;
113+
// append the type and array/value
114+
ret_types.push_back(value.getType());
115+
ret_values.push_back(value);
116+
auto ptt = value.getType().dyn_cast<::imex::ptensor::PTensorType>();
117+
assert(ptt);
118+
auto rank = ptt.getRtensor().getShape().size();
119+
rank_map[v.first] = rank;
120+
// add sizeof(MemRefDescriptor<elementtype, rank>) to sz
121+
sz += 3 + 2 * rank;
122+
}
123+
::mlir::TypeRange ret_tr(ret_types);
124+
::mlir::ValueRange ret_vr(ret_values);
125+
126+
// add return statement
127+
auto ret_value = builder.create<::mlir::func::ReturnOp>(loc, ret_vr);
128+
// Define and assign correct function type
129+
auto funcTypeAttr = ::mlir::TypeAttr::get(builder.getFunctionType({}, ret_tr));
130+
function.setFunctionTypeAttr(funcTypeAttr);
131+
// also request generation of c-wrapper function
132+
function->setAttr(::mlir::LLVM::LLVMDialect::getEmitCWrapperAttrName(), ::mlir::UnitAttr::get(&jit._context));
97133
// add the function to the module
98134
module.push_back(function);
99135
module.dump();
100136
// finally compile and run the module
101-
if(jit.run(module, fname)) throw std::runtime_error("failed running jit");
137+
assert(sizeof(intptr_t) == sizeof(void*));
138+
intptr_t * output = new intptr_t[sz];
139+
std::cout << ivp.size() << " sz: " << sz << std::endl;
140+
if(jit.run(module, fname, output)) throw std::runtime_error("failed running jit");
141+
142+
size_t pos = 0;
143+
for(auto & v : ivp) {
144+
auto value = v.second.first;
145+
auto rank = rank_map[v.first];
146+
void * allocated = (void*)output[pos];
147+
void * aligned = (void*)output[pos+1];
148+
intptr_t offset = output[pos+2];
149+
intptr_t * sizes = output + pos + 3;
150+
intptr_t * stride = output + pos + 3 + rank;
151+
pos += 3 + 2 * rank;
152+
v.second.second(rank, allocated, aligned, offset, sizes, stride);
153+
}
154+
155+
// finally release all our runables/tasks/deferred/futures
156+
runables.clear();
102157
}
103158

104159
void sync()

src/include/ddptensor/Deferred.hpp

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,11 @@ struct Runable
1616
/// actually execute, a deferred will set value of future
1717
virtual void run() = 0;
1818
/// generate MLIR code for jit
19-
virtual ::mlir::Value generate_mlir(::mlir::OpBuilder & builder, ::mlir::Location, jit::IdValueMap & ivm) = 0;
19+
virtual ::mlir::Value generate_mlir(::mlir::OpBuilder &, ::mlir::Location, jit::IdValueMap &)
20+
{
21+
throw(std::runtime_error("No MLIR support for this operation."));
22+
return {};
23+
};
2024
virtual FactoryId factory() const = 0;
2125
virtual void defer(ptr_type &&);
2226
static void fini();
@@ -33,12 +37,6 @@ struct DeferredT : public P, public Runable
3337

3438
DeferredT() = default;
3539
DeferredT(const DeferredT<P, F> &) = delete;
36-
37-
// FIXME: from Runable but should be in most derived classes
38-
::mlir::Value generate_mlir(::mlir::OpBuilder &, ::mlir::Location, jit::IdValueMap &) override
39-
{
40-
return {};
41-
};
4240
};
4341

4442
struct Deferred : public DeferredT<tensor_i::promise_type, tensor_i::future_type>
@@ -109,13 +107,6 @@ struct DeferredLambda : public Runable
109107
_l();
110108
}
111109

112-
// FIXME: from Runable but should be in most derived classes
113-
::mlir::Value generate_mlir(::mlir::OpBuilder &, ::mlir::Location, jit::IdValueMap &) override
114-
{
115-
throw(std::runtime_error("No MLIR support for DeferredLambda."));
116-
return {};
117-
};
118-
119110
FactoryId factory() const
120111
{
121112
throw(std::runtime_error("No Factory for DeferredLambda."));

src/include/ddptensor/jit/mlir.hpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,15 @@
1010
#include <mlir/IR/Builders.h>
1111

1212
#include <unordered_map>
13+
#include <functional>
14+
#include <utility>
1315

1416
namespace jit {
1517

16-
using IdValueMap = std::unordered_map<id_type, ::mlir::Value>;
18+
// function type for building body for linalg::generic
19+
using SetResFunc = std::function<void(
20+
uint64_t rank, void *allocated, void *aligned, intptr_t offset, intptr_t * sizes, intptr_t * strides)>;
21+
using IdValueMap = std::unordered_map<id_type, std::pair<::mlir::Value, SetResFunc>>;
1722

1823
// initialize jit
1924
void init();
@@ -38,7 +43,7 @@ class JIT {
3843

3944
JIT();
4045
// run
41-
int run(::mlir::ModuleOp &, const std::string &);
46+
int run(::mlir::ModuleOp &, const std::string &, void *);
4247

4348
::mlir::MLIRContext _context;
4449
::mlir::PassManager _pm;

src/include/ddptensor/x.hpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,13 +74,20 @@ namespace x
7474
}
7575

7676
template<typename I>
77-
DPTensorX(shape_type && slc, I && ax, rank_type owner=NOOWNER)
77+
DPTensorX(shape_type && shp, I && ax, rank_type owner=NOOWNER)
7878
: _owner(owner),
79-
_slice(std::move(slc), static_cast<int>(owner==REPLICATED ? NOSPLIT : 0)),
79+
_slice(std::move(shp), static_cast<int>(owner==REPLICATED ? NOSPLIT : 0)),
8080
_xarray(std::make_shared<xt::xarray<T>>(std::forward<I>(ax)))
8181
{
8282
}
8383

84+
DPTensorX(const shape_type & shp, T * ptr, rank_type owner=NOOWNER)
85+
: _owner(owner),
86+
_slice(shp, static_cast<int>(owner==REPLICATED ? NOSPLIT : 0)),
87+
_xarray(std::make_shared<xt::xarray<T>>(xt::adapt(ptr, VPROD(shp), xt::no_ownership(), shp)))
88+
{
89+
}
90+
8491
DPTensorX(const shape_type & shp, rank_type owner=NOOWNER)
8592
: _owner(owner),
8693
_slice(shp, static_cast<int>(owner==REPLICATED ? NOSPLIT : 0)),

src/jit/mlir.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ ::mlir::Value createI64(const ::mlir::Location & loc, ::mlir::OpBuilder & builde
6363
return builder.create<::mlir::arith::ConstantOp>(loc, attr).getResult();
6464
}
6565

66-
int JIT::run(::mlir::ModuleOp & module, const std::string & fname)
66+
int JIT::run(::mlir::ModuleOp & module, const std::string & fname, void * out)
6767
{
6868
if (::mlir::failed(_pm.run(module)))
6969
throw std::runtime_error("failed to run pass manager");
@@ -85,16 +85,11 @@ int JIT::run(::mlir::ModuleOp & module, const std::string & fname)
8585
const char * fn = getenv("DDPT_FN");
8686
if(!fn) fn = fname.c_str();
8787

88-
MemRefDescriptor<int64_t, 1> result;
89-
auto r_ptr = &result;
90-
// int64_t arg = 7;
9188
// Invoke the JIT-compiled function.
92-
if(engine->invoke(fn, ::mlir::ExecutionEngine::result(r_ptr))) {
89+
if(engine->invoke(fn, ::mlir::ExecutionEngine::result(out))) {
9390
::llvm::errs() << "JIT invocation failed\n";
9491
throw std::runtime_error("JIT invocation failed");
9592
}
96-
std::cout << "aptr=" << result.allocated << " dptr=" << result.aligned << " offset=" << result.offset << std::endl;
97-
std::cout << ((int64_t*)result.aligned)[result.offset] << std::endl;
9893

9994
return 0;
10095
}
@@ -120,8 +115,8 @@ JIT::JIT()
120115
if(::mlir::failed(::mlir::parsePassPipeline(pass_pipeline, _pm)))
121116
throw std::runtime_error("failed to parse pass pipeline");
122117
// some verbosity
123-
_pm.enableStatistics();
124-
_pm.enableIRPrinting();
118+
// _pm.enableStatistics();
119+
// _pm.enableIRPrinting();
125120
_pm.dump();
126121
}
127122

@@ -196,8 +191,13 @@ void ttt()
196191
// add the function to the module
197192
module.push_back(function);
198193

194+
JIT::MemRefDescriptor<int64_t, 1> result;
195+
void * r_ptr = &result;
199196
// finally compile and run the module
200-
if(jit.run(module, fname)) throw std::runtime_error("failed running jit");
197+
if(jit.run(module, fname, r_ptr)) throw std::runtime_error("failed running jit");
198+
199+
std::cout << "aptr=" << result.allocated << " dptr=" << result.aligned << " offset=" << result.offset << std::endl;
200+
std::cout << ((int64_t*)result.aligned)[result.offset] << std::endl;
201201
}
202202

203203
} // namespace jit

0 commit comments

Comments
 (0)