Skip to content
This repository was archived by the owner on Jan 26, 2026. It is now read-only.

Commit 9803955

Browse files
committed
simplifying code for deferring execution
1 parent ee0bb9e commit 9803955

20 files changed

+123
-199
lines changed

ddptensor/ddptensor.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,14 @@ def __repr__(self):
2323
f"{method} = lambda self, other: dtensor(_cdt.EWBinOp.op(_cdt.{METHOD}, self._t, other._t if isinstance(other, dtensor) else other))"
2424
)
2525

26+
def _inplace(self, t):
27+
self._t = t
28+
return self
29+
2630
for method in api.api_categories["IEWBinOp"]:
2731
METHOD = method.upper()
2832
exec(
29-
f"{method} = lambda self, other: (self, _cdt.IEWBinOp.op(_cdt.{METHOD}, self._t, other._t))[0]" # if isinstance(other, dtensor) else other))[0]"
33+
f"{method} = lambda self, other: self._inplace(_cdt.IEWBinOp.op(_cdt.{METHOD}, self._t, other._t if isinstance(other, dtensor) else other))"
3034
)
3135

3236
for method in api.api_categories["EWUnyOp"]:

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def build_cmake(self, ext):
2929
extdir.parent.mkdir(parents=True, exist_ok=True)
3030

3131
# example of cmake args
32-
config = 'Debug' if self.debug else 'Release' #'RelWithDebInfo'
32+
config = 'Debug'# if self.debug else 'Release' #'RelWithDebInfo'
3333
cmake_args = [
3434
'-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + str(extdir.parent.absolute()),
3535
'-DCMAKE_BUILD_TYPE=' + config
@@ -38,7 +38,7 @@ def build_cmake(self, ext):
3838
# example of build args
3939
build_args = [
4040
'--config', config,
41-
'--', '-j8'
41+
#'--', '-j8'
4242
]
4343

4444
os.chdir(str(build_temp))

src/EWBinOp.cpp

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -133,29 +133,15 @@ namespace x {
133133
};
134134
} // namespace x
135135

136-
struct DeferredEWBinOp : public Deferred
137-
{
138-
tensor_i::future_type _a;
139-
tensor_i::future_type _b;
140-
EWBinOpId _op;
141-
142-
DeferredEWBinOp(EWBinOpId op, tensor_i::future_type & a, tensor_i::future_type & b)
143-
: _a(a), _b(b), _op(op)
144-
{}
145-
146-
void run()
147-
{
148-
auto a = std::move(_a.get());
149-
auto b = std::move(_b.get());
150-
set_value(TypeDispatch<x::EWBinOp>(a, b, _op));
151-
}
152-
};
153-
154-
tensor_i::future_type EWBinOp::op(EWBinOpId op, tensor_i::future_type & a, py::object & b)
136+
tensor_i::future_type EWBinOp::op(EWBinOpId op, const tensor_i::future_type & a, const py::object & b)
155137
{
138+
auto bb = x::mk_ftx(b);
156139
if(op == __MATMUL__) {
157-
auto bb = x::mk_ftx(b);
158140
return LinAlgOp::vecdot(a, bb, 0);
159141
}
160-
return defer<DeferredEWBinOp>(op, a, x::mk_ftx(b));
142+
auto aa = std::move(a.get());
143+
auto bbb = std::move(bb.get());
144+
return defer([op, aa, bbb](){
145+
return TypeDispatch<x::EWBinOp>(aa, bbb, op);
146+
});
161147
}

src/EWUnyOp.cpp

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -108,23 +108,10 @@ namespace x {
108108
};
109109
} //namespace x
110110

111-
struct DeferredEWUnyOp : public Deferred
111+
tensor_i::future_type EWUnyOp::op(EWUnyOpId op, const tensor_i::future_type & a)
112112
{
113-
tensor_i::future_type _a;
114-
EWUnyOpId _op;
115-
116-
DeferredEWUnyOp(EWUnyOpId op, tensor_i::future_type & a)
117-
: _a(a), _op(op)
118-
{}
119-
120-
void run()
121-
{
122-
auto a = std::move(_a.get());
123-
set_value(TypeDispatch<x::EWUnyOp>(a, _op));
124-
}
125-
};
126-
127-
tensor_i::future_type EWUnyOp::op(EWUnyOpId op, tensor_i::future_type & a)
128-
{
129-
return defer<DeferredEWUnyOp>(op, a);
113+
auto aa = std::move(a.get());
114+
return defer([op, aa](){
115+
return TypeDispatch<x::EWUnyOp>(aa, op);
116+
});
130117
}

src/IEWBinOp.cpp

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -73,25 +73,12 @@ namespace x {
7373
};
7474
} // namespace x
7575

76-
struct DeferredIEWBinOp : public Deferred
76+
tensor_i::future_type IEWBinOp::op(IEWBinOpId op, tensor_i::future_type & a, const py::object & b)
7777
{
78-
tensor_i::future_type _a;
79-
tensor_i::future_type _b;
80-
IEWBinOpId _op;
81-
82-
DeferredIEWBinOp(IEWBinOpId op, tensor_i::future_type & a, tensor_i::future_type & b)
83-
: _a(a), _b(b), _op(op)
84-
{}
85-
86-
void run()
87-
{
88-
auto a = std::move(_a.get());
89-
auto b = std::move(_b.get());
90-
set_value(TypeDispatch<x::IEWBinOp>(a, b, _op));
91-
}
92-
};
93-
94-
tensor_i::future_type IEWBinOp::op(IEWBinOpId op, tensor_i::future_type & a, py::object & b)
95-
{
96-
return defer<DeferredIEWBinOp>(op, a, x::mk_ftx(b));
78+
auto bb = x::mk_ftx(b);
79+
auto aa = std::move(a.get());
80+
auto bbb = std::move(bb.get());
81+
return defer([op, aa, bbb](){
82+
return TypeDispatch<x::IEWBinOp>(aa, bbb, op);
83+
});
9784
}

src/LinAlgOp.cpp

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -109,25 +109,11 @@ namespace x {
109109
};
110110
}
111111

112-
struct DeferredLinAlgOp : public Deferred
112+
tensor_i::future_type LinAlgOp::vecdot(const tensor_i::future_type & a, const tensor_i::future_type & b, int axis)
113113
{
114-
tensor_i::future_type _a;
115-
tensor_i::future_type _b;
116-
int _axis;
117-
118-
DeferredLinAlgOp(tensor_i::future_type & a, tensor_i::future_type & b, int axis)
119-
: _a(a), _b(b), _axis(axis)
120-
{}
121-
122-
void run()
123-
{
124-
auto a = std::move(_a.get());
125-
auto b = std::move(_b.get());
126-
set_value(TypeDispatch<x::LinAlgOp>(a, b, _axis));
127-
}
128-
};
129-
130-
tensor_i::future_type LinAlgOp::vecdot(tensor_i::future_type & a, tensor_i::future_type & b, int axis)
131-
{
132-
return defer<DeferredLinAlgOp>(a, b, axis);
114+
auto aa = std::move(a.get());
115+
auto bb = std::move(b.get());
116+
return defer([aa, bb, axis](){
117+
return TypeDispatch<x::LinAlgOp>(aa, bb, axis);
118+
});
133119
}

src/ManipOp.cpp

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -23,23 +23,10 @@ namespace x {
2323
};
2424
}
2525

26-
struct DeferredManipOp : public Deferred
26+
tensor_i::future_type ManipOp::reshape(const tensor_i::future_type & a, const shape_type & shape)
2727
{
28-
tensor_i::future_type _a;
29-
shape_type _shape;
30-
31-
DeferredManipOp(tensor_i::future_type & a, const shape_type & shape)
32-
: _a(a), _shape(shape)
33-
{}
34-
35-
void run()
36-
{
37-
auto a = std::move(_a.get());
38-
set_value(TypeDispatch<x::ManipOp>(a, _shape));
39-
}
40-
};
41-
42-
tensor_i::future_type ManipOp::reshape(tensor_i::future_type & a, const shape_type & shape)
43-
{
44-
return defer<DeferredManipOp>(a, shape);
28+
auto aa = std::move(a.get());
29+
return defer([aa, shape](){
30+
return TypeDispatch<x::ManipOp>(aa, shape);
31+
});
4532
}

src/Random.cpp

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,48 +9,53 @@ namespace x {
99
template<typename T>
1010
struct Rand
1111
{
12-
template<typename L, typename U>
13-
static ptr_type op(const shape_type & shp, const L & lower, const U & upper)
12+
//template<typename L, typename U>
13+
static ptr_type op(const shape_type & shp, T lower, T upper)
1414
{
15-
if constexpr (std::is_floating_point<T>::value) {
16-
PVSlice pvslice(shp);
17-
shape_type shape(std::move(pvslice.shape_of_rank()));
18-
return operatorx<T>::mk_tx(std::move(pvslice), std::move(xt::random::rand(std::move(shape), to_native<T>(lower), to_native<T>(upper))));
19-
}
15+
PVSlice pvslice(shp);
16+
shape_type shape(std::move(pvslice.shape_of_rank()));
17+
return operatorx<T>::mk_tx(std::move(pvslice), std::move(xt::random::rand(std::move(shape), lower, upper)));
2018
}
2119
};
2220
}
2321

22+
template<typename T>
2423
struct DeferredRandomOp : public Deferred
2524
{
2625
shape_type _shape;
27-
py::object _lower, _upper;
28-
DTypeId _dtype;
26+
T _lower, _upper;
2927

30-
DeferredRandomOp(DTypeId dtype, const shape_type & shape, const py::object & lower, const py::object & upper)
31-
: _shape(shape), _lower(lower), _upper(upper), _dtype(dtype)
28+
DeferredRandomOp(const shape_type & shape, T lower, T upper)
29+
: _shape(shape), _lower(lower), _upper(upper)
3230
{}
3331

3432
void run()
3533
{
36-
switch(_dtype) {
37-
case FLOAT64:
38-
set_value(x::Rand<double>::op(_shape, _lower, _upper));
39-
return;
40-
case FLOAT32:
41-
set_value(x::Rand<float>::op(_shape, _lower, _upper));
42-
return;
43-
}
44-
throw std::runtime_error("rand: dtype must be a floating point type");
34+
set_value(x::Rand<T>::op(_shape, _lower, _upper));
4535
}
4636
};
4737

4838
Random::future_type Random::rand(DTypeId dtype, const shape_type & shape, const py::object & lower, const py::object & upper)
4939
{
50-
return defer<DeferredRandomOp>(dtype, shape, lower, upper);
40+
switch(dtype) {
41+
case FLOAT64: {
42+
double lo = x::to_native<double>(lower);
43+
double up = x::to_native<double>(upper);
44+
return defer([shape, lo, up](){return x::Rand<double>::op(shape, lo, up);});
45+
//return defer<DeferredRandomOp<double>>(shape, x::to_native<double>(lower), x::to_native<double>(upper));
46+
}
47+
case FLOAT32: {
48+
float lo = x::to_native<float>(lower);
49+
float up = x::to_native<float>(upper);
50+
return defer([shape, lo, up](){return x::Rand<float>::op(shape, lo, up);});
51+
//return defer<DeferredRandomOp<float>>(shape, x::to_native<double>(lower), x::to_native<double>(upper));
52+
}
53+
default:
54+
throw std::runtime_error("rand: dtype must be a floating point type");
55+
}
5156
}
5257

5358
void Random::seed(uint64_t s)
5459
{
55-
xt::random::seed(s);
60+
defer([s](){xt::random::seed(s); return tensor_i::ptr_type();});
5661
}

src/ReduceOp.cpp

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -63,24 +63,10 @@ namespace x {
6363
};
6464
} // namespace x
6565

66-
struct DeferredReduceOp : public Deferred
66+
tensor_i::future_type ReduceOp::op(ReduceOpId op, const tensor_i::future_type & a, const dim_vec_type & dim)
6767
{
68-
tensor_i::future_type _a;
69-
dim_vec_type _dim;
70-
ReduceOpId _op;
71-
72-
DeferredReduceOp(ReduceOpId op, tensor_i::future_type & a, const dim_vec_type & dim)
73-
: _a(a), _dim(dim), _op(op)
74-
{}
75-
76-
void run()
77-
{
78-
auto a = std::move(_a.get());
79-
set_value(TypeDispatch<x::ReduceOp>(a, _op, _dim));
80-
}
81-
};
82-
83-
tensor_i::future_type ReduceOp::op(ReduceOpId op, tensor_i::future_type & a, const dim_vec_type & dim)
84-
{
85-
return defer<DeferredReduceOp>(op, a, dim);
68+
auto aa = std::move(a.get());
69+
return defer([aa, op, dim](){
70+
return TypeDispatch<x::ReduceOp>(aa, op, dim);
71+
});
8672
}

src/SetGetItem.cpp

Lines changed: 17 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -134,58 +134,33 @@ namespace x {
134134

135135
} // namespace x
136136

137-
struct DeferredSetItem : public Deferred
137+
tensor_i::future_type SetItem::__setitem__(tensor_i::future_type & a, const std::vector<py::slice> & v, const tensor_i::future_type & b)
138138
{
139-
tensor_i::future_type _a;
140-
tensor_i::future_type _b;
141-
NDSlice _slc;
142-
143-
DeferredSetItem(tensor_i::future_type & a, tensor_i::future_type & b, const std::vector<py::slice> & v)
144-
: _a(a), _b(b), _slc(v)
145-
{}
146-
147-
void run()
148-
{
149-
auto a = std::move(_a.get());
150-
auto b = std::move(_b.get());
151-
set_value(TypeDispatch<x::SetItem>(a, b, _slc));
152-
}
153-
};
154-
155-
tensor_i::future_type SetItem::__setitem__(tensor_i::future_type & a, const std::vector<py::slice> & v, tensor_i::future_type & b)
156-
{
157-
return defer<DeferredSetItem>(a, b, v);
139+
auto aa = std::move(a.get());
140+
auto bb = std::move(b.get());
141+
NDSlice _slc(v);
142+
return defer([aa, bb, _slc](){
143+
return TypeDispatch<x::SetItem>(aa, bb, _slc);
144+
});
158145
}
159146

160-
struct DeferredGetItem : public Deferred
161-
{
162-
tensor_i::future_type _a;
163-
NDSlice _slc;
164-
165-
DeferredGetItem(tensor_i::future_type & a, const std::vector<py::slice> & v)
166-
: _a(a), _slc(v)
167-
{}
168-
169-
void run()
170-
{
171-
auto a = std::move(_a.get());
172-
set_value(TypeDispatch<x::GetItem>(a, _slc));
173-
}
174-
};
175-
176-
tensor_i::future_type GetItem::__getitem__(tensor_i::future_type & a, const std::vector<py::slice> & v)
147+
tensor_i::future_type GetItem::__getitem__(const tensor_i::future_type & a, const std::vector<py::slice> & v)
177148
{
178-
return defer<DeferredGetItem>(a, v);
149+
auto aa = std::move(a.get());
150+
NDSlice _slc(v);
151+
return defer([aa, _slc](){
152+
return TypeDispatch<x::GetItem>(aa, _slc);
153+
});
179154
}
180155

181-
py::object GetItem::get_slice(tensor_i::future_type & a, const std::vector<py::slice> & v)
156+
py::object GetItem::get_slice(const tensor_i::future_type & a, const std::vector<py::slice> & v)
182157
{
183-
auto aa = std::move(a.get());
158+
const auto & aa = std::move(a.get());
184159
return TypeDispatch<x::SPMD>(aa, NDSlice(v));
185160
}
186161

187-
py::object GetItem::get_local(tensor_i::future_type & a, py::handle h)
162+
py::object GetItem::get_local(const tensor_i::future_type & a, py::handle h)
188163
{
189-
auto aa = std::move(a.get());
164+
const auto & aa = std::move(a.get());
190165
return TypeDispatch<x::SPMD>(aa, h);
191166
}

0 commit comments

Comments
 (0)