Skip to content
This repository was archived by the owner on Jan 26, 2026. It is now read-only.

Commit f97f1e8

Browse files
committed
fixed result handling: always send ready signals, store inputs before output
1 parent 1edc643 commit f97f1e8

File tree

3 files changed

+65
-50
lines changed

3 files changed

+65
-50
lines changed

src/Deferred.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,8 @@ void process_promises() {
123123
}
124124

125125
if (!runables.empty()) {
126+
// get input buffers (before results!)
127+
auto input = std::move(dm.store_inputs());
126128
// create return statement and adjust function type
127129
uint64_t osz = dm.handleResult(builder);
128130
// also request generation of c-wrapper function
@@ -132,9 +134,6 @@ void process_promises() {
132134
// add the function to the module
133135
module.push_back(function);
134136

135-
// get input buffers (before results!)
136-
auto input = std::move(dm.store_inputs());
137-
138137
if (osz > 0 || !input.empty()) {
139138
// compile and run the module
140139
auto output = jit.run(module, fname, input, osz);

src/jit/mlir.cpp

Lines changed: 62 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@
7171
#include "mlir/Transforms/Passes.h"
7272
// #include <mlir/InitAllPasses.h>
7373

74+
#include "mlir/Parser/Parser.h"
75+
7476
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
7577
#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
7678
#include "mlir/ExecutionEngine/ExecutionEngine.h"
@@ -242,66 +244,80 @@ uint64_t DepManager::handleResult(::mlir::OpBuilder &builder) {
242244
_irm[v.first] = rank;
243245
// add sizes of dtensor (3 memrefs + team + balanced) to sz
244246
sz += dtensor_sz(rank);
247+
// clear reference to MLIR value
248+
v.second = nullptr;
245249
++idx;
246250
}
247251

248252
// add return statement
249253
auto ret_value = builder.create<::mlir::func::ReturnOp>(
250254
builder.getUnknownLoc(), ret_values);
251255

252-
// clear any reference to MLIR values
253-
_ivm.clear();
256+
// _ivm defines the order of return values -> do not clear
257+
254258
return 2 * sz;
255259
}
256260

257261
void DepManager::deliver(std::vector<intptr_t> &outputV, uint64_t sz) {
258262
auto output = outputV.data();
259263
size_t pos = 0;
260-
for (auto &v : _icm) {
261-
auto rank = _irm[v.first];
262-
// first extract tensor
263-
void *t_allocated = reinterpret_cast<void *>(output[pos]);
264-
void *t_aligned = reinterpret_cast<void *>(output[pos + 1]);
265-
intptr_t t_offset = output[pos + 2];
266-
intptr_t *t_sizes = output + pos + 3;
267-
intptr_t *t_stride = output + pos + 3 + rank;
268-
pos += memref_sz(rank);
269-
// second is the team
270-
auto team = output[pos];
271-
pos += 1;
272-
// third is balanced
273-
auto balanced = output[pos];
274-
pos += 1;
275-
if (rank > 0) {
276-
// third extract global shape
277-
uint64_t *gs_allocated = reinterpret_cast<uint64_t *>(output[pos]);
278-
uint64_t *gs_aligned = reinterpret_cast<uint64_t *>(output[pos + 1]);
279-
intptr_t gs_offset = output[pos + 2];
280-
// no sizes/stride needed
281-
pos += memref_sz(1);
282-
// lastly extract local offsets
283-
uint64_t *lo_allocated = reinterpret_cast<uint64_t *>(output[pos]);
284-
uint64_t *lo_aligned = reinterpret_cast<uint64_t *>(output[pos + 1]);
285-
intptr_t lo_offset = output[pos + 2];
286-
// no sizes/stride needed
287-
pos += memref_sz(1);
288-
// call finalization
289-
v.second(reinterpret_cast<Transceiver *>(team), rank, t_allocated,
290-
t_aligned, t_offset, t_sizes, t_stride, // tensor
291-
gs_allocated,
292-
gs_aligned + gs_offset, // global shape is 1d tensor of uint64_t
293-
lo_allocated,
294-
lo_aligned + lo_offset, // local offset is 1d tensor of uint64_t
295-
balanced);
296-
} else { // 0d tensor
297-
v.second(reinterpret_cast<Transceiver *>(team), rank, t_allocated,
298-
t_aligned, t_offset, t_sizes, t_stride, nullptr, nullptr,
299-
nullptr, nullptr, 1);
264+
265+
// _ivm defines the order of return values
266+
for (auto &r : _ivm) {
267+
auto guid = r.first;
268+
if (auto v = _icm.find(guid); v != _icm.end()) {
269+
assert(v->first == guid);
270+
auto rank = _irm[guid];
271+
// first extract tensor
272+
void *t_allocated = reinterpret_cast<void *>(output[pos]);
273+
void *t_aligned = reinterpret_cast<void *>(output[pos + 1]);
274+
intptr_t t_offset = output[pos + 2];
275+
intptr_t *t_sizes = output + pos + 3;
276+
intptr_t *t_stride = output + pos + 3 + rank;
277+
pos += memref_sz(rank);
278+
// second is the team
279+
auto team = output[pos];
280+
pos += 1;
281+
// third is balanced
282+
auto balanced = output[pos];
283+
pos += 1;
284+
if (rank > 0) {
285+
// third extract global shape
286+
uint64_t *gs_allocated = reinterpret_cast<uint64_t *>(output[pos]);
287+
uint64_t *gs_aligned = reinterpret_cast<uint64_t *>(output[pos + 1]);
288+
intptr_t gs_offset = output[pos + 2];
289+
// no sizes/stride needed
290+
pos += memref_sz(1);
291+
// lastly extract local offsets
292+
uint64_t *lo_allocated = reinterpret_cast<uint64_t *>(output[pos]);
293+
uint64_t *lo_aligned = reinterpret_cast<uint64_t *>(output[pos + 1]);
294+
intptr_t lo_offset = output[pos + 2];
295+
// no sizes/stride needed
296+
pos += memref_sz(1);
297+
// call finalization
298+
v->second(
299+
reinterpret_cast<Transceiver *>(team), rank, t_allocated, t_aligned,
300+
t_offset, t_sizes, t_stride, // tensor
301+
gs_allocated,
302+
gs_aligned + gs_offset, // global shape is 1d tensor of uint64_t
303+
lo_allocated,
304+
lo_aligned + lo_offset, // local offset is 1d tensor of uint64_t
305+
balanced);
306+
} else { // 0d tensor
307+
v->second(reinterpret_cast<Transceiver *>(team), rank, t_allocated,
308+
t_aligned, t_offset, t_sizes, t_stride, nullptr, nullptr,
309+
nullptr, nullptr, 1);
310+
}
311+
} else {
312+
assert(false);
300313
}
301314
}
302-
for (auto &v : _icr) {
303-
for (auto cb : v.second) {
304-
cb(v.first);
315+
316+
// ready signals will always be sent, at this point they are not linked to a
317+
// return value
318+
for (auto &readyV : _icr) {
319+
for (auto cb : readyV.second) {
320+
cb(readyV.first);
305321
}
306322
}
307323
}
@@ -327,6 +343,7 @@ std::vector<intptr_t> JIT::run(::mlir::ModuleOp &module,
327343
module = cached;
328344
std::cerr << "using cached module" << std::endl;
329345
} else {
346+
std::cerr << "compiling..." << std::endl;
330347
cache.push_back(std::make_pair(cksm, module));
331348
}
332349
}

test/stencil-2d.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,8 +183,7 @@ def main():
183183
for s in range(-r, r + 1):
184184
for t in range(-r, r + 1):
185185
B[r:b, r:b] += W[r + t, r + s] * A[r + t : b + t, r + s : b + s]
186-
A[0:n, 0:n] = A + 1.0
187-
# A += 1.0
186+
A = A + 1.0
188187

189188
np.sync()
190189
t1 = timer()

0 commit comments

Comments
 (0)