@@ -64,32 +64,33 @@ struct DeferredFull : public Deferred {
6464 ::imex::ptensor::DType dtyp;
6565 ::mlir::Value val = dispatch<ValAndDType>(_dtype, builder, loc, _val, dtyp);
6666
67- auto team =
68- _team == 0
69- ? ::mlir::Value ()
70- : ::imex::createIndex (loc, builder,
71- reinterpret_cast <uint64_t >(getTransceiver ()));
67+ auto transceiver = getTransceiver ();
68+ auto teamV = team () == 0
69+ ? ::mlir::Value ()
70+ : ::imex::createIndex (loc, builder,
71+ reinterpret_cast <uint64_t >(team ()));
7272
7373 auto rTyp = ::imex::ptensor::PTensorType::get (
7474 shape (), imex::ptensor::toMLIR (builder, dtyp));
7575
7676 dm.addVal (this ->guid (),
7777 builder.create <::imex::ptensor::CreateOp>(loc, rTyp, shp, dtyp,
78- val, nullptr , team ),
79- [this ](Transceiver *transceiver, uint64_t rank, void *l_allocated,
80- void *l_aligned, intptr_t l_offset ,
81- const intptr_t *l_sizes, const intptr_t *l_strides ,
82- void *o_allocated, void * o_aligned, intptr_t o_offset,
78+ val, nullptr , teamV ),
79+ [this ](uint64_t rank, void *l_allocated, void *l_aligned ,
80+ intptr_t l_offset, const intptr_t *l_sizes ,
81+ const intptr_t *l_strides, void *o_allocated ,
82+ void *o_aligned, intptr_t o_offset,
8383 const intptr_t *o_sizes, const intptr_t *o_strides,
8484 void *r_allocated, void *r_aligned, intptr_t r_offset,
8585 const intptr_t *r_sizes, const intptr_t *r_strides,
8686 uint64_t *lo_allocated, uint64_t *lo_aligned) {
8787 assert (rank == this ->rank ());
8888 this ->set_value (std::move (mk_tnsr (
89- transceiver, _dtype, this ->shape (), l_allocated, l_aligned,
90- l_offset, l_sizes, l_strides, o_allocated, o_aligned,
91- o_offset, o_sizes, o_strides, r_allocated, r_aligned,
92- r_offset, r_sizes, r_strides, lo_allocated, lo_aligned)));
89+ reinterpret_cast <Transceiver *>(this ->team ()), _dtype,
90+ this ->shape (), l_allocated, l_aligned, l_offset, l_sizes,
91+ l_strides, o_allocated, o_aligned, o_offset, o_sizes,
92+ o_strides, r_allocated, r_aligned, r_offset, r_sizes,
93+ r_strides, lo_allocated, lo_aligned)));
9394 });
9495 return false ;
9596 }
@@ -126,11 +127,11 @@ struct DeferredArange : public Deferred {
126127 bool generate_mlir (::mlir::OpBuilder &builder, ::mlir::Location loc,
127128 jit::DepManager &dm) override {
128129 // ::mlir::Value
129- auto team =
130- _team == 0
131- ? ::mlir::Value ()
132- : ::imex::createIndex (loc, builder,
133- reinterpret_cast <uint64_t >(getTransceiver ()));
130+ auto transceiver = getTransceiver ();
131+ auto teamV = team () == 0
132+ ? ::mlir::Value ()
133+ : ::imex::createIndex (loc, builder,
134+ reinterpret_cast <uint64_t >(team ()));
134135
135136 auto _num = shape ()[0 ];
136137
@@ -142,22 +143,23 @@ struct DeferredArange : public Deferred {
142143
143144 dm.addVal (this ->guid (),
144145 builder.create <::imex::ptensor::LinSpaceOp>(
145- loc, rTyp, start, stop, num, false , nullptr , team ),
146- [this ](Transceiver *transceiver, uint64_t rank, void *l_allocated,
147- void *l_aligned, intptr_t l_offset ,
148- const intptr_t *l_sizes, const intptr_t *l_strides ,
149- void *o_allocated, void * o_aligned, intptr_t o_offset,
146+ loc, rTyp, start, stop, num, false , nullptr , teamV ),
147+ [this ](uint64_t rank, void *l_allocated, void *l_aligned ,
148+ intptr_t l_offset, const intptr_t *l_sizes ,
149+ const intptr_t *l_strides, void *o_allocated ,
150+ void *o_aligned, intptr_t o_offset,
150151 const intptr_t *o_sizes, const intptr_t *o_strides,
151152 void *r_allocated, void *r_aligned, intptr_t r_offset,
152153 const intptr_t *r_sizes, const intptr_t *r_strides,
153154 uint64_t *lo_allocated, uint64_t *lo_aligned) {
154155 assert (rank == 1 );
155156 assert (l_strides[0 ] == 1 );
156157 this ->set_value (std::move (mk_tnsr (
157- transceiver, _dtype, this ->shape (), l_allocated, l_aligned,
158- l_offset, l_sizes, l_strides, o_allocated, o_aligned,
159- o_offset, o_sizes, o_strides, r_allocated, r_aligned,
160- r_offset, r_sizes, r_strides, lo_allocated, lo_aligned)));
158+ reinterpret_cast <Transceiver *>(this ->team ()), _dtype,
159+ this ->shape (), l_allocated, l_aligned, l_offset, l_sizes,
160+ l_strides, o_allocated, o_aligned, o_offset, o_sizes,
161+ o_strides, r_allocated, r_aligned, r_offset, r_sizes,
162+ r_strides, lo_allocated, lo_aligned)));
161163 });
162164 return false ;
163165 }
@@ -193,11 +195,10 @@ struct DeferredLinspace : public Deferred {
193195 bool generate_mlir (::mlir::OpBuilder &builder, ::mlir::Location loc,
194196 jit::DepManager &dm) override {
195197 // ::mlir::Value
196- auto team =
197- _team == 0
198- ? ::mlir::Value ()
199- : ::imex::createIndex (loc, builder,
200- reinterpret_cast <uint64_t >(getTransceiver ()));
198+ auto teamV = team () == 0
199+ ? ::mlir::Value ()
200+ : ::imex::createIndex (loc, builder,
201+ reinterpret_cast <uint64_t >(team ()));
201202
202203 auto start = ::imex::createFloat (loc, builder, _start);
203204 auto stop = ::imex::createFloat (loc, builder, _end);
@@ -207,22 +208,23 @@ struct DeferredLinspace : public Deferred {
207208
208209 dm.addVal (this ->guid (),
209210 builder.create <::imex::ptensor::LinSpaceOp>(
210- loc, rTyp, start, stop, num, _endpoint, nullptr , team ),
211- [this ](Transceiver *transceiver, uint64_t rank, void *l_allocated,
212- void *l_aligned, intptr_t l_offset ,
213- const intptr_t *l_sizes, const intptr_t *l_strides ,
214- void *o_allocated, void * o_aligned, intptr_t o_offset,
211+ loc, rTyp, start, stop, num, _endpoint, nullptr , teamV ),
212+ [this ](uint64_t rank, void *l_allocated, void *l_aligned ,
213+ intptr_t l_offset, const intptr_t *l_sizes ,
214+ const intptr_t *l_strides, void *o_allocated ,
215+ void *o_aligned, intptr_t o_offset,
215216 const intptr_t *o_sizes, const intptr_t *o_strides,
216217 void *r_allocated, void *r_aligned, intptr_t r_offset,
217218 const intptr_t *r_sizes, const intptr_t *r_strides,
218219 uint64_t *lo_allocated, uint64_t *lo_aligned) {
219220 assert (rank == 1 );
220221 assert (l_strides[0 ] == 1 );
221222 this ->set_value (std::move (mk_tnsr (
222- transceiver, _dtype, this ->shape (), l_allocated, l_aligned,
223- l_offset, l_sizes, l_strides, o_allocated, o_aligned,
224- o_offset, o_sizes, o_strides, r_allocated, r_aligned,
225- r_offset, r_sizes, r_strides, lo_allocated, lo_aligned)));
223+ reinterpret_cast <Transceiver *>(this ->team ()), _dtype,
224+ this ->shape (), l_allocated, l_aligned, l_offset, l_sizes,
225+ l_strides, o_allocated, o_aligned, o_offset, o_sizes,
226+ o_strides, r_allocated, r_aligned, r_offset, r_sizes,
227+ r_strides, lo_allocated, lo_aligned)));
226228 });
227229 return false ;
228230 }
0 commit comments