22#include " ddptensor/TypeDispatch.hpp"
33#include " ddptensor/x.hpp"
44#include " ddptensor/Deferred.hpp"
5+ #include " ddptensor/Factory.hpp"
56
67namespace x {
78
@@ -15,7 +16,7 @@ namespace x {
1516 static ptr_type op (CreatorId c, const shape_type & shp)
1617 {
1718 PVSlice pvslice (shp);
18- shape_type shape (std::move (pvslice.shape_of_rank ()));
19+ shape_type shape (std::move (pvslice.tile_shape ()));
1920 switch (c) {
2021 case EMPTY:
2122 return operatorx<T>::mk_tx (std::move (pvslice), std::move (xt::empty<T>(std::move (shape))));
@@ -28,14 +29,19 @@ namespace x {
2829 };
2930 };
3031
31- template <typename V>
32- static ptr_type op (CreatorId c, const shape_type & shp, const V & v)
32+ static ptr_type op (CreatorId c, const shape_type & shp, PyScalar v)
3333 {
34+ T val;
35+ if constexpr (std::is_integral<T>::value) val = static_cast <T>(v._int );
36+ else if constexpr (std::is_floating_point<T>::value) val = static_cast <T>(v._float );
3437 if (c == FULL) {
38+ if (VPROD (shp) <= 1 ) {
39+ return operatorx<T>::mk_tx (val, REPLICATED);
40+ }
3541 PVSlice pvslice (shp);
36- shape_type shape (std::move (pvslice.shape_of_rank ()));
42+ shape_type shape (std::move (pvslice.tile_shape ()));
3743 auto a = xt::empty<T>(std::move (shape));
38- a.fill (to_native<T>(v) );
44+ a.fill (val );
3945 return operatorx<T>::mk_tx (std::move (pvslice), std::move (a));
4046 }
4147 throw std::runtime_error (" Unknown creator" );
@@ -47,24 +53,39 @@ namespace x {
4753 auto lslc = pvslice.slice_of_rank ();
4854 const auto & l1dslc = lslc.dim (0 );
4955 auto a = xt::arange<T>(start + l1dslc._start *step, start + l1dslc._end * step, l1dslc._step );
50- return operatorx<T>::mk_tx (std::move (pvslice), std::move (a));
56+ auto r = operatorx<T>::mk_tx (std::move (pvslice), std::move (a));
57+ return r;
5158 }
5259 }; // class creatorx
5360} // namespace x
5461
5562struct DeferredFromShape : public Deferred
5663{
57- CreatorId _op;
5864 shape_type _shape;
5965 DTypeId _dtype;
66+ CreatorId _op;
6067
68+ DeferredFromShape () = default ;
6169 DeferredFromShape (CreatorId op, const shape_type & shape, DTypeId dtype)
62- : _op(op), _shape(shape), _dtype(dtype)
70+ : _shape(shape), _dtype(dtype), _op(op )
6371 {}
6472
6573 void run ()
6674 {
67- set_value (TypeDispatch<x::Creator>(_dtype, _op, _shape));
75+ set_value (std::move (TypeDispatch<x::Creator>(_dtype, _op, _shape)));
76+ }
77+
78+ FactoryId factory () const
79+ {
80+ return F_FROMSHAPE;
81+ }
82+
83+ template <typename S>
84+ void serialize (S & ser)
85+ {
86+ ser.template container <sizeof (shape_type::value_type)>(_shape, 8 );
87+ ser.template value <sizeof (_dtype)>(_dtype);
88+ ser.template value <sizeof (_op)>(_op);
6889 }
6990};
7091
@@ -76,30 +97,46 @@ tensor_i::future_type Creator::create_from_shape(CreatorId op, const shape_type
7697struct DeferredFull : public Deferred
7798{
7899 shape_type _shape;
79- const py::object & _val;
100+ PyScalar _val;
80101 DTypeId _dtype;
81102
82- DeferredFull (const shape_type & shape, const py::object & val, DTypeId dtype)
103+ DeferredFull () = default ;
104+ DeferredFull (const shape_type & shape, PyScalar val, DTypeId dtype)
83105 : _shape(shape), _val(val), _dtype(dtype)
84106 {}
85107
86108 void run ()
87109 {
88110 auto op = FULL;
89- set_value (TypeDispatch<x::Creator>(_dtype, op, _shape, _val));
111+ set_value (std::move (TypeDispatch<x::Creator>(_dtype, op, _shape, _val)));
112+ }
113+
114+ FactoryId factory () const
115+ {
116+ return F_FULL;
117+ }
118+
119+ template <typename S>
120+ void serialize (S & ser)
121+ {
122+ ser.template container <sizeof (shape_type::value_type)>(_shape, 8 );
123+ ser.template value <sizeof (_val)>(_val._int );
124+ ser.template value <sizeof (_dtype)>(_dtype);
90125 }
91126};
92127
93128tensor_i::future_type Creator::full (const shape_type & shape, const py::object & val, DTypeId dtype)
94129{
95- return defer<DeferredFull>(shape, val, dtype);
130+ auto v = mk_scalar (val, dtype);
131+ return defer<DeferredFull>(shape, v, dtype);
96132}
97133
98134struct DeferredArange : public Deferred
99135{
100136 uint64_t _start, _end, _step;
101137 DTypeId _dtype;
102138
139+ DeferredArange () = default ;
103140 DeferredArange (uint64_t start, uint64_t end, uint64_t step, DTypeId dtype)
104141 : _start(start), _end(end), _step(step), _dtype(dtype)
105142 {}
@@ -108,9 +145,39 @@ struct DeferredArange : public Deferred
108145 {
109146 set_value (std::move (TypeDispatch<x::Creator>(_dtype, _start, _end, _step)));
110147 };
148+
149+ FactoryId factory () const
150+ {
151+ return F_ARANGE;
152+ }
153+
154+ template <typename S>
155+ void serialize (S & ser)
156+ {
157+ ser.template value <sizeof (_start)>(_start);
158+ ser.template value <sizeof (_end)>(_end);
159+ ser.template value <sizeof (_step)>(_step);
160+ ser.template value <sizeof (_dtype)>(_dtype);
161+ }
111162};
112163
113164tensor_i::future_type Creator::arange (uint64_t start, uint64_t end, uint64_t step, DTypeId dtype)
114165{
115166 return defer<DeferredArange>(start, end, step, dtype);
116167}
168+
169+ tensor_i::future_type Creator::mk_future (const py::object & b)
170+ {
171+ if (py::isinstance<tensor_i::future_type>(b)) {
172+ return b.cast <tensor_i::future_type>();
173+ } else if (py::isinstance<py::float_>(b)) {
174+ return Creator::full ({1 }, b, FLOAT64);
175+ } else if (py::isinstance<py::int_>(b)) {
176+ return Creator::full ({1 }, b, INT64);
177+ }
178+ throw std::runtime_error (" Invalid right operand to elementwise binary operation" );
179+ };
180+
181+ FACTORY_INIT (DeferredFromShape, F_FROMSHAPE);
182+ FACTORY_INIT (DeferredFull, F_FULL);
183+ FACTORY_INIT (DeferredArange, F_ARANGE);
0 commit comments