33#include " ddptensor/x.hpp"
44#include " ddptensor/Factory.hpp"
55
6+ #include < imex/Dialect/PTensor/IR/PTensorOps.h>
7+ #include < mlir/IR/Builders.h>
8+
69namespace x {
710
811 class ReduceOp
@@ -65,6 +68,33 @@ namespace x {
6568 };
6669} // namespace x
6770
71+
72+ // convert id of our reduction op to id of imex::ptensor reduction op
73+ static ::imex::ptensor::ReduceOpId ddpt2mlir (const ReduceOpId rop)
74+ {
75+ #pragma GCC diagnostic ignored "-Wswitch"
76+ switch (rop) {
77+ case MEAN:
78+ return ::imex::ptensor::MEAN;
79+ case PROD:
80+ return ::imex::ptensor::PROD;
81+ case SUM:
82+ return ::imex::ptensor::SUM;
83+ case STD:
84+ return ::imex::ptensor::STD;
85+ case VAR:
86+ return ::imex::ptensor::VAR;
87+ case MAX:
88+ return ::imex::ptensor::MAX;
89+ case MIN:
90+ return ::imex::ptensor::MIN;
91+ default :
92+ throw std::runtime_error (" Unknown reduction operation" );
93+ }
94+ }
95+
96+ #pragma GCC diagnostic pop
97+
6898struct DeferredReduceOp : public Deferred
6999{
70100 id_type _a;
@@ -78,8 +108,32 @@ struct DeferredReduceOp : public Deferred
78108
79109 void run ()
80110 {
111+ #if 0
81112 const auto a = std::move(Registry::get(_a).get());
82113 set_value(std::move(TypeDispatch<x::ReduceOp>(a, _op, _dim)));
114+ #endif
115+ }
116+
117+ ::mlir::Value generate_mlir (::mlir::OpBuilder & builder, ::mlir::Location loc, jit::IdValueMap & ivm) override
118+ {
119+ // FIXME the type of the result is hard-coded to uint64_t
120+ // FIXME reduction over individual dimensions is not supported
121+ auto a = ivm[_a].first ;
122+ auto ptt = a.getType ().dyn_cast <::imex::ptensor::PTensorType>();
123+ assert (ptt);
124+
125+ auto rtyp = ::imex::ptensor::PTensorType::get (
126+ builder.getContext (),
127+ ::mlir::RankedTensorType::get (llvm::SmallVector<int64_t >(), ptt.getRtensor().getElementType()),
128+ true
129+ );
130+ auto rop = builder.create <::imex::ptensor::ReductionOp>(loc, rtyp, builder.getI32IntegerAttr (ddpt2mlir (_op)), a);
131+ auto setter = [this ](uint64_t rank, void *allocated, void *aligned, intptr_t offset, intptr_t * sizes, intptr_t * strides) {
132+ // FIXME GC assert(allocated == aligned);
133+ this ->set_value (std::move (x::operatorx<uint64_t >::mk_tx (rank, allocated, aligned, offset, sizes, strides)));
134+ };
135+ ivm[_guid] = {rop, setter};
136+ return rop;
83137 }
84138
85139 FactoryId factory () const
0 commit comments