Skip to content
This repository was archived by the owner on Jan 26, 2026. It is now read-only.

Commit 035a369

Browse files
committed
initial jit for ewbinops
1 parent abe5cb2 commit 035a369

File tree

3 files changed

+93
-5
lines changed

3 files changed

+93
-5
lines changed

src/Creator.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ struct DeferredArange : public Deferred
148148

149149
void run() override
150150
{
151-
set_value(std::move(TypeDispatch<x::Creator>(_dtype, _start, _end, _step)));
151+
// set_value(std::move(TypeDispatch<x::Creator>(_dtype, _start, _end, _step)));
152152
};
153153

154154
::mlir::Value generate_mlir(::mlir::OpBuilder & builder, ::mlir::Location loc, jit::IdValueMap & ivm) override
@@ -166,8 +166,7 @@ struct DeferredArange : public Deferred
166166
// FIXME GC assert(allocated == aligned);
167167
assert(rank == 1);
168168
assert(strides[0] == 1);
169-
shape_type shape(1, sizes[0]);
170-
this->set_value(std::move(x::operatorx<uint64_t>::mk_tx(shape, reinterpret_cast<uint64_t*>(aligned)+offset)));
169+
this->set_value(std::move(x::operatorx<uint64_t>::mk_tx(rank, allocated, aligned, offset, sizes, strides)));
171170
};
172171
ivm[_guid] = {ar, setter};
173172
return ar;

src/EWBinOp.cpp

Lines changed: 81 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ using namespace xt::placeholders;
1313
#include "ddptensor/CollComm.hpp"
1414
#include "ddptensor/Chunker.hpp"
1515

16+
#include <imex/Dialect/PTensor/IR/PTensorOps.h>
17+
#include <mlir/IR/Builders.h>
18+
1619
// #######################################################################################
1720
// The 2 operators/tensors can have shifted partitions, e.g. local data might not be the
1821
// same on a and b. This means we
@@ -366,10 +369,71 @@ namespace x {
366369
} else throw std::runtime_error("Unknown/invalid elementwise binary operation");
367370
}
368371
}
369-
#pragma GCC diagnostic pop
370372
};
371373
} // namespace x
372374

375+
static ::imex::ptensor::EWBinOpId ddpt2mlir(const EWBinOpId bop)
376+
{
377+
switch(bop) {
378+
case __ADD__:
379+
case ADD:
380+
case __RADD__:
381+
return ::imex::ptensor::ADD;
382+
case ATAN2:
383+
return ::imex::ptensor::ATAN2;
384+
case __FLOORDIV__:
385+
case FLOOR_DIVIDE:
386+
case __RFLOORDIV__:
387+
return ::imex::ptensor::FLOOR_DIVIDE;
388+
// __MATMUL__ is handled before dispatching, see below
389+
case __MUL__:
390+
case MULTIPLY:
391+
case __RMUL__:
392+
return ::imex::ptensor::MULTIPLY;
393+
case __SUB__:
394+
case SUBTRACT:
395+
case __RSUB__:
396+
return ::imex::ptensor::SUBTRACT;
397+
case __TRUEDIV__:
398+
case DIVIDE:
399+
case __RTRUEDIV__:
400+
return ::imex::ptensor::TRUE_DIVIDE;
401+
case __POW__:
402+
case POW:
403+
case __RPOW__:
404+
return ::imex::ptensor::POWER;
405+
case LOGADDEXP:
406+
return ::imex::ptensor::LOGADDEXP;
407+
case __LSHIFT__:
408+
case BITWISE_LEFT_SHIFT:
409+
case __RLSHIFT__:
410+
return ::imex::ptensor::BITWISE_LEFT_SHIFT;
411+
case __MOD__:
412+
case REMAINDER:
413+
case __RMOD__:
414+
return ::imex::ptensor::MODULO;
415+
case __RSHIFT__:
416+
case BITWISE_RIGHT_SHIFT:
417+
case __RRSHIFT__:
418+
return ::imex::ptensor::BITWISE_RIGHT_SHIFT;
419+
case __AND__:
420+
case BITWISE_AND:
421+
case __RAND__:
422+
return ::imex::ptensor::BITWISE_AND;
423+
case __OR__:
424+
case BITWISE_OR:
425+
case __ROR__:
426+
return ::imex::ptensor::BITWISE_OR;
427+
case __XOR__:
428+
case BITWISE_XOR:
429+
case __RXOR__:
430+
return ::imex::ptensor::BITWISE_XOR;
431+
default:
432+
throw std::runtime_error("Unknown/invalid elementwise binary operation");
433+
}
434+
}
435+
#pragma GCC diagnostic pop
436+
373437
struct DeferredEWBinOp : public Deferred
374438
{
375439
id_type _a;
@@ -381,11 +445,26 @@ struct DeferredEWBinOp : public Deferred
381445
: _a(a.id()), _b(b.id()), _op(op)
382446
{}
383447

384-
void run()
448+
void run() override
385449
{
450+
#if 0
386451
const auto a = std::move(Registry::get(_a).get());
387452
const auto b = std::move(Registry::get(_b).get());
388453
set_value(std::move(TypeDispatch<x::EWBinOp>(a, b, _op)));
454+
#endif
455+
}
456+
457+
::mlir::Value generate_mlir(::mlir::OpBuilder & builder, ::mlir::Location loc, jit::IdValueMap & ivm) override
458+
{
459+
// FIXME compute the type of the result (inputs can be heterogeneous)
460+
auto rtyp = ivm[_a].first.getType();
461+
auto ewbo = builder.create<::imex::ptensor::EWBinOp>(loc, rtyp, builder.getI32IntegerAttr(ddpt2mlir(_op)), ivm[_a].first, ivm[_b].first);
462+
auto setter = [this](uint64_t rank, void *allocated, void *aligned, intptr_t offset, intptr_t * sizes, intptr_t * strides) {
463+
// FIXME GC assert(allocated == aligned);
464+
this->set_value(std::move(x::operatorx<uint64_t>::mk_tx(rank, allocated, aligned, offset, sizes, strides)));
465+
};
466+
ivm[_guid] = {ewbo, setter};
467+
return ewbo;
389468
}
390469

391470
FactoryId factory() const

src/include/ddptensor/x.hpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,16 @@ namespace x
265265
return std::make_shared<DPTensorX<T>>(std::forward<Ts>(args)...);
266266
}
267267

268+
static typename DPTensorX<T>::typed_ptr_type mk_tx(uint64_t rank, void *allocated, void *aligned, intptr_t offset, intptr_t * sizes, intptr_t * strides)
269+
{
270+
// FIXME strides/slices are not used
271+
shape_type shp(rank);
272+
for(int i = 0; i < rank; ++i) {
273+
shp[i] = sizes[i];
274+
}
275+
return std::make_shared<DPTensorX<T>>(shp, reinterpret_cast<T*>(aligned) + offset);
276+
}
277+
268278
template<typename X>
269279
static DPTensorBaseX::ptr_type mk_tx_(const DPTensorX<T> & tx, X && x)
270280
{

0 commit comments

Comments
 (0)