@@ -81,7 +81,7 @@ struct DeferredFromShape : public Deferred
8181
8282 DeferredFromShape () = default ;
8383 DeferredFromShape (CreatorId op, const shape_type & shape, DTypeId dtype)
84- : Deferred(dtype, shape.size()),
84+ : Deferred(dtype, shape.size(), true ),
8585 _shape (shape), _dtype(dtype), _op(op)
8686 {}
8787
@@ -119,7 +119,7 @@ struct DeferredFull : public Deferred
119119
120120 DeferredFull () = default ;
121121 DeferredFull (const shape_type & shape, PyScalar val, DTypeId dtype)
122- : Deferred(dtype, shape.size()),
122+ : Deferred(dtype, shape.size(), true ),
123123 _shape (shape), _val(val), _dtype(dtype)
124124 {}
125125
@@ -158,11 +158,11 @@ struct DeferredFull : public Deferred
158158
159159 dm.addVal (this ->guid (),
160160 builder.create <::imex::ptensor::CreateOp>(loc, shp, dtyp, val, nullptr , team),
161- [this ](uint64_t rank, void *allocated, void *aligned, intptr_t offset, const intptr_t * sizes, const intptr_t * strides,
162- uint64_t * gs_allocated, uint64_t * gs_aligned, uint64_t * lo_allocated, uint64_t * lo_aligned) {
161+ [this ](Transceiver * transceiver, uint64_t rank, void *allocated, void *aligned, intptr_t offset, const intptr_t * sizes, const intptr_t * strides,
162+ uint64_t * gs_allocated, uint64_t * gs_aligned, uint64_t * lo_allocated, uint64_t * lo_aligned, uint64_t balanced ) {
163163 assert (rank == this ->_shape .size ());
164- this ->set_value (std::move (mk_tnsr (_dtype, rank, allocated, aligned, offset, sizes, strides,
165- gs_allocated, gs_aligned, lo_allocated, lo_aligned)));
164+ this ->set_value (std::move (mk_tnsr (transceiver, _dtype, rank, allocated, aligned, offset, sizes, strides,
165+ gs_allocated, gs_aligned, lo_allocated, lo_aligned, balanced )));
166166 });
167167 return false ;
168168 }
@@ -193,7 +193,7 @@ struct DeferredArange : public Deferred
193193
194194 DeferredArange () = default ;
195195 DeferredArange (uint64_t start, uint64_t end, uint64_t step, DTypeId dtype, uint64_t team = 0 )
196- : Deferred(dtype, 1 ),
196+ : Deferred(dtype, 1 , true ),
197197 _start (start), _end(end), _step(step), _team(team)
198198 {}
199199
@@ -211,12 +211,12 @@ struct DeferredArange : public Deferred
211211 auto team = ::imex::createIndex (loc, builder, reinterpret_cast <uint64_t >(getTransceiver ()));
212212 dm.addVal (this ->guid (),
213213 builder.create <::imex::ptensor::ARangeOp>(loc, start, stop, step, nullptr , team),
214- [this ](uint64_t rank, void *allocated, void *aligned, intptr_t offset, const intptr_t * sizes, const intptr_t * strides,
215- uint64_t * gs_allocated, uint64_t * gs_aligned, uint64_t * lo_allocated, uint64_t * lo_aligned) {
214+ [this ](Transceiver * transceiver, uint64_t rank, void *allocated, void *aligned, intptr_t offset, const intptr_t * sizes, const intptr_t * strides,
215+ uint64_t * gs_allocated, uint64_t * gs_aligned, uint64_t * lo_allocated, uint64_t * lo_aligned, uint64_t balanced ) {
216216 assert (rank == 1 );
217217 assert (strides[0 ] == 1 );
218- this ->set_value (std::move (mk_tnsr (_dtype, rank, allocated, aligned, offset, sizes, strides,
219- gs_allocated, gs_aligned, lo_allocated, lo_aligned)));
218+ this ->set_value (std::move (mk_tnsr (transceiver, _dtype, rank, allocated, aligned, offset, sizes, strides,
219+ gs_allocated, gs_aligned, lo_allocated, lo_aligned, balanced )));
220220 });
221221 return false ;
222222 }
0 commit comments