Skip to content
This repository was archived by the owner on Jan 26, 2026. It is now read-only.

Commit 24aac84

Browse files
committed
adding (but not using) inplace binop support
1 parent 015829d commit 24aac84

File tree

6 files changed

+95
-118
lines changed

6 files changed

+95
-118
lines changed

ddptensor/ddptensor.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,12 @@ def __repr__(self):
2424
f"{method} = lambda self, other: dtensor(_cdt.EWBinOp.op(_cdt.{METHOD}, self._t, other._t if isinstance(other, dtensor) else other))"
2525
)
2626

27-
def _inplace(self, t):
28-
self._t = t
29-
return self
30-
31-
for method in api.api_categories["IEWBinOp"]:
32-
METHOD = method.upper()
33-
exec(
34-
f"{method} = lambda self, other: self._inplace(_cdt.IEWBinOp.op(_cdt.{METHOD}, self._t, other._t if isinstance(other, dtensor) else other))"
35-
)
27+
# inplace operators still lead to an assignment, needs more involved analysis
28+
# for method in api.api_categories["IEWBinOp"]:
29+
# METHOD = method.upper()
30+
# exec(
31+
# f"{method} = lambda self, other: (self, _cdt.IEWBinOp.op(_cdt.{METHOD}, self._t, other._t if isinstance(other, dtensor) else other)[0])"
32+
# )
3633

3734
for method in api.api_categories["EWUnyOp"]:
3835
if method.startswith("__"):

src/EWBinOp.cpp

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
*/
66

77
#include "ddptensor/EWBinOp.hpp"
8-
#include "ddptensor/CollComm.hpp"
98
#include "ddptensor/Creator.hpp"
109
#include "ddptensor/DDPTensorImpl.hpp"
1110
#include "ddptensor/Factory.hpp"
@@ -449,14 +448,6 @@ struct DeferredEWBinOp : public Deferred {
449448
: Deferred(a.dtype(), std::max(a.rank(), b.rank()), true), _a(a.id()),
450449
_b(b.id()), _op(op) {}
451450

452-
void run() override {
453-
#if 0
454-
const auto a = std::move(Registry::get(_a).get());
455-
const auto b = std::move(Registry::get(_b).get());
456-
set_value(std::move(TypeDispatch<x::EWBinOp>(a, b, _op)));
457-
#endif
458-
}
459-
460451
bool generate_mlir(::mlir::OpBuilder &builder, ::mlir::Location loc,
461452
jit::DepManager &dm) override {
462453
// FIXME the type of the result is based on a only

src/IEWBinOp.cpp

Lines changed: 73 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -8,80 +8,47 @@
88
#include "ddptensor/Creator.hpp"
99
#include "ddptensor/DDPTensorImpl.hpp"
1010
#include "ddptensor/Factory.hpp"
11+
#include "ddptensor/Registry.hpp"
1112
#include "ddptensor/TypeDispatch.hpp"
1213

13-
#if 0
14-
namespace x {
14+
#include <imex/Dialect/Dist/IR/DistOps.h>
15+
#include <imex/Dialect/PTensor/IR/PTensorOps.h>
16+
#include <mlir/Dialect/Shape/IR/Shape.h>
17+
#include <mlir/IR/Builders.h>
18+
#include <mlir/IR/BuiltinTypeInterfaces.h>
1519

16-
class IEWBinOp
17-
{
18-
public:
19-
using ptr_type = DPTensorBaseX::ptr_type;
20-
21-
template<typename A, typename B>
22-
static ptr_type op(IEWBinOpId iop, std::shared_ptr<DPTensorX<A>> a_ptr, const std::shared_ptr<DPTensorX<B>> & b_ptr)
23-
{
24-
auto & ax = a_ptr->xarray();
25-
const auto & bx = b_ptr->xarray();
26-
if(a_ptr->is_sliced() || b_ptr->is_sliced()) {
27-
auto av = xt::strided_view(ax, a_ptr->lslice());
28-
const auto & bv = xt::strided_view(bx, b_ptr->lslice());
29-
return do_op(iop, av, bv, a_ptr);
30-
}
31-
return do_op(iop, ax, bx, a_ptr);
32-
}
33-
34-
#pragma GCC diagnostic ignored "-Wswitch"
35-
template<typename A, typename T1, typename T2>
36-
static ptr_type do_op(IEWBinOpId iop, T1 & a, const T2 & b, std::shared_ptr<DPTensorX<A>> a_ptr)
37-
{
38-
switch(iop) {
39-
case __IADD__:
40-
a += b;
41-
return a_ptr;
42-
case __IFLOORDIV__:
43-
a = xt::floor(a / b);
44-
return a_ptr;
45-
case __IMUL__:
46-
a *= b;
47-
return a_ptr;
48-
case __ISUB__:
49-
a -= b;
50-
return a_ptr;
51-
case __ITRUEDIV__:
52-
a /= b;
53-
return a_ptr;
54-
case __IPOW__:
55-
throw std::runtime_error("Binary inplace operation not implemented");
56-
}
57-
if constexpr (std::is_integral<typename T1::value_type>::value && std::is_integral<typename T2::value_type>::value) {
58-
switch(iop) {
59-
case __IMOD__:
60-
a %= b;
61-
return a_ptr;
62-
case __IOR__:
63-
a |= b;
64-
return a_ptr;
65-
case __IAND__:
66-
a &= b;
67-
return a_ptr;
68-
case __IXOR__:
69-
a ^= b;
70-
case __ILSHIFT__:
71-
a = xt::left_shift(a, b);
72-
return a_ptr;
73-
case __IRSHIFT__:
74-
a = xt::right_shift(a, b);
75-
return a_ptr;
76-
}
77-
}
78-
throw std::runtime_error("Unknown/invalid inplace elementwise binary operation");
79-
}
80-
#pragma GCC diagnostic pop
81-
82-
};
83-
} // namespace x
84-
#endif // if 0
20+
// convert id of our binop to id of imex::ptensor binop
21+
static ::imex::ptensor::EWBinOpId ddpt2mlir(const IEWBinOpId bop) {
22+
switch (bop) {
23+
case __IADD__:
24+
return ::imex::ptensor::ADD;
25+
case __IAND__:
26+
return ::imex::ptensor::BITWISE_AND;
27+
case __IFLOORDIV__:
28+
return ::imex::ptensor::FLOOR_DIVIDE;
29+
case __ILSHIFT__:
30+
return ::imex::ptensor::BITWISE_LEFT_SHIFT;
31+
case __IMOD__:
32+
return ::imex::ptensor::MODULO;
33+
case __IMUL__:
34+
return ::imex::ptensor::MULTIPLY;
35+
case __IOR__:
36+
return ::imex::ptensor::BITWISE_OR;
37+
case __IPOW__:
38+
return ::imex::ptensor::POWER;
39+
case __IRSHIFT__:
40+
return ::imex::ptensor::BITWISE_RIGHT_SHIFT;
41+
case __ISUB__:
42+
return ::imex::ptensor::SUBTRACT;
43+
case __ITRUEDIV__:
44+
return ::imex::ptensor::TRUE_DIVIDE;
45+
case __IXOR__:
46+
return ::imex::ptensor::BITWISE_XOR;
47+
default:
48+
throw std::runtime_error(
49+
"Unknown/invalid inplace elementwise binary operation");
50+
}
51+
}
8552

8653
struct DeferredIEWBinOp : public Deferred {
8754
id_type _a;
@@ -91,15 +58,45 @@ struct DeferredIEWBinOp : public Deferred {
9158
DeferredIEWBinOp() = default;
9259
DeferredIEWBinOp(IEWBinOpId op, const tensor_i::future_type &a,
9360
const tensor_i::future_type &b)
94-
: _a(a.id()), _b(b.id()), _op(op) {}
61+
: Deferred(a.dtype(), a.rank(), a.balanced()), _a(a.id()), _b(b.id()),
62+
_op(op) {}
63+
64+
bool generate_mlir(::mlir::OpBuilder &builder, ::mlir::Location loc,
65+
jit::DepManager &dm) override {
66+
// FIXME the type of the result is based on a only
67+
auto av = dm.getDependent(builder, _a);
68+
auto bv = dm.getDependent(builder, _b);
9569

96-
void run() {
97-
// const auto a = std::move(Registry::get(_a).get());
98-
// const auto b = std::move(Registry::get(_b).get());
99-
// set_value(std::move(TypeDispatch<x::IEWBinOp>(a, b, _op)));
70+
auto aTyp = ::imex::dist::getPTensorType(av);
71+
::mlir::SmallVector<int64_t> shape(rank(), ::mlir::ShapedType::kDynamic);
72+
auto outTyp =
73+
::imex::ptensor::PTensorType::get(shape, aTyp.getElementType());
74+
75+
auto binop = builder.create<::imex::ptensor::EWBinOp>(
76+
loc, outTyp, builder.getI32IntegerAttr(ddpt2mlir(_op)), av, bv);
77+
// insertsliceop has no return value, so we just create the op...
78+
auto zero = ::imex::createIndex(loc, builder, 0);
79+
auto one = ::imex::createIndex(loc, builder, 1);
80+
auto dyn = ::imex::createIndex(loc, builder, ::mlir::ShapedType::kDynamic);
81+
::mlir::SmallVector<::mlir::Value> offs(rank(), zero);
82+
::mlir::SmallVector<::mlir::Value> szs(rank(), dyn);
83+
::mlir::SmallVector<::mlir::Value> strds(rank(), one);
84+
(void)builder.create<::imex::ptensor::InsertSliceOp>(loc, av, binop, offs,
85+
szs, strds);
86+
// ... and use av as to later create the ptensor
87+
dm.addVal(this->guid(), av,
88+
[this](Transceiver *transceiver, uint64_t rank, void *allocated,
89+
void *aligned, intptr_t offset, const intptr_t *sizes,
90+
const intptr_t *strides, uint64_t *gs_allocated,
91+
uint64_t *gs_aligned, uint64_t *lo_allocated,
92+
uint64_t *lo_aligned, uint64_t balanced) {
93+
this->set_value(Registry::get(this->_a).get());
94+
});
95+
return false;
10096
}
10197

10298
FactoryId factory() const { return F_IEWBINOP; }
99+
103100
template <typename S> void serialize(S &ser) {
104101
ser.template value<sizeof(_a)>(_a);
105102
ser.template value<sizeof(_b)>(_b);

src/SetGetItem.cpp

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -204,12 +204,6 @@ struct DeferredSetItem : public Deferred {
204204
const std::vector<py::slice> &v)
205205
: _a(a.id()), _b(b.id()), _slc(v) {}
206206

207-
void run() {
208-
// const auto a = std::move(Registry::get(_a).get());
209-
// const auto b = std::move(Registry::get(_b).get());
210-
// set_value(std::move(TypeDispatch<x::SetItem>(a, b, _slc, _b)));
211-
}
212-
213207
bool generate_mlir(::mlir::OpBuilder &builder, ::mlir::Location loc,
214208
jit::DepManager &dm) override {
215209
// get params and extract offsets/sizes/strides
@@ -229,9 +223,9 @@ struct DeferredSetItem : public Deferred {
229223
sizesV[i] = ::imex::createIndex(loc, builder, sizes[i]);
230224
stridesV[i] = ::imex::createIndex(loc, builder, strides[i]);
231225
}
232-
// insertsliceop has no return value, so we just craete the op...
233-
builder.create<::imex::ptensor::InsertSliceOp>(loc, av, bv, offsV, sizesV,
234-
stridesV);
226+
// insertsliceop has no return value, so we just create the op...
227+
(void)builder.create<::imex::ptensor::InsertSliceOp>(loc, av, bv, offsV,
228+
sizesV, stridesV);
235229
// ... and use av as to later create the ptensor
236230
dm.addVal(this->guid(), av,
237231
[this](Transceiver *transceiver, uint64_t rank, void *allocated,

src/include/ddptensor/Deferred.hpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@ struct Runable {
2323
using ptr_type = std::unique_ptr<Runable>;
2424
virtual ~Runable(){};
2525
/// actually execute, a deferred will set value of future
26-
virtual void run() = 0;
26+
virtual void run() {
27+
throw(std::runtime_error(
28+
"No immediate execution support for this operation."));
29+
};
2730
/// generate MLIR code for jit
2831
/// the runable might not generate MLIR and instead return true
2932
/// to request the scheduler to execute the run method instead.

test/stencil-2d.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,6 @@ def main():
144144
for i in range(n):
145145
for j in range(n):
146146
A[i, j] = float(i + j)
147-
print(A.dtype)
148147
B = numpy.zeros((n, n), dtype=numpy.float64)
149148

150149
for k in range(iterations + 1):
@@ -154,9 +153,8 @@ def main():
154153

155154
if pattern == "star":
156155
if r == 2:
157-
B[2 : n - 2, 2 : n - 2] = (
158-
B[2 : n - 2, 2 : n - 2]
159-
+ W[2, 2] * A[2 : n - 2, 2 : n - 2]
156+
B[2 : n - 2, 2 : n - 2] += (
157+
W[2, 2] * A[2 : n - 2, 2 : n - 2]
160158
+ W[2, 0] * A[2 : n - 2, 0 : n - 4]
161159
+ W[2, 1] * A[2 : n - 2, 1 : n - 3]
162160
+ W[2, 3] * A[2 : n - 2, 3 : n - 1]
@@ -168,11 +166,10 @@ def main():
168166
)
169167
else:
170168
b = n - r
171-
B[r:b, r:b] = B[r:b, r:b] + W[r, r] * A[r:b, r:b]
169+
B[r:b, r:b] += W[r, r] * A[r:b, r:b]
172170
for s in range(1, r + 1):
173-
B[r:b, r:b] = (
174-
B[r:b, r:b]
175-
+ W[r, r - s] * A[r:b, r - s : b - s]
171+
B[r:b, r:b] += (
172+
W[r, r - s] * A[r:b, r - s : b - s]
176173
+ W[r, r + s] * A[r:b, r + s : b + s]
177174
+ W[r - s, r] * A[r - s : b - s, r:b]
178175
+ W[r + s, r] * A[r + s : b + s, r:b]
@@ -182,11 +179,7 @@ def main():
182179
b = n - r
183180
for s in range(-r, r + 1):
184181
for t in range(-r, r + 1):
185-
B[r:b, r:b] = (
186-
B[r:b, r:b]
187-
+ W[r + t, r + s] * A[r + t : b + t, r + s : b + s]
188-
)
189-
182+
B[r:b, r:b] += W[r + t, r + s] * A[r + t : b + t, r + s : b + s]
190183
A = A + 1.0
191184

192185
t1 = timer()
@@ -196,7 +189,9 @@ def main():
196189
# * Analyze and output results.
197190
# ******************************************************************************
198191

199-
print(W, B)
192+
print(W)
193+
print("********************************")
194+
print(B)
200195
# norm = numpy.linalg.norm(numpy.reshape(B,n*n),ord=1)
201196
# active_points = (n-2*r)**2
202197
# norm /= active_points

0 commit comments

Comments
 (0)