|
1 | 1 | #include "ddptensor/Creator.hpp" |
2 | 2 | #include "ddptensor/TypeDispatch.hpp" |
3 | 3 | #include "ddptensor/x.hpp" |
| 4 | +#include "ddptensor/Deferred.hpp" |
4 | 5 |
|
5 | 6 | namespace x { |
6 | 7 |
|
7 | 8 | template<typename T> |
8 | 9 | class Creator |
9 | 10 | { |
10 | 11 | public: |
11 | | - using ptr_type = DPTensorBaseX::ptr_type; |
| 12 | + using ptr_type = typename tensor_i::ptr_type; |
12 | 13 | using typed_ptr_type = typename DPTensorX<T>::typed_ptr_type; |
13 | 14 |
|
14 | 15 | static ptr_type op(CreatorId c, const shape_type & shp) |
@@ -51,18 +52,65 @@ namespace x { |
51 | 52 | }; // class creatorx |
52 | 53 | } // namespace x |
53 | 54 |
|
54 | | -tensor_i::ptr_type Creator::create_from_shape(CreatorId op, const shape_type & shape, DTypeId dtype) |
| 55 | +struct DeferredFromShape : public Deferred |
55 | 56 | { |
56 | | - return TypeDispatch<x::Creator>(dtype, op, shape); |
| 57 | + CreatorId _op; |
| 58 | + shape_type _shape; |
| 59 | + DTypeId _dtype; |
| 60 | + |
| 61 | + DeferredFromShape(CreatorId op, const shape_type & shape, DTypeId dtype) |
| 62 | + : _op(op), _shape(shape), _dtype(dtype) |
| 63 | + {} |
| 64 | + |
| 65 | + void run() |
| 66 | + { |
| 67 | + set_value(TypeDispatch<x::Creator>(_dtype, _op, _shape)); |
| 68 | + } |
| 69 | +}; |
| 70 | + |
| 71 | +tensor_i::future_type Creator::create_from_shape(CreatorId op, const shape_type & shape, DTypeId dtype) |
| 72 | +{ |
| 73 | + return defer<DeferredFromShape>(op, shape, dtype); |
57 | 74 | } |
58 | 75 |
|
59 | | -tensor_i::ptr_type Creator::full(const shape_type & shape, const py::object & val, DTypeId dtype) |
| 76 | +struct DeferredFull : public Deferred |
60 | 77 | { |
61 | | - auto op = FULL; |
62 | | - return TypeDispatch<x::Creator>(dtype, op, shape, val); |
| 78 | + shape_type _shape; |
| 79 | + const py::object & _val; |
| 80 | + DTypeId _dtype; |
| 81 | + |
| 82 | + DeferredFull(const shape_type & shape, const py::object & val, DTypeId dtype) |
| 83 | + : _shape(shape), _val(val), _dtype(dtype) |
| 84 | + {} |
| 85 | + |
| 86 | + void run() |
| 87 | + { |
| 88 | + auto op = FULL; |
| 89 | + set_value(TypeDispatch<x::Creator>(_dtype, op, _shape, _val)); |
| 90 | + } |
| 91 | +}; |
| 92 | + |
| 93 | +tensor_i::future_type Creator::full(const shape_type & shape, const py::object & val, DTypeId dtype) |
| 94 | +{ |
| 95 | + return defer<DeferredFull>(shape, val, dtype); |
63 | 96 | } |
64 | 97 |
|
65 | | -tensor_i::ptr_type Creator::arange(uint64_t start, uint64_t end, uint64_t step, DTypeId dtype) |
| 98 | +struct DeferredArange : public Deferred |
| 99 | +{ |
| 100 | + uint64_t _start, _end, _step; |
| 101 | + DTypeId _dtype; |
| 102 | + |
| 103 | + DeferredArange(uint64_t start, uint64_t end, uint64_t step, DTypeId dtype) |
| 104 | + : _start(start), _end(end), _step(step), _dtype(dtype) |
| 105 | + {} |
| 106 | + |
| 107 | + void run() |
| 108 | + { |
| 109 | + set_value(TypeDispatch<x::Creator>(_dtype, _start, _end, _step)); |
| 110 | + }; |
| 111 | +}; |
| 112 | + |
| 113 | +tensor_i::future_type Creator::arange(uint64_t start, uint64_t end, uint64_t step, DTypeId dtype) |
66 | 114 | { |
67 | | - return TypeDispatch<x::Creator>(dtype, start, end, step); |
| 115 | + return defer<DeferredArange>(start, end, step, dtype); |
68 | 116 | } |
0 commit comments