@@ -79,13 +79,11 @@ namespace x {
7979
8080struct DeferredFromShape : public Deferred {
8181 shape_type _shape;
82- DTypeId _dtype;
8382 CreatorId _op;
8483
8584 DeferredFromShape () = default ;
8685 DeferredFromShape (CreatorId op, const shape_type &shape, DTypeId dtype)
87- : Deferred(dtype, shape.size(), true ), _shape(shape), _dtype(dtype),
88- _op (op) {}
86+ : Deferred(dtype, shape.size(), true ), _shape(shape), _op(op) {}
8987
9088 void run () {
9189 // set_value(std::move(TypeDispatch<x::Creator>(_dtype, _op, _shape)));
@@ -97,7 +95,6 @@ struct DeferredFromShape : public Deferred {
9795
9896 template <typename S> void serialize (S &ser) {
9997 ser.template container <sizeof (shape_type::value_type)>(_shape, 8 );
100- ser.template value <sizeof (_dtype)>(_dtype);
10198 ser.template value <sizeof (_op)>(_op);
10299 }
103100};
@@ -110,12 +107,10 @@ ddptensor *Creator::create_from_shape(CreatorId op, const shape_type &shape,
110107struct DeferredFull : public Deferred {
111108 shape_type _shape;
112109 PyScalar _val;
113- DTypeId _dtype;
114110
115111 DeferredFull () = default ;
116112 DeferredFull (const shape_type &shape, PyScalar val, DTypeId dtype)
117- : Deferred(dtype, shape.size(), true ), _shape(shape), _val(val),
118- _dtype (dtype) {}
113+ : Deferred(dtype, shape.size(), true ), _shape(shape), _val(val) {}
119114
120115 void run () {
121116 // auto op = FULL;
@@ -189,6 +184,8 @@ ddptensor *Creator::full(const shape_type &shape, const py::object &val,
189184 return new ddptensor (defer<DeferredFull>(shape, v, dtype));
190185}
191186
187+ // ***************************************************************************
188+
192189struct DeferredArange : public Deferred {
193190 uint64_t _start, _end, _step, _team;
194191
@@ -205,19 +202,25 @@ struct DeferredArange : public Deferred {
205202
206203 bool generate_mlir (::mlir::OpBuilder &builder, ::mlir::Location loc,
207204 jit::DepManager &dm) override {
208- auto start = ::imex::createInt (loc, builder, _start);
209- auto stop = ::imex::createInt (loc, builder, _end);
210- auto step = ::imex::createInt (loc, builder, _step);
211205 // ::mlir::Value
212206 auto team = /* getTransceiver()->nranks() <= 1
213207 ? ::mlir::Value()
214208 :*/
215209 ::imex::createIndex (loc, builder,
216210 reinterpret_cast <uint64_t >(getTransceiver()));
217211
212+ auto _num = (_end - _start + _step + (_step < 0 ? 1 : -1 )) / _step;
213+
214+ auto start = ::imex::createFloat (loc, builder, _start);
215+ auto stop = ::imex::createFloat (loc, builder, _start + _num * _step);
216+ auto num = ::imex::createIndex (loc, builder, _num);
217+ auto rTyp = ::imex::ptensor::PTensorType::get (
218+ ::llvm::ArrayRef<int64_t >{::mlir::ShapedType::kDynamic },
219+ imex::ptensor::toMLIR (builder, jit::getPTDType (_dtype)));
220+
218221 dm.addVal (this ->guid (),
219- builder.create <::imex::ptensor::ARangeOp>(loc, start, stop, step,
220- nullptr , team),
222+ builder.create <::imex::ptensor::LinSpaceOp>(
223+ loc, rTyp, start, stop, num, false , nullptr , team),
221224 [this ](Transceiver *transceiver, uint64_t rank, void *allocated,
222225 void *aligned, intptr_t offset, const intptr_t *sizes,
223226 const intptr_t *strides, uint64_t *gs_allocated,
@@ -239,7 +242,6 @@ struct DeferredArange : public Deferred {
239242 ser.template value <sizeof (_start)>(_start);
240243 ser.template value <sizeof (_end)>(_end);
241244 ser.template value <sizeof (_step)>(_step);
242- ser.template value <sizeof (_dtype)>(_dtype);
243245 }
244246};
245247
@@ -248,6 +250,76 @@ ddptensor *Creator::arange(uint64_t start, uint64_t end, uint64_t step,
248250 return new ddptensor (defer<DeferredArange>(start, end, step, dtype, team));
249251}
250252
253+ // ***************************************************************************
254+
255+ struct DeferredLinspace : public Deferred {
256+ double _start, _end;
257+ uint64_t _num, _team;
258+ bool _endpoint;
259+
260+ DeferredLinspace () = default ;
261+ DeferredLinspace (double start, double end, uint64_t num, bool endpoint,
262+ DTypeId dtype, uint64_t team = 0 )
263+ : Deferred(dtype, 1 , true ), _start(start), _end(end), _num(num),
264+ _team (team), _endpoint(endpoint) {}
265+
266+ void run () override {
267+ // set_value(std::move(TypeDispatch<x::Creator>(_dtype, _start, _end,
268+ // _num)));
269+ };
270+
271+ bool generate_mlir (::mlir::OpBuilder &builder, ::mlir::Location loc,
272+ jit::DepManager &dm) override {
273+ // ::mlir::Value
274+ auto team = /* getTransceiver()->nranks() <= 1
275+ ? ::mlir::Value()
276+ :*/
277+ ::imex::createIndex (loc, builder,
278+ reinterpret_cast <uint64_t >(getTransceiver()));
279+
280+ auto start = ::imex::createFloat (loc, builder, _start);
281+ auto stop = ::imex::createFloat (loc, builder, _end);
282+ auto num = ::imex::createIndex (loc, builder, _num);
283+ auto rTyp = ::imex::ptensor::PTensorType::get (
284+ ::llvm::ArrayRef<int64_t >{::mlir::ShapedType::kDynamic },
285+ imex::ptensor::toMLIR (builder, jit::getPTDType (_dtype)));
286+
287+ dm.addVal (this ->guid (),
288+ builder.create <::imex::ptensor::LinSpaceOp>(
289+ loc, rTyp, start, stop, num, _endpoint, nullptr , team),
290+ [this ](Transceiver *transceiver, uint64_t rank, void *allocated,
291+ void *aligned, intptr_t offset, const intptr_t *sizes,
292+ const intptr_t *strides, uint64_t *gs_allocated,
293+ uint64_t *gs_aligned, uint64_t *lo_allocated,
294+ uint64_t *lo_aligned, uint64_t balanced) {
295+ assert (rank == 1 );
296+ assert (strides[0 ] == 1 );
297+ this ->set_value (std::move (
298+ mk_tnsr (transceiver, _dtype, rank, allocated, aligned,
299+ offset, sizes, strides, gs_allocated, gs_aligned,
300+ lo_allocated, lo_aligned, balanced)));
301+ });
302+ return false ;
303+ }
304+
305+ FactoryId factory () const { return F_ARANGE; }
306+
307+ template <typename S> void serialize (S &ser) {
308+ ser.template value <sizeof (_start)>(_start);
309+ ser.template value <sizeof (_end)>(_end);
310+ ser.template value <sizeof (_num)>(_num);
311+ ser.template value <sizeof (_endpoint)>(_endpoint);
312+ }
313+ };
314+
315+ ddptensor *Creator::linspace (double start, double end, uint64_t num,
316+ bool endpoint, DTypeId dtype, uint64_t team) {
317+ return new ddptensor (
318+ defer<DeferredLinspace>(start, end, num, endpoint, dtype, team));
319+ }
320+
321+ // ***************************************************************************
322+
251323std::pair<ddptensor *, bool > Creator::mk_future (const py::object &b) {
252324 if (py::isinstance<ddptensor>(b)) {
253325 return {b.cast <ddptensor *>(), false };
@@ -263,3 +335,4 @@ std::pair<ddptensor *, bool> Creator::mk_future(const py::object &b) {
263335FACTORY_INIT (DeferredFromShape, F_FROMSHAPE);
264336FACTORY_INIT (DeferredFull, F_FULL);
265337FACTORY_INIT (DeferredArange, F_ARANGE);
338+ FACTORY_INIT (DeferredLinspace, F_LINSPACE);
0 commit comments