@@ -273,7 +273,8 @@ struct DeferredSetItem : public Deferred {
273273 DeferredSetItem (const tensor_i::future_type &a,
274274 const tensor_i::future_type &b,
275275 const std::vector<py::slice> &v)
276- : _a(a.id()), _b(b.id()), _slc(v) {}
276+ : Deferred(a.id(), a.dtype(), a.rank(), a.balanced()), _a(a.id()),
277+ _b (b.id()), _slc(v) {}
277278
278279 bool generate_mlir (::mlir::OpBuilder &builder, ::mlir::Location loc,
279280 jit::DepManager &dm) override {
@@ -298,19 +299,10 @@ struct DeferredSetItem : public Deferred {
298299 (void )builder.create <::imex::ptensor::InsertSliceOp>(loc, av, bv, offsV,
299300 sizesV, stridesV);
300301 // ... and use av as to later create the ptensor
301- dm.addVal (this ->guid (), av,
302- [this ](Transceiver *transceiver, uint64_t rank, void *allocated,
303- void *aligned, intptr_t offset, const intptr_t *sizes,
304- const intptr_t *strides, uint64_t *gs_allocated,
305- uint64_t *gs_aligned, uint64_t *lo_allocated,
306- uint64_t *lo_aligned, uint64_t balanced) {
307- this ->set_value (Registry::get (this ->_a ).get ());
308- // this->set_value(std::move(mk_tnsr(dtype, rank, allocated,
309- // aligned, offset, sizes, strides,
310- // gs_allocated, gs_aligned,
311- // lo_allocated,
312- // lo_aligned)));
313- });
302+ dm.addReady (this ->guid (), [this ](id_type guid) {
303+ assert (this ->guid () == guid);
304+ this ->set_value (Registry::get (this ->_a ).get ());
305+ });
314306 return false ;
315307 }
316308
@@ -368,6 +360,10 @@ struct DeferredGetItem : public Deferred {
368360 // auto outnd = nd == 0 || _slc.size() == 1 ? 0 : nd;
369361 auto outTyp =
370362 ::imex::ptensor::PTensorType::get (shape, oTyp.getElementType());
363+ // if(auto dtyp = av.getType().dyn_cast<::imex::dist::DistTensorType>()) {
364+ // av = builder.create<::mlir::UnrealizedConversionCastOp>(loc,
365+ // dtyp.getPTensorType(), av).getResult(0);
366+ // }
371367 // now we can create the PTensor op using the above Values
372368 auto res = builder.create <::imex::ptensor::SubviewOp>(
373369 loc, outTyp, av, offsV, sizesV, stridesV);
@@ -413,10 +409,10 @@ ddptensor *SetItem::__setitem__(ddptensor &a, const std::vector<py::slice> &v,
413409 const py::object &b) {
414410
415411 auto bb = Creator::mk_future (b);
416- auto res = new ddptensor (defer<DeferredSetItem>(a.get (), bb.first ->get (), v));
412+ a. put (defer<DeferredSetItem>(a.get (), bb.first ->get (), v));
417413 if (bb.second )
418414 delete bb.first ;
419- return res ;
415+ return &a ;
420416}
421417
422418py::object GetItem::get_slice (const ddptensor &a,
0 commit comments