@@ -128,10 +128,10 @@ ::mlir::Value DepManager::getDependent(::mlir::OpBuilder & builder, id_type guid
128128 _func.insertArgument (idx, typ, {}, loc);
129129 auto val = _func.getArgument (idx);
130130 _args.push_back ({guid, fut.rank ()});
131- _ivm[guid] = { val, {}} ;
131+ _ivm[guid] = val;
132132 return val;
133133 } else {
134- return d->second . first ;
134+ return d->second ;
135135 }
136136}
137137
@@ -151,22 +151,23 @@ std::vector<void*> DepManager::store_inputs()
151151 auto f = Registry::get (a.first );
152152 f.get ().get ()->add_to_args (res, a.second );
153153 _ivm.erase (a.first ); // inputs need no delivery
154+ _icm.erase (a.first );
154155 }
155156 return res;
156157}
157158
158159void DepManager::addVal (id_type guid, ::mlir::Value val, SetResFunc cb)
159160{
160161 assert (_ivm.find (guid) == _ivm.end ());
161- _ivm[guid] = {val, cb};
162+ _ivm[guid] = val;
163+ _icm[guid] = cb;
162164}
163165
164166void DepManager::drop (id_type guid)
165167{
166- if (auto e = _ivm.find (guid); e != _ivm.end ()) {
167- _ivm.erase (e);
168- // FIXME create delete op
169- }
168+ _ivm.erase (guid);
169+ _icm.erase (guid);
170+ // FIXME create delete op
170171}
171172
172173// Now we have to define the return type as a ValueRange of all arrays which we have created
@@ -186,7 +187,7 @@ uint64_t DepManager::handleResult(::mlir::OpBuilder & builder)
186187 uint64_t sz = 0 ;
187188 unsigned idx = 0 ;
188189 for (auto & v : _ivm) {
189- ::mlir::Value value = v.second . first ;
190+ ::mlir::Value value = v.second ;
190191 // append the type and array/value
191192 auto retDtTyp = value.getType ().dyn_cast <::imex::dist::DistTensorType>();
192193 if (!retDtTyp) {
@@ -207,44 +208,49 @@ uint64_t DepManager::handleResult(::mlir::OpBuilder & builder)
207208 // add return statement
208209 auto ret_value = builder.create <::mlir::func::ReturnOp>(builder.getUnknownLoc (), ret_values);
209210
211+ // clear any reference to MLIR values
212+ _ivm.clear ();
210213 return sz;
211214}
212215
213216void DepManager::deliver (intptr_t * output, uint64_t sz)
214217{
215218 size_t pos = 0 ;
216- for (auto & v : _ivm) {
217- auto value = v.second .first ;
219+ for (auto & v : _icm) {
218220 auto rank = _irm[v.first ];
219- // first extract global shape
220- uint64_t * gs_allocated = reinterpret_cast <uint64_t *>(output[pos]);
221- uint64_t * gs_aligned = reinterpret_cast <uint64_t *>(output[pos+1 ]);
222- intptr_t gs_offset = output[pos+2 ];
223- // no sizes/stride needed
224- pos += memref_sz (1 );
225- // second extract tensor
221+ // first extract tensor
226222 void * t_allocated = reinterpret_cast <void *>(output[pos]);
227223 void * t_aligned = reinterpret_cast <void *>(output[pos+1 ]);
228224 intptr_t t_offset = output[pos+2 ];
229225 intptr_t * t_sizes = output + pos + 3 ;
230226 intptr_t * t_stride = output + pos + 3 + rank;
231227 pos += memref_sz (rank);
232- // third extract local offsets
233- uint64_t * lo_allocated = reinterpret_cast <uint64_t *>(output[pos]);
234- uint64_t * lo_aligned = reinterpret_cast <uint64_t *>(output[pos+1 ]);
235- intptr_t lo_offset = output[pos+2 ];
236- // no sizes/stride needed
237- pos += memref_sz (1 );
238- // last is the team
228+ // second is the team
239229 // auto team = output[pos];
240230 pos += 1 ;
241- // call finalization
242- v.second .second (
243- rank,
244- t_allocated, t_aligned, t_offset, t_sizes, t_stride, // tensor
245- gs_allocated, gs_aligned + gs_offset, // global shape is 1d tensor of uint64_t
246- lo_allocated, lo_aligned + lo_offset // local offset is 1d tensor of uint64_t
247- );
231+ if (rank > 0 ) {
232+ // third extract global shape
233+ uint64_t * gs_allocated = reinterpret_cast <uint64_t *>(output[pos]);
234+ uint64_t * gs_aligned = reinterpret_cast <uint64_t *>(output[pos+1 ]);
235+ intptr_t gs_offset = output[pos+2 ];
236+ // no sizes/stride needed
237+ pos += memref_sz (1 );
238+ // lastly extract local offsets
239+ uint64_t * lo_allocated = reinterpret_cast <uint64_t *>(output[pos]);
240+ uint64_t * lo_aligned = reinterpret_cast <uint64_t *>(output[pos+1 ]);
241+ intptr_t lo_offset = output[pos+2 ];
242+ // no sizes/stride needed
243+ pos += memref_sz (1 );
244+ // call finalization
245+ v.second (rank,
246+ t_allocated, t_aligned, t_offset, t_sizes, t_stride, // tensor
247+ gs_allocated, gs_aligned + gs_offset, // global shape is 1d tensor of uint64_t
248+ lo_allocated, lo_aligned + lo_offset // local offset is 1d tensor of uint64_t
249+ );
250+ } else { // 0d tensor
251+ v.second (rank, t_allocated, t_aligned, t_offset, t_sizes, t_stride,
252+ nullptr , nullptr , nullptr , nullptr );
253+ }
248254 }
249255}
250256
@@ -296,8 +302,9 @@ int JIT::run(::mlir::ModuleOp & module, const std::string & fname, std::vector<v
296302static const char * pass_pipeline =
297303 getenv (" DDPT_PASSES" )
298304 ? getenv(" DDPT_PASSES" )
299- : " func.func(ptensor-dist),convert-dist-to-standard,convert-ptensor-to-linalg,arith-expand,canonicalize,arith-bufferize,func.func(empty-tensor-to-alloc-tensor,scf-bufferize,linalg-bufferize,tensor-bufferize),func-bufferize,canonicalize,func.func(finalizing-bufferize,convert-linalg-to-parallel-loops),canonicalize,fold-memref-alias-ops,lower-affine,convert-scf-to-cf,convert-memref-to-llvm,convert-func-to-llvm,reconcile-unrealized-casts" ;
300-
305+ // : "func.func(ptensor-dist),convert-dist-to-standard,convert-ptensor-to-linalg,arith-expand,canonicalize,arith-bufferize,func.func(empty-tensor-to-alloc-tensor,scf-bufferize,linalg-bufferize,tensor-bufferize),func-bufferize,canonicalize,func.func(finalizing-bufferize,convert-linalg-to-parallel-loops),canonicalize,fold-memref-alias-ops,lower-affine,convert-scf-to-cf,convert-memref-to-llvm,convert-func-to-llvm,reconcile-unrealized-casts";
306+ // : "builtin.module(func.func(ptensor-dist),convert-dist-to-standard,convert-ptensor-to-linalg,arith-bufferize,func.func(empty-tensor-to-alloc-tensor,scf-bufferize,linalg-bufferize,tensor-bufferize,bufferization-bufferize),func-bufferize,func.func(finalizing-bufferize,convert-linalg-to-parallel-loops),canonicalize,fold-memref-alias-ops,expand-strided-metadata,lower-affine,convert-scf-to-cf,convert-memref-to-llvm,convert-func-to-llvm,reconcile-unrealized-casts)";
307+ : " func.func(ptensor-dist),convert-dist-to-standard,convert-ptensor-to-linalg,arith-bufferize,func.func(empty-tensor-to-alloc-tensor,scf-bufferize,linalg-bufferize,tensor-bufferize,bufferization-bufferize),func-bufferize,func.func(finalizing-bufferize,convert-linalg-to-parallel-loops),canonicalize,fold-memref-alias-ops,expand-strided-metadata,lower-affine,convert-scf-to-cf,convert-memref-to-llvm,convert-func-to-llvm,reconcile-unrealized-casts" ;
301308JIT::JIT ()
302309 : _context(::mlir::MLIRContext::Threading::DISABLED),
303310 _pm (&_context),
0 commit comments