1818#include < mlir/Dialect/Tensor/IR/Tensor.h>
1919#include < mlir/IR/Builders.h>
2020
21- #if 0
22- namespace x {
23-
24- template<typename T>
25- class Creator
26- {
27- public:
28- using ptr_type = typename tensor_i::ptr_type;
29- using typed_ptr_type = typename DPTensorX<T>::typed_ptr_type;
30-
31- static ptr_type op(CreatorId c, const shape_type & shp)
32- {
33- PVSlice pvslice(shp);
34- shape_type shape(std::move(pvslice.tile_shape()));
35- switch(c) {
36- case EMPTY:
37- return operatorx<T>::mk_tx(std::move(pvslice), std::move(xt::empty<T>(std::move(shape))));
38- case ONES:
39- return operatorx<T>::mk_tx(std::move(pvslice), std::move(xt::ones<T>(std::move(shape))));
40- case ZEROS:
41- return operatorx<T>::mk_tx(std::move(pvslice), std::move(xt::zeros<T>(std::move(shape))));
42- default:
43- throw std::runtime_error("Unknown creator");
44- };
45- };
46-
47- static ptr_type op(CreatorId c, const shape_type & shp, PyScalar v)
48- {
49- T val;
50- if constexpr (std::is_integral<T>::value) val = static_cast<T>(v._int);
51- else if constexpr (std::is_floating_point<T>::value) val = static_cast<T>(v._float);
52- if(c == FULL) {
53- if(VPROD(shp) <= 1) {
54- return operatorx<T>::mk_tx(val, REPLICATED);
55- }
56- PVSlice pvslice(shp);
57- shape_type shape(std::move(pvslice.tile_shape()));
58- auto a = xt::empty<T>(std::move(shape));
59- a.fill(val);
60- return operatorx<T>::mk_tx(std::move(pvslice), std::move(a));
61- }
62- throw std::runtime_error("Unknown creator");
63- }
64-
65- static ptr_type op(uint64_t start, uint64_t end, uint64_t step)
66- {
67- PVSlice pvslice({static_cast<uint64_t>(Slice(start, end, step).size())});
68- auto lslc = pvslice.local_slice();
69- const auto & l1dslc = lslc.dim(0);
70-
71- auto a = xt::arange<T>(start + l1dslc._start*step, start + l1dslc._end * step, l1dslc._step);
72- auto r = operatorx<T>::mk_tx(std::move(pvslice), std::move(a));
73-
74- return r;
75- }
76- }; // class creatorx
77- } // namespace x
78- #endif // if 0
79-
80- struct DeferredFromShape : public Deferred {
81- shape_type _shape;
82- CreatorId _op;
83-
84- DeferredFromShape () = default ;
85- DeferredFromShape (CreatorId op, const shape_type &shape, DTypeId dtype)
86- : Deferred(dtype, shape.size(), true ), _shape(shape), _op(op) {}
87-
88- void run () {
89- // set_value(std::move(TypeDispatch<x::Creator>(_dtype, _op, _shape)));
21+ inline uint64_t mkTeam (uint64_t team) {
22+ if (team && getTransceiver ()->nranks () > 1 ) {
23+ return 1 ;
9024 }
91-
92- // FIXME mlir
93-
94- FactoryId factory () const { return F_FROMSHAPE; }
95-
96- template <typename S> void serialize (S &ser) {
97- ser.template container <sizeof (shape_type::value_type)>(_shape, 8 );
98- ser.template value <sizeof (_op)>(_op);
99- }
100- };
101-
102- ddptensor *Creator::create_from_shape (CreatorId op, const shape_type &shape,
103- DTypeId dtype) {
104- return new ddptensor (defer<DeferredFromShape>(op, shape, dtype));
25+ return 0 ;
10526}
10627
10728struct DeferredFull : public Deferred {
10829 shape_type _shape;
10930 PyScalar _val;
11031
11132 DeferredFull () = default ;
112- DeferredFull (const shape_type &shape, PyScalar val, DTypeId dtype)
113- : Deferred(dtype, shape.size(), true ), _shape(shape), _val(val) {}
33+ DeferredFull (const shape_type &shape, PyScalar val, DTypeId dtype,
34+ uint64_t team)
35+ : Deferred(dtype, shape.size(), team, true ), _shape(shape), _val(val) {}
11436
11537 void run () {
11638 // auto op = FULL;
@@ -146,19 +68,19 @@ struct DeferredFull : public Deferred {
14668 ::imex::ptensor::DType dtyp;
14769 ::mlir::Value val = dispatch<ValAndDType>(_dtype, builder, loc, _val, dtyp);
14870
149- auto team = /* getTransceiver()->nranks() <= 1
150- ? ::mlir::Value()
151- : */
152- ::imex::createIndex (loc, builder,
153- reinterpret_cast <uint64_t >(getTransceiver()));
71+ auto team =
72+ _team == 0
73+ ? :: mlir::Value ()
74+ : ::imex::createIndex (loc, builder,
75+ reinterpret_cast <uint64_t >(getTransceiver ()));
15476
15577 dm.addVal (this ->guid (),
15678 builder.create <::imex::ptensor::CreateOp>(loc, shp, dtyp, val,
15779 nullptr , team),
15880 [this ](Transceiver *transceiver, uint64_t rank, void *allocated,
15981 void *aligned, intptr_t offset, const intptr_t *sizes,
160- const intptr_t *strides, uint64_t *gs_allocated,
161- uint64_t *gs_aligned, uint64_t *lo_allocated,
82+ const intptr_t *strides, int64_t *gs_allocated,
83+ int64_t *gs_aligned, uint64_t *lo_allocated,
16284 uint64_t *lo_aligned, uint64_t balanced) {
16385 assert (rank == this ->_shape .size ());
16486 this ->set_value (std::move (
@@ -179,21 +101,20 @@ struct DeferredFull : public Deferred {
179101};
180102
181103ddptensor *Creator::full (const shape_type &shape, const py::object &val,
182- DTypeId dtype) {
104+ DTypeId dtype, uint64_t team ) {
183105 auto v = mk_scalar (val, dtype);
184- return new ddptensor (defer<DeferredFull>(shape, v, dtype));
106+ return new ddptensor (defer<DeferredFull>(shape, v, dtype, mkTeam (team) ));
185107}
186108
187109// ***************************************************************************
188110
189111struct DeferredArange : public Deferred {
190- uint64_t _start, _end, _step, _team ;
112+ uint64_t _start, _end, _step;
191113
192114 DeferredArange () = default ;
193115 DeferredArange (uint64_t start, uint64_t end, uint64_t step, DTypeId dtype,
194- uint64_t team = 0 )
195- : Deferred(dtype, 1 , true ), _start(start), _end(end), _step(step),
196- _team (team) {}
116+ uint64_t team)
117+ : Deferred(dtype, 1 , team, true ), _start(start), _end(end), _step(step) {}
197118
198119 void run () override {
199120 // set_value(std::move(TypeDispatch<x::Creator>(_dtype, _start, _end,
@@ -203,11 +124,11 @@ struct DeferredArange : public Deferred {
203124 bool generate_mlir (::mlir::OpBuilder &builder, ::mlir::Location loc,
204125 jit::DepManager &dm) override {
205126 // ::mlir::Value
206- auto team = /* getTransceiver()->nranks() <= 1
207- ? ::mlir::Value()
208- : */
209- ::imex::createIndex (loc, builder,
210- reinterpret_cast <uint64_t >(getTransceiver()));
127+ auto team =
128+ _team == 0
129+ ? :: mlir::Value ()
130+ : ::imex::createIndex (loc, builder,
131+ reinterpret_cast <uint64_t >(getTransceiver ()));
211132
212133 auto _num = (_end - _start + _step + (_step < 0 ? 1 : -1 )) / _step;
213134
@@ -223,8 +144,8 @@ struct DeferredArange : public Deferred {
223144 loc, rTyp, start, stop, num, false , nullptr , team),
224145 [this ](Transceiver *transceiver, uint64_t rank, void *allocated,
225146 void *aligned, intptr_t offset, const intptr_t *sizes,
226- const intptr_t *strides, uint64_t *gs_allocated,
227- uint64_t *gs_aligned, uint64_t *lo_allocated,
147+ const intptr_t *strides, int64_t *gs_allocated,
148+ int64_t *gs_aligned, uint64_t *lo_allocated,
228149 uint64_t *lo_aligned, uint64_t balanced) {
229150 assert (rank == 1 );
230151 assert (strides[0 ] == 1 );
@@ -247,21 +168,22 @@ struct DeferredArange : public Deferred {
247168
248169ddptensor *Creator::arange (uint64_t start, uint64_t end, uint64_t step,
249170 DTypeId dtype, uint64_t team) {
250- return new ddptensor (defer<DeferredArange>(start, end, step, dtype, team));
171+ return new ddptensor (
172+ defer<DeferredArange>(start, end, step, dtype, mkTeam (team)));
251173}
252174
253175// ***************************************************************************
254176
255177struct DeferredLinspace : public Deferred {
256178 double _start, _end;
257- uint64_t _num, _team ;
179+ uint64_t _num;
258180 bool _endpoint;
259181
260182 DeferredLinspace () = default ;
261183 DeferredLinspace (double start, double end, uint64_t num, bool endpoint,
262- DTypeId dtype, uint64_t team = 0 )
263- : Deferred(dtype, 1 , true ), _start(start), _end(end), _num(num),
264- _team (team), _endpoint(endpoint) {}
184+ DTypeId dtype, uint64_t team)
185+ : Deferred(dtype, 1 , team, true ), _start(start), _end(end), _num(num),
186+ _endpoint (endpoint) {}
265187
266188 void run () override {
267189 // set_value(std::move(TypeDispatch<x::Creator>(_dtype, _start, _end,
@@ -271,11 +193,11 @@ struct DeferredLinspace : public Deferred {
271193 bool generate_mlir (::mlir::OpBuilder &builder, ::mlir::Location loc,
272194 jit::DepManager &dm) override {
273195 // ::mlir::Value
274- auto team = /* getTransceiver()->nranks() <= 1
275- ? ::mlir::Value()
276- : */
277- ::imex::createIndex (loc, builder,
278- reinterpret_cast <uint64_t >(getTransceiver()));
196+ auto team =
197+ _team == 0
198+ ? :: mlir::Value ()
199+ : ::imex::createIndex (loc, builder,
200+ reinterpret_cast <uint64_t >(getTransceiver ()));
279201
280202 auto start = ::imex::createFloat (loc, builder, _start);
281203 auto stop = ::imex::createFloat (loc, builder, _end);
@@ -289,8 +211,8 @@ struct DeferredLinspace : public Deferred {
289211 loc, rTyp, start, stop, num, _endpoint, nullptr , team),
290212 [this ](Transceiver *transceiver, uint64_t rank, void *allocated,
291213 void *aligned, intptr_t offset, const intptr_t *sizes,
292- const intptr_t *strides, uint64_t *gs_allocated,
293- uint64_t *gs_aligned, uint64_t *lo_allocated,
214+ const intptr_t *strides, int64_t *gs_allocated,
215+ int64_t *gs_aligned, uint64_t *lo_allocated,
294216 uint64_t *lo_aligned, uint64_t balanced) {
295217 assert (rank == 1 );
296218 assert (strides[0 ] == 1 );
@@ -315,24 +237,24 @@ struct DeferredLinspace : public Deferred {
315237ddptensor *Creator::linspace (double start, double end, uint64_t num,
316238 bool endpoint, DTypeId dtype, uint64_t team) {
317239 return new ddptensor (
318- defer<DeferredLinspace>(start, end, num, endpoint, dtype, team));
240+ defer<DeferredLinspace>(start, end, num, endpoint, dtype, mkTeam ( team) ));
319241}
320242
321243// ***************************************************************************
322244
323- std::pair<ddptensor *, bool > Creator::mk_future (const py::object &b) {
245+ std::pair<ddptensor *, bool > Creator::mk_future (const py::object &b,
246+ uint64_t team) {
324247 if (py::isinstance<ddptensor>(b)) {
325248 return {b.cast <ddptensor *>(), false };
326249 } else if (py::isinstance<py::float_>(b)) {
327- return {Creator::full ({}, b, FLOAT64), true };
250+ return {Creator::full ({}, b, FLOAT64, team ), true };
328251 } else if (py::isinstance<py::int_>(b)) {
329- return {Creator::full ({}, b, INT64), true };
252+ return {Creator::full ({}, b, INT64, team ), true };
330253 }
331254 throw std::runtime_error (
332255 " Invalid right operand to elementwise binary operation" );
333256};
334257
335- FACTORY_INIT (DeferredFromShape, F_FROMSHAPE);
336258FACTORY_INIT (DeferredFull, F_FULL);
337259FACTORY_INIT (DeferredArange, F_ARANGE);
338260FACTORY_INIT (DeferredLinspace, F_LINSPACE);
0 commit comments