11#include " ddptensor/Creator.hpp"
22#include " ddptensor/TypeDispatch.hpp"
3- #include " ddptensor/x.hpp"
43#include " ddptensor/Deferred.hpp"
54#include " ddptensor/Factory.hpp"
5+ #include " ddptensor/DDPTensorImpl.hpp"
66
77#include < imex/Dialect/PTensor/IR/PTensorOps.h>
88#include < mlir/IR/Builders.h>
99
10+ #if 0
1011namespace x {
1112
1213 template<typename T>
@@ -63,6 +64,7 @@ namespace x {
6364 }
6465 }; // class creatorx
6566} // namespace x
67+ #endif // if 0
6668
6769struct DeferredFromShape : public Deferred
6870{
@@ -72,14 +74,17 @@ struct DeferredFromShape : public Deferred
7274
7375 DeferredFromShape () = default ;
7476 DeferredFromShape (CreatorId op, const shape_type & shape, DTypeId dtype)
75- : _shape(shape), _dtype(dtype), _op(op)
77+ : Deferred(dtype, shape.size()),
78+ _shape (shape), _dtype(dtype), _op(op)
7679 {}
7780
7881 void run ()
7982 {
80- set_value (std::move (TypeDispatch<x::Creator>(_dtype, _op, _shape)));
83+ // set_value(std::move(TypeDispatch<x::Creator>(_dtype, _op, _shape)));
8184 }
8285
86+ // FIXME mlir
87+
8388 FactoryId factory () const
8489 {
8590 return F_FROMSHAPE;
@@ -107,15 +112,18 @@ struct DeferredFull : public Deferred
107112
108113 DeferredFull () = default ;
109114 DeferredFull (const shape_type & shape, PyScalar val, DTypeId dtype)
110- : _shape(shape), _val(val), _dtype(dtype)
115+ : Deferred(dtype, shape.size()),
116+ _shape (shape), _val(val), _dtype(dtype)
111117 {}
112118
113119 void run ()
114120 {
115- auto op = FULL;
116- set_value (std::move (TypeDispatch<x::Creator>(_dtype, op, _shape, _val)));
121+ // auto op = FULL;
122+ // set_value(std::move(TypeDispatch<x::Creator>(_dtype, op, _shape, _val)));
117123 }
118124
125+ // FIXME mlir
126+
119127 FactoryId factory () const
120128 {
121129 return F_FULL;
@@ -139,11 +147,11 @@ ddptensor * Creator::full(const shape_type & shape, const py::object & val, DTyp
139147struct DeferredArange : public Deferred
140148{
141149 uint64_t _start, _end, _step;
142- DTypeId _dtype;
143150
144151 DeferredArange () = default ;
145152 DeferredArange (uint64_t start, uint64_t end, uint64_t step, DTypeId dtype)
146- : _start(start), _end(end), _step(step), _dtype(dtype)
153+ : Deferred(dtype, 1 ),
154+ _start (start), _end(end), _step(step)
147155 {}
148156
149157 void run () override
@@ -153,21 +161,20 @@ struct DeferredArange : public Deferred
153161
154162 ::mlir::Value generate_mlir (::mlir::OpBuilder & builder, ::mlir::Location loc, jit::IdValueMap & ivm) override
155163 {
156- // FIXME the type of the result is hard-coded to uint64_t
157164 // create start, stop and step
158165 auto start = jit::createI64 (loc, builder, _start);
159166 auto end = jit::createI64 (loc, builder, _end);
160167 auto step = jit::createI64 (loc, builder, _step);
161168 // create arange
162169 auto dtype = builder.getI64Type ();
170+ assert (_dtype == INT64 || _dtype == UINT64); // FIXME
163171 llvm::SmallVector<int64_t > shape (1 , -1 ); // ::mlir::ShapedType::kDynamicSize);
164172 auto artype = ::imex::ptensor::PTensorType::get (builder.getContext (), ::mlir::RankedTensorType::get (shape, dtype), true );
165173 auto ar = builder.create <::imex::ptensor::ARangeOp>(loc, artype, start, end, step, true );
166174 auto setter = [this ](uint64_t rank, void *allocated, void *aligned, intptr_t offset, const intptr_t * sizes, const intptr_t * strides) {
167- // FIXME GC assert(allocated == aligned);
168175 assert (rank == 1 );
169176 assert (strides[0 ] == 1 );
170- this ->set_value (std::move (x::operatorx< uint64_t >:: mk_tx ( rank, allocated, aligned, offset, sizes, strides)));
177+ this ->set_value (std::move (mk_tnsr (_dtype, rank, allocated, aligned, offset, sizes, strides)));
171178 };
172179 ivm[_guid] = {ar, setter};
173180 return ar;
0 commit comments