88#include " ddptensor/Creator.hpp"
99#include " ddptensor/DDPTensorImpl.hpp"
1010#include " ddptensor/Factory.hpp"
11+ #include " ddptensor/Registry.hpp"
1112#include " ddptensor/TypeDispatch.hpp"
1213
13- #if 0
14- namespace x {
14+ #include < imex/Dialect/Dist/IR/DistOps.h>
15+ #include < imex/Dialect/PTensor/IR/PTensorOps.h>
16+ #include < mlir/Dialect/Shape/IR/Shape.h>
17+ #include < mlir/IR/Builders.h>
18+ #include < mlir/IR/BuiltinTypeInterfaces.h>
1519
16- class IEWBinOp
17- {
18- public:
19- using ptr_type = DPTensorBaseX::ptr_type;
20-
21- template<typename A, typename B>
22- static ptr_type op(IEWBinOpId iop, std::shared_ptr<DPTensorX<A>> a_ptr, const std::shared_ptr<DPTensorX<B>> & b_ptr)
23- {
24- auto & ax = a_ptr->xarray();
25- const auto & bx = b_ptr->xarray();
26- if(a_ptr->is_sliced() || b_ptr->is_sliced()) {
27- auto av = xt::strided_view(ax, a_ptr->lslice());
28- const auto & bv = xt::strided_view(bx, b_ptr->lslice());
29- return do_op(iop, av, bv, a_ptr);
30- }
31- return do_op(iop, ax, bx, a_ptr);
32- }
33-
34- #pragma GCC diagnostic ignored "-Wswitch"
35- template<typename A, typename T1, typename T2>
36- static ptr_type do_op(IEWBinOpId iop, T1 & a, const T2 & b, std::shared_ptr<DPTensorX<A>> a_ptr)
37- {
38- switch(iop) {
39- case __IADD__:
40- a += b;
41- return a_ptr;
42- case __IFLOORDIV__:
43- a = xt::floor(a / b);
44- return a_ptr;
45- case __IMUL__:
46- a *= b;
47- return a_ptr;
48- case __ISUB__:
49- a -= b;
50- return a_ptr;
51- case __ITRUEDIV__:
52- a /= b;
53- return a_ptr;
54- case __IPOW__:
55- throw std::runtime_error("Binary inplace operation not implemented");
56- }
57- if constexpr (std::is_integral<typename T1::value_type>::value && std::is_integral<typename T2::value_type>::value) {
58- switch(iop) {
59- case __IMOD__:
60- a %= b;
61- return a_ptr;
62- case __IOR__:
63- a |= b;
64- return a_ptr;
65- case __IAND__:
66- a &= b;
67- return a_ptr;
68- case __IXOR__:
69- a ^= b;
70- case __ILSHIFT__:
71- a = xt::left_shift(a, b);
72- return a_ptr;
73- case __IRSHIFT__:
74- a = xt::right_shift(a, b);
75- return a_ptr;
76- }
77- }
78- throw std::runtime_error("Unknown/invalid inplace elementwise binary operation");
79- }
80- #pragma GCC diagnostic pop
81-
82- };
83- } // namespace x
84- #endif // if 0
20+ // convert id of our binop to id of imex::ptensor binop
21+ static ::imex::ptensor::EWBinOpId ddpt2mlir (const IEWBinOpId bop) {
22+ switch (bop) {
23+ case __IADD__:
24+ return ::imex::ptensor::ADD;
25+ case __IAND__:
26+ return ::imex::ptensor::BITWISE_AND;
27+ case __IFLOORDIV__:
28+ return ::imex::ptensor::FLOOR_DIVIDE;
29+ case __ILSHIFT__:
30+ return ::imex::ptensor::BITWISE_LEFT_SHIFT;
31+ case __IMOD__:
32+ return ::imex::ptensor::MODULO;
33+ case __IMUL__:
34+ return ::imex::ptensor::MULTIPLY;
35+ case __IOR__:
36+ return ::imex::ptensor::BITWISE_OR;
37+ case __IPOW__:
38+ return ::imex::ptensor::POWER;
39+ case __IRSHIFT__:
40+ return ::imex::ptensor::BITWISE_RIGHT_SHIFT;
41+ case __ISUB__:
42+ return ::imex::ptensor::SUBTRACT;
43+ case __ITRUEDIV__:
44+ return ::imex::ptensor::TRUE_DIVIDE;
45+ case __IXOR__:
46+ return ::imex::ptensor::BITWISE_XOR;
47+ default :
48+ throw std::runtime_error (
49+ " Unknown/invalid inplace elementwise binary operation" );
50+ }
51+ }
8552
8653struct DeferredIEWBinOp : public Deferred {
8754 id_type _a;
@@ -91,15 +58,45 @@ struct DeferredIEWBinOp : public Deferred {
9158 DeferredIEWBinOp () = default ;
9259 DeferredIEWBinOp (IEWBinOpId op, const tensor_i::future_type &a,
9360 const tensor_i::future_type &b)
94- : _a(a.id()), _b(b.id()), _op(op) {}
61+ : Deferred(a.dtype(), a.rank(), a.balanced()), _a(a.id()), _b(b.id()),
62+ _op (op) {}
63+
64+ bool generate_mlir (::mlir::OpBuilder &builder, ::mlir::Location loc,
65+ jit::DepManager &dm) override {
66+ // FIXME the type of the result is based on a only
67+ auto av = dm.getDependent (builder, _a);
68+ auto bv = dm.getDependent (builder, _b);
9569
96- void run () {
97- // const auto a = std::move(Registry::get(_a).get());
98- // const auto b = std::move(Registry::get(_b).get());
99- // set_value(std::move(TypeDispatch<x::IEWBinOp>(a, b, _op)));
70+ auto aTyp = ::imex::dist::getPTensorType (av);
71+ ::mlir::SmallVector<int64_t > shape (rank (), ::mlir::ShapedType::kDynamic );
72+ auto outTyp =
73+ ::imex::ptensor::PTensorType::get (shape, aTyp.getElementType());
74+
75+ auto binop = builder.create <::imex::ptensor::EWBinOp>(
76+ loc, outTyp, builder.getI32IntegerAttr (ddpt2mlir (_op)), av, bv);
77+ // insertsliceop has no return value, so we just create the op...
78+ auto zero = ::imex::createIndex (loc, builder, 0 );
79+ auto one = ::imex::createIndex (loc, builder, 1 );
80+ auto dyn = ::imex::createIndex (loc, builder, ::mlir::ShapedType::kDynamic );
81+ ::mlir::SmallVector<::mlir::Value> offs (rank (), zero);
82+ ::mlir::SmallVector<::mlir::Value> szs (rank (), dyn);
83+ ::mlir::SmallVector<::mlir::Value> strds (rank (), one);
84+ (void )builder.create <::imex::ptensor::InsertSliceOp>(loc, av, binop, offs,
85+ szs, strds);
86+ // ... and use av as to later create the ptensor
87+ dm.addVal (this ->guid (), av,
88+ [this ](Transceiver *transceiver, uint64_t rank, void *allocated,
89+ void *aligned, intptr_t offset, const intptr_t *sizes,
90+ const intptr_t *strides, uint64_t *gs_allocated,
91+ uint64_t *gs_aligned, uint64_t *lo_allocated,
92+ uint64_t *lo_aligned, uint64_t balanced) {
93+ this ->set_value (Registry::get (this ->_a ).get ());
94+ });
95+ return false ;
10096 }
10197
10298 FactoryId factory () const { return F_IEWBINOP; }
99+
103100 template <typename S> void serialize (S &ser) {
104101 ser.template value <sizeof (_a)>(_a);
105102 ser.template value <sizeof (_b)>(_b);
0 commit comments