@@ -13,6 +13,9 @@ using namespace xt::placeholders;
1313#include " ddptensor/CollComm.hpp"
1414#include " ddptensor/Chunker.hpp"
1515
16+ #include < imex/Dialect/PTensor/IR/PTensorOps.h>
17+ #include < mlir/IR/Builders.h>
18+
1619// #######################################################################################
1720// The 2 operators/tensors can have shifted partitions, e.g. local data might not be the
1821// same on a and b. This means we
@@ -366,10 +369,71 @@ namespace x {
366369 } else throw std::runtime_error (" Unknown/invalid elementwise binary operation" );
367370 }
368371 }
369- #pragma GCC diagnostic pop
370372 };
371373} // namespace x
372374
375+ static ::imex::ptensor::EWBinOpId ddpt2mlir (const EWBinOpId bop)
376+ {
377+ switch (bop) {
378+ case __ADD__:
379+ case ADD:
380+ case __RADD__:
381+ return ::imex::ptensor::ADD;
382+ case ATAN2:
383+ return ::imex::ptensor::ATAN2;
384+ case __FLOORDIV__:
385+ case FLOOR_DIVIDE:
386+ case __RFLOORDIV__:
387+ return ::imex::ptensor::FLOOR_DIVIDE;
388+ // __MATMUL__ is handled before dispatching, see below
389+ case __MUL__:
390+ case MULTIPLY:
391+ case __RMUL__:
392+ return ::imex::ptensor::MULTIPLY;
393+ case __SUB__:
394+ case SUBTRACT:
395+ case __RSUB__:
396+ return ::imex::ptensor::SUBTRACT;
397+ case __TRUEDIV__:
398+ case DIVIDE:
399+ case __RTRUEDIV__:
400+ return ::imex::ptensor::TRUE_DIVIDE;
401+ case __POW__:
402+ case POW:
403+ case __RPOW__:
404+ return ::imex::ptensor::POWER;
405+ case LOGADDEXP:
406+ return ::imex::ptensor::LOGADDEXP;
407+ case __LSHIFT__:
408+ case BITWISE_LEFT_SHIFT:
409+ case __RLSHIFT__:
410+ return ::imex::ptensor::BITWISE_LEFT_SHIFT;
411+ case __MOD__:
412+ case REMAINDER:
413+ case __RMOD__:
414+ return ::imex::ptensor::MODULO;
415+ case __RSHIFT__:
416+ case BITWISE_RIGHT_SHIFT:
417+ case __RRSHIFT__:
418+ return ::imex::ptensor::BITWISE_RIGHT_SHIFT;
419+ case __AND__:
420+ case BITWISE_AND:
421+ case __RAND__:
422+ return ::imex::ptensor::BITWISE_AND;
423+ case __OR__:
424+ case BITWISE_OR:
425+ case __ROR__:
426+ return ::imex::ptensor::BITWISE_OR;
427+ case __XOR__:
428+ case BITWISE_XOR:
429+ case __RXOR__:
430+ return ::imex::ptensor::BITWISE_XOR;
431+ default :
432+ throw std::runtime_error (" Unknown/invalid elementwise binary operation" );
433+ }
434+ }
435+ #pragma GCC diagnostic pop
436+
373437struct DeferredEWBinOp : public Deferred
374438{
375439 id_type _a;
@@ -381,11 +445,26 @@ struct DeferredEWBinOp : public Deferred
381445 : _a(a.id()), _b(b.id()), _op(op)
382446 {}
383447
384- void run ()
448+ void run () override
385449 {
450+ #if 0
386451 const auto a = std::move(Registry::get(_a).get());
387452 const auto b = std::move(Registry::get(_b).get());
388453 set_value(std::move(TypeDispatch<x::EWBinOp>(a, b, _op)));
454+ #endif
455+ }
456+
457+ ::mlir::Value generate_mlir (::mlir::OpBuilder & builder, ::mlir::Location loc, jit::IdValueMap & ivm) override
458+ {
459+ // FIXME compute the type of the result (inputs can be heterogeneous)
460+ auto rtyp = ivm[_a].first .getType ();
461+ auto ewbo = builder.create <::imex::ptensor::EWBinOp>(loc, rtyp, builder.getI32IntegerAttr (ddpt2mlir (_op)), ivm[_a].first , ivm[_b].first );
462+ auto setter = [this ](uint64_t rank, void *allocated, void *aligned, intptr_t offset, intptr_t * sizes, intptr_t * strides) {
463+ // FIXME GC assert(allocated == aligned);
464+ this ->set_value (std::move (x::operatorx<uint64_t >::mk_tx (rank, allocated, aligned, offset, sizes, strides)));
465+ };
466+ ivm[_guid] = {ewbo, setter};
467+ return ewbo;
389468 }
390469
391470 FactoryId factory () const
0 commit comments