7171#include " mlir/Transforms/Passes.h"
7272// #include <mlir/InitAllPasses.h>
7373
74+ #include " mlir/Parser/Parser.h"
75+
7476#include " mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
7577#include " mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
7678#include " mlir/ExecutionEngine/ExecutionEngine.h"
@@ -242,66 +244,80 @@ uint64_t DepManager::handleResult(::mlir::OpBuilder &builder) {
242244 _irm[v.first ] = rank;
243245 // add sizes of dtensor (3 memrefs + team + balanced) to sz
244246 sz += dtensor_sz (rank);
247+ // clear reference to MLIR value
248+ v.second = nullptr ;
245249 ++idx;
246250 }
247251
248252 // add return statement
249253 auto ret_value = builder.create <::mlir::func::ReturnOp>(
250254 builder.getUnknownLoc (), ret_values);
251255
252- // clear any reference to MLIR values
253- _ivm. clear ();
256+ // _ivm defines the order of return values -> do not clear
257+
254258 return 2 * sz;
255259}
256260
257261void DepManager::deliver (std::vector<intptr_t > &outputV, uint64_t sz) {
258262 auto output = outputV.data ();
259263 size_t pos = 0 ;
260- for (auto &v : _icm) {
261- auto rank = _irm[v.first ];
262- // first extract tensor
263- void *t_allocated = reinterpret_cast <void *>(output[pos]);
264- void *t_aligned = reinterpret_cast <void *>(output[pos + 1 ]);
265- intptr_t t_offset = output[pos + 2 ];
266- intptr_t *t_sizes = output + pos + 3 ;
267- intptr_t *t_stride = output + pos + 3 + rank;
268- pos += memref_sz (rank);
269- // second is the team
270- auto team = output[pos];
271- pos += 1 ;
272- // third is balanced
273- auto balanced = output[pos];
274- pos += 1 ;
275- if (rank > 0 ) {
276- // third extract global shape
277- uint64_t *gs_allocated = reinterpret_cast <uint64_t *>(output[pos]);
278- uint64_t *gs_aligned = reinterpret_cast <uint64_t *>(output[pos + 1 ]);
279- intptr_t gs_offset = output[pos + 2 ];
280- // no sizes/stride needed
281- pos += memref_sz (1 );
282- // lastly extract local offsets
283- uint64_t *lo_allocated = reinterpret_cast <uint64_t *>(output[pos]);
284- uint64_t *lo_aligned = reinterpret_cast <uint64_t *>(output[pos + 1 ]);
285- intptr_t lo_offset = output[pos + 2 ];
286- // no sizes/stride needed
287- pos += memref_sz (1 );
288- // call finalization
289- v.second (reinterpret_cast <Transceiver *>(team), rank, t_allocated,
290- t_aligned, t_offset, t_sizes, t_stride, // tensor
291- gs_allocated,
292- gs_aligned + gs_offset, // global shape is 1d tensor of uint64_t
293- lo_allocated,
294- lo_aligned + lo_offset, // local offset is 1d tensor of uint64_t
295- balanced);
296- } else { // 0d tensor
297- v.second (reinterpret_cast <Transceiver *>(team), rank, t_allocated,
298- t_aligned, t_offset, t_sizes, t_stride, nullptr , nullptr ,
299- nullptr , nullptr , 1 );
264+
265+ // _ivm defines the order of return values
266+ for (auto &r : _ivm) {
267+ auto guid = r.first ;
268+ if (auto v = _icm.find (guid); v != _icm.end ()) {
269+ assert (v->first == guid);
270+ auto rank = _irm[guid];
271+ // first extract tensor
272+ void *t_allocated = reinterpret_cast <void *>(output[pos]);
273+ void *t_aligned = reinterpret_cast <void *>(output[pos + 1 ]);
274+ intptr_t t_offset = output[pos + 2 ];
275+ intptr_t *t_sizes = output + pos + 3 ;
276+ intptr_t *t_stride = output + pos + 3 + rank;
277+ pos += memref_sz (rank);
278+ // second is the team
279+ auto team = output[pos];
280+ pos += 1 ;
281+ // third is balanced
282+ auto balanced = output[pos];
283+ pos += 1 ;
284+ if (rank > 0 ) {
285+ // third extract global shape
286+ uint64_t *gs_allocated = reinterpret_cast <uint64_t *>(output[pos]);
287+ uint64_t *gs_aligned = reinterpret_cast <uint64_t *>(output[pos + 1 ]);
288+ intptr_t gs_offset = output[pos + 2 ];
289+ // no sizes/stride needed
290+ pos += memref_sz (1 );
291+ // lastly extract local offsets
292+ uint64_t *lo_allocated = reinterpret_cast <uint64_t *>(output[pos]);
293+ uint64_t *lo_aligned = reinterpret_cast <uint64_t *>(output[pos + 1 ]);
294+ intptr_t lo_offset = output[pos + 2 ];
295+ // no sizes/stride needed
296+ pos += memref_sz (1 );
297+ // call finalization
298+ v->second (
299+ reinterpret_cast <Transceiver *>(team), rank, t_allocated, t_aligned,
300+ t_offset, t_sizes, t_stride, // tensor
301+ gs_allocated,
302+ gs_aligned + gs_offset, // global shape is 1d tensor of uint64_t
303+ lo_allocated,
304+ lo_aligned + lo_offset, // local offset is 1d tensor of uint64_t
305+ balanced);
306+ } else { // 0d tensor
307+ v->second (reinterpret_cast <Transceiver *>(team), rank, t_allocated,
308+ t_aligned, t_offset, t_sizes, t_stride, nullptr , nullptr ,
309+ nullptr , nullptr , 1 );
310+ }
311+ } else {
312+ assert (false );
300313 }
301314 }
302- for (auto &v : _icr) {
303- for (auto cb : v.second ) {
304- cb (v.first );
315+
316+ // ready signals will always be sent, at this point they are not linked to a
317+ // return value
318+ for (auto &readyV : _icr) {
319+ for (auto cb : readyV.second ) {
320+ cb (readyV.first );
305321 }
306322 }
307323}
@@ -327,6 +343,7 @@ std::vector<intptr_t> JIT::run(::mlir::ModuleOp &module,
327343 module = cached;
328344 std::cerr << " using cached module" << std::endl;
329345 } else {
346+ std::cerr << " compiling..." << std::endl;
330347 cache.push_back (std::make_pair (cksm, module ));
331348 }
332349 }
0 commit comments