Skip to content
This repository was archived by the owner on Jan 26, 2026. It is now read-only.

Commit c8e106b

Browse files
committed
base for deferred execution
1 parent 2a9da9c commit c8e106b

23 files changed

+383
-83
lines changed

src/Creator.cpp

Lines changed: 56 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
#include "ddptensor/Creator.hpp"
22
#include "ddptensor/TypeDispatch.hpp"
33
#include "ddptensor/x.hpp"
4+
#include "ddptensor/Deferred.hpp"
45

56
namespace x {
67

78
template<typename T>
89
class Creator
910
{
1011
public:
11-
using ptr_type = DPTensorBaseX::ptr_type;
12+
using ptr_type = typename tensor_i::ptr_type;
1213
using typed_ptr_type = typename DPTensorX<T>::typed_ptr_type;
1314

1415
static ptr_type op(CreatorId c, const shape_type & shp)
@@ -51,18 +52,65 @@ namespace x {
5152
}; // class creatorx
5253
} // namespace x
5354

54-
tensor_i::ptr_type Creator::create_from_shape(CreatorId op, const shape_type & shape, DTypeId dtype)
55+
struct DeferredFromShape : public Deferred
5556
{
56-
return TypeDispatch<x::Creator>(dtype, op, shape);
57+
CreatorId _op;
58+
shape_type _shape;
59+
DTypeId _dtype;
60+
61+
DeferredFromShape(CreatorId op, const shape_type & shape, DTypeId dtype)
62+
: _op(op), _shape(shape), _dtype(dtype)
63+
{}
64+
65+
void run()
66+
{
67+
set_value(TypeDispatch<x::Creator>(_dtype, _op, _shape));
68+
}
69+
};
70+
71+
tensor_i::future_type Creator::create_from_shape(CreatorId op, const shape_type & shape, DTypeId dtype)
72+
{
73+
return defer<DeferredFromShape>(op, shape, dtype);
5774
}
5875

59-
tensor_i::ptr_type Creator::full(const shape_type & shape, const py::object & val, DTypeId dtype)
76+
struct DeferredFull : public Deferred
6077
{
61-
auto op = FULL;
62-
return TypeDispatch<x::Creator>(dtype, op, shape, val);
78+
shape_type _shape;
79+
const py::object & _val;
80+
DTypeId _dtype;
81+
82+
DeferredFull(const shape_type & shape, const py::object & val, DTypeId dtype)
83+
: _shape(shape), _val(val), _dtype(dtype)
84+
{}
85+
86+
void run()
87+
{
88+
auto op = FULL;
89+
set_value(TypeDispatch<x::Creator>(_dtype, op, _shape, _val));
90+
}
91+
};
92+
93+
tensor_i::future_type Creator::full(const shape_type & shape, const py::object & val, DTypeId dtype)
94+
{
95+
return defer<DeferredFull>(shape, val, dtype);
6396
}
6497

65-
tensor_i::ptr_type Creator::arange(uint64_t start, uint64_t end, uint64_t step, DTypeId dtype)
98+
struct DeferredArange : public Deferred
99+
{
100+
uint64_t _start, _end, _step;
101+
DTypeId _dtype;
102+
103+
DeferredArange(uint64_t start, uint64_t end, uint64_t step, DTypeId dtype)
104+
: _start(start), _end(end), _step(step), _dtype(dtype)
105+
{}
106+
107+
void run()
108+
{
109+
set_value(TypeDispatch<x::Creator>(_dtype, _start, _end, _step));
110+
};
111+
};
112+
113+
tensor_i::future_type Creator::arange(uint64_t start, uint64_t end, uint64_t step, DTypeId dtype)
66114
{
67-
return TypeDispatch<x::Creator>(dtype, start, end, step);
115+
return defer<DeferredArange>(start, end, step, dtype);
68116
}

src/Deferred.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#include "include/ddptensor/Deferred.hpp"
2+
#include <queue>
3+
4+
static std::queue<Deferred::ptr_type> _deferred;
5+
6+
Deferred::future_type Deferred::defer(Deferred::ptr_type && d)
7+
{
8+
//auto f = d->get_future();
9+
_deferred.push(std::move(d));
10+
// return f;
11+
auto aa = Deferred::undefer_next();
12+
aa->run();
13+
return aa->get_future();
14+
}
15+
16+
Deferred::ptr_type Deferred::undefer_next()
17+
{
18+
auto r = std::move(_deferred.front());
19+
_deferred.pop();
20+
return r;
21+
}

src/EWBinOp.cpp

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ namespace x {
5353
case __LT__:
5454
case LESS:
5555
return operatorx<A>::mk_tx_(a_ptr, a < b);
56+
// __MATMUL__ is handled before dispatching, see below
5657
case __MUL__:
5758
case MULTIPLY:
5859
return operatorx<A>::mk_tx_(a_ptr, a * b);
@@ -73,8 +74,6 @@ namespace x {
7374
return operatorx<A>::mk_tx_(a_ptr, b - a);
7475
case __RTRUEDIV__:
7576
return operatorx<A>::mk_tx_(a_ptr, b / a);
76-
case __MATMUL__:
77-
return LinAlgOp::vecdot(a_ptr, b_ptr, 0);
7877
case __POW__:
7978
case POW:
8079
return operatorx<A>::mk_tx_(a_ptr, xt::pow(a, b));
@@ -133,9 +132,30 @@ namespace x {
133132

134133
};
135134
} // namespace x
136-
137-
tensor_i::ptr_type EWBinOp::op(EWBinOpId op, x::DPTensorBaseX::ptr_type a, py::object & b)
135+
136+
struct DeferredEWBinOp : public Deferred
137+
{
138+
tensor_i::future_type _a;
139+
tensor_i::future_type _b;
140+
EWBinOpId _op;
141+
142+
DeferredEWBinOp(EWBinOpId op, tensor_i::future_type & a, tensor_i::future_type & b)
143+
: _a(a), _b(b), _op(op)
144+
{}
145+
146+
void run()
147+
{
148+
auto a = std::move(_a.get());
149+
auto b = std::move(_b.get());
150+
set_value(TypeDispatch<x::EWBinOp>(a, b, _op));
151+
}
152+
};
153+
154+
tensor_i::future_type EWBinOp::op(EWBinOpId op, tensor_i::future_type & a, py::object & b)
138155
{
139-
auto bb = x::mk_tx(b);
140-
return TypeDispatch<x::EWBinOp>(a, bb, op);
156+
if(op == __MATMUL__) {
157+
auto bb = x::mk_ftx(b);
158+
return LinAlgOp::vecdot(a, bb, 0);
159+
}
160+
return defer<DeferredEWBinOp>(op, a, x::mk_ftx(b));
141161
}

src/EWUnyOp.cpp

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,23 @@ namespace x {
108108
};
109109
} //namespace x
110110

111-
tensor_i::ptr_type EWUnyOp::op(EWUnyOpId op, x::DPTensorBaseX::ptr_type a)
111+
struct DeferredEWUnyOp : public Deferred
112112
{
113-
return TypeDispatch<x::EWUnyOp>(a, op);
113+
tensor_i::future_type _a;
114+
EWUnyOpId _op;
115+
116+
DeferredEWUnyOp(EWUnyOpId op, tensor_i::future_type & a)
117+
: _a(a), _op(op)
118+
{}
119+
120+
void run()
121+
{
122+
auto a = std::move(_a.get());
123+
set_value(TypeDispatch<x::EWUnyOp>(a, _op));
124+
}
125+
};
126+
127+
tensor_i::future_type EWUnyOp::op(EWUnyOpId op, tensor_i::future_type & a)
128+
{
129+
return defer<DeferredEWUnyOp>(op, a);
114130
}

src/IEWBinOp.cpp

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,61 +10,60 @@ namespace x {
1010
using ptr_type = DPTensorBaseX::ptr_type;
1111

1212
template<typename A, typename B>
13-
static void op(IEWBinOpId iop, std::shared_ptr<DPTensorX<A>> a_ptr, const std::shared_ptr<DPTensorX<B>> & b_ptr)
13+
static ptr_type op(IEWBinOpId iop, std::shared_ptr<DPTensorX<A>> a_ptr, const std::shared_ptr<DPTensorX<B>> & b_ptr)
1414
{
1515
auto & ax = a_ptr->xarray();
1616
const auto & bx = b_ptr->xarray();
1717
if(a_ptr->is_sliced() || b_ptr->is_sliced()) {
1818
auto av = xt::strided_view(ax, a_ptr->lslice());
1919
const auto & bv = xt::strided_view(bx, b_ptr->lslice());
20-
do_op(iop, av, bv);
21-
} else {
22-
do_op(iop, ax, bx);
20+
return do_op(iop, av, bv, a_ptr);
2321
}
22+
return do_op(iop, ax, bx, a_ptr);
2423
}
2524

2625
#pragma GCC diagnostic ignored "-Wswitch"
27-
template<typename T1, typename T2>
28-
static void do_op(IEWBinOpId iop, T1 & a, const T2 & b)
26+
template<typename A, typename T1, typename T2>
27+
static ptr_type do_op(IEWBinOpId iop, T1 & a, const T2 & b, std::shared_ptr<DPTensorX<A>> a_ptr)
2928
{
3029
switch(iop) {
3130
case __IADD__:
3231
a += b;
33-
return;
32+
return a_ptr;
3433
case __IFLOORDIV__:
3534
a = xt::floor(a / b);
36-
return;
35+
return a_ptr;
3736
case __IMUL__:
3837
a *= b;
39-
return;
38+
return a_ptr;
4039
case __ISUB__:
4140
a -= b;
42-
return;
41+
return a_ptr;
4342
case __ITRUEDIV__:
4443
a /= b;
45-
return;
44+
return a_ptr;
4645
case __IPOW__:
4746
throw std::runtime_error("Binary inplace operation not implemented");
4847
}
4948
if constexpr (std::is_integral<typename T1::value_type>::value && std::is_integral<typename T2::value_type>::value) {
5049
switch(iop) {
5150
case __IMOD__:
5251
a %= b;
53-
return;
52+
return a_ptr;
5453
case __IOR__:
5554
a |= b;
56-
return;
55+
return a_ptr;
5756
case __IAND__:
5857
a &= b;
59-
return;
58+
return a_ptr;
6059
case __IXOR__:
6160
a ^= b;
6261
case __ILSHIFT__:
6362
a = xt::left_shift(a, b);
64-
return;
63+
return a_ptr;
6564
case __IRSHIFT__:
6665
a = xt::right_shift(a, b);
67-
return;
66+
return a_ptr;
6867
}
6968
}
7069
throw std::runtime_error("Unknown/invalid inplace elementwise binary operation");
@@ -74,8 +73,25 @@ namespace x {
7473
};
7574
} // namespace x
7675

77-
void IEWBinOp::op(IEWBinOpId op, x::DPTensorBaseX::ptr_type a, py::object & b)
76+
struct DeferredIEWBinOp : public Deferred
7877
{
79-
auto bb = x::mk_tx(b);
80-
TypeDispatch<x::IEWBinOp>(a, bb, op);
78+
tensor_i::future_type _a;
79+
tensor_i::future_type _b;
80+
IEWBinOpId _op;
81+
82+
DeferredIEWBinOp(IEWBinOpId op, tensor_i::future_type & a, tensor_i::future_type & b)
83+
: _a(a), _b(b), _op(op)
84+
{}
85+
86+
void run()
87+
{
88+
auto a = std::move(_a.get());
89+
auto b = std::move(_b.get());
90+
set_value(TypeDispatch<x::IEWBinOp>(a, b, _op));
91+
}
92+
};
93+
94+
tensor_i::future_type IEWBinOp::op(IEWBinOpId op, tensor_i::future_type & a, py::object & b)
95+
{
96+
return defer<DeferredIEWBinOp>(op, a, x::mk_ftx(b));
8197
}

src/LinAlgOp.cpp

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,25 @@ namespace x {
109109
};
110110
}
111111

112-
tensor_i::ptr_type LinAlgOp::vecdot(tensor_i::ptr_type a, tensor_i::ptr_type b, int axis)
112+
struct DeferredLinAlgOp : public Deferred
113113
{
114-
return TypeDispatch<x::LinAlgOp>(a, b, axis);
114+
tensor_i::future_type _a;
115+
tensor_i::future_type _b;
116+
int _axis;
117+
118+
DeferredLinAlgOp(tensor_i::future_type & a, tensor_i::future_type & b, int axis)
119+
: _a(a), _b(b), _axis(axis)
120+
{}
121+
122+
void run()
123+
{
124+
auto a = std::move(_a.get());
125+
auto b = std::move(_b.get());
126+
set_value(TypeDispatch<x::LinAlgOp>(a, b, _axis));
127+
}
128+
};
129+
130+
tensor_i::future_type LinAlgOp::vecdot(tensor_i::future_type & a, tensor_i::future_type & b, int axis)
131+
{
132+
return defer<DeferredLinAlgOp>(a, b, axis);
115133
}

src/ManipOp.cpp

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,23 @@ namespace x {
2323
};
2424
}
2525

26-
tensor_i::ptr_type ManipOp::reshape(x::DPTensorBaseX::ptr_type a, const shape_type & shape)
26+
struct DeferredManipOp : public Deferred
2727
{
28-
return TypeDispatch<x::ManipOp>(a, shape);
28+
tensor_i::future_type _a;
29+
shape_type _shape;
30+
31+
DeferredManipOp(tensor_i::future_type & a, const shape_type & shape)
32+
: _a(a), _shape(shape)
33+
{}
34+
35+
void run()
36+
{
37+
auto a = std::move(_a.get());
38+
set_value(TypeDispatch<x::ManipOp>(a, _shape));
39+
}
40+
};
41+
42+
tensor_i::future_type ManipOp::reshape(tensor_i::future_type & a, const shape_type & shape)
43+
{
44+
return defer<DeferredManipOp>(a, shape);
2945
}

src/Random.cpp

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,31 @@ namespace x {
2121
};
2222
}
2323

24-
ptr_type Random::rand(DTypeId dtype, const shape_type & shape, const py::object & lower, const py::object & upper)
24+
struct DeferredRandomOp : public Deferred
2525
{
26-
switch(dtype) {
27-
case FLOAT64:
28-
return x::Rand<double>::op(shape, lower, upper);
29-
case FLOAT32:
30-
return x::Rand<double>::op(shape, lower, upper);
26+
shape_type _shape;
27+
py::object _lower, _upper;
28+
DTypeId _dtype;
29+
30+
DeferredRandomOp(DTypeId dtype, const shape_type & shape, const py::object & lower, const py::object & upper)
31+
: _shape(shape), _lower(lower), _upper(upper), _dtype(dtype)
32+
{}
33+
34+
void run()
35+
{
36+
switch(_dtype) {
37+
case FLOAT64:
38+
set_value(x::Rand<double>::op(_shape, _lower, _upper));
39+
case FLOAT32:
40+
set_value(x::Rand<float>::op(_shape, _lower, _upper));
41+
}
42+
throw std::runtime_error("rand: dtype must be a floating point type");
3143
}
32-
throw std::runtime_error("rand: dtype must be a floating point type");
44+
};
45+
46+
Random::future_type Random::rand(DTypeId dtype, const shape_type & shape, const py::object & lower, const py::object & upper)
47+
{
48+
return defer<DeferredRandomOp>(dtype, shape, lower, upper);
3349
}
3450

3551
void Random::seed(uint64_t s)

0 commit comments

Comments
 (0)