Skip to content
This repository was archived by the owner on Jan 26, 2026. It is now read-only.

Commit c22137c

Browse files
committed
initial jit for reduction
1 parent 035a369 commit c22137c

File tree

4 files changed

+60
-1
lines changed

4 files changed

+60
-1
lines changed

src/Creator.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ struct DeferredArange : public Deferred
153153

154154
::mlir::Value generate_mlir(::mlir::OpBuilder & builder, ::mlir::Location loc, jit::IdValueMap & ivm) override
155155
{
156+
// FIXME the type of the result is hard-coded to uint64_t
156157
// create start, stop and step
157158
auto start = jit::createI64(loc, builder, _start);
158159
auto end = jit::createI64(loc, builder, _end);

src/EWBinOp.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,7 @@ namespace x {
372372
};
373373
} // namespace x
374374

375+
// convert id of our binop to id of imex::ptensor binop
375376
static ::imex::ptensor::EWBinOpId ddpt2mlir(const EWBinOpId bop)
376377
{
377378
switch(bop) {
@@ -456,7 +457,7 @@ struct DeferredEWBinOp : public Deferred
456457

457458
::mlir::Value generate_mlir(::mlir::OpBuilder & builder, ::mlir::Location loc, jit::IdValueMap & ivm) override
458459
{
459-
// FIXME compute the type of the result (inputs can be heterogeneous)
460+
// FIXME the type of the result is hard-coded to uint64_t
460461
auto rtyp = ivm[_a].first.getType();
461462
auto ewbo = builder.create<::imex::ptensor::EWBinOp>(loc, rtyp, builder.getI32IntegerAttr(ddpt2mlir(_op)), ivm[_a].first, ivm[_b].first);
462463
auto setter = [this](uint64_t rank, void *allocated, void *aligned, intptr_t offset, intptr_t * sizes, intptr_t * strides) {

src/ReduceOp.cpp

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
#include "ddptensor/x.hpp"
44
#include "ddptensor/Factory.hpp"
55

6+
#include <imex/Dialect/PTensor/IR/PTensorOps.h>
7+
#include <mlir/IR/Builders.h>
8+
69
namespace x {
710

811
class ReduceOp
@@ -65,6 +68,33 @@ namespace x {
6568
};
6669
} // namespace x
6770

71+
72+
// convert id of our reduction op to id of imex::ptensor reduction op
73+
static ::imex::ptensor::ReduceOpId ddpt2mlir(const ReduceOpId rop)
74+
{
75+
#pragma GCC diagnostic ignored "-Wswitch"
76+
switch(rop) {
77+
case MEAN:
78+
return ::imex::ptensor::MEAN;
79+
case PROD:
80+
return ::imex::ptensor::PROD;
81+
case SUM:
82+
return ::imex::ptensor::SUM;
83+
case STD:
84+
return ::imex::ptensor::STD;
85+
case VAR:
86+
return ::imex::ptensor::VAR;
87+
case MAX:
88+
return ::imex::ptensor::MAX;
89+
case MIN:
90+
return ::imex::ptensor::MIN;
91+
default:
92+
throw std::runtime_error("Unknown reduction operation");
93+
}
94+
}
95+
96+
#pragma GCC diagnostic pop
97+
6898
struct DeferredReduceOp : public Deferred
6999
{
70100
id_type _a;
@@ -78,8 +108,32 @@ struct DeferredReduceOp : public Deferred
78108

79109
void run()
80110
{
111+
#if 0
81112
const auto a = std::move(Registry::get(_a).get());
82113
set_value(std::move(TypeDispatch<x::ReduceOp>(a, _op, _dim)));
114+
#endif
115+
}
116+
117+
::mlir::Value generate_mlir(::mlir::OpBuilder & builder, ::mlir::Location loc, jit::IdValueMap & ivm) override
118+
{
119+
// FIXME the type of the result is hard-coded to uint64_t
120+
// FIXME reduction over individual dimensions is not supported
121+
auto a = ivm[_a].first;
122+
auto ptt = a.getType().dyn_cast<::imex::ptensor::PTensorType>();
123+
assert(ptt);
124+
125+
auto rtyp = ::imex::ptensor::PTensorType::get(
126+
builder.getContext(),
127+
::mlir::RankedTensorType::get(llvm::SmallVector<int64_t>(), ptt.getRtensor().getElementType()),
128+
true
129+
);
130+
auto rop = builder.create<::imex::ptensor::ReductionOp>(loc, rtyp, builder.getI32IntegerAttr(ddpt2mlir(_op)), a);
131+
auto setter = [this](uint64_t rank, void *allocated, void *aligned, intptr_t offset, intptr_t * sizes, intptr_t * strides) {
132+
// FIXME GC assert(allocated == aligned);
133+
this->set_value(std::move(x::operatorx<uint64_t>::mk_tx(rank, allocated, aligned, offset, sizes, strides)));
134+
};
135+
ivm[_guid] = {rop, setter};
136+
return rop;
83137
}
84138

85139
FactoryId factory() const

src/include/ddptensor/x.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,9 @@ namespace x
268268
static typename DPTensorX<T>::typed_ptr_type mk_tx(uint64_t rank, void *allocated, void *aligned, intptr_t offset, intptr_t * sizes, intptr_t * strides)
269269
{
270270
// FIXME strides/slices are not used
271+
if(rank == 0) {
272+
return std::make_shared<DPTensorX<T>>(static_cast<T>(*reinterpret_cast<T*>(aligned)+offset));
273+
}
271274
shape_type shp(rank);
272275
for(int i = 0; i < rank; ++i) {
273276
shp[i] = sizes[i];

0 commit comments

Comments
 (0)