Skip to content

Commit 676f9b5

Browse files
committed
添加异步局部调度实现及相关测试
1 parent bb633c5 commit 676f9b5

6 files changed

Lines changed: 399 additions & 1 deletion

File tree

mlir/optimization/scheduler/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ add_executable(
1717
lib/MemrefLifetime.cpp
1818
lib/FusionFeasibility.cpp
1919
lib/LivenessAdapter.cpp
20+
lib/LocalListScheduling.cpp
2021
)
2122

2223
# add_dependencies(lab-scheduler ToyCh6ShapeInferenceInterfaceIncGen

mlir/optimization/scheduler/include/lab/LabPasses.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,6 @@ std::unique_ptr<Pass> createLabPipelinePlanPass();
1515
std::unique_ptr<Pass> createLabLivenessPass();
1616
std::unique_ptr<Pass> createLabMemrefLifetimePass();
1717
std::unique_ptr<Pass> createLabFusionFeasibilityPass();
18+
std::unique_ptr<Pass> createAsyncLocalSchedulePass();
1819

1920
} // namespace mlir

mlir/optimization/scheduler/lab-opt.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "lab/LabPasses.h"
22
#include "mlir/Dialect/Affine/IR/AffineOps.h"
33
#include "mlir/Dialect/Arith/IR/Arith.h"
4+
#include "mlir/Dialect/Async/IR/Async.h"
45
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
56
#include "mlir/Dialect/Func/IR/FuncOps.h"
67
#include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -19,7 +20,8 @@ int main(int argc, char **argv) {
1920
registry.insert<mlir::func::FuncDialect, mlir::linalg::LinalgDialect,
2021
mlir::arith::ArithDialect, mlir::tensor::TensorDialect,
2122
mlir::memref::MemRefDialect, mlir::scf::SCFDialect,
22-
mlir::affine::AffineDialect, mlir::cf::ControlFlowDialect>();
23+
mlir::affine::AffineDialect, mlir::cf::ControlFlowDialect,
24+
mlir::async::AsyncDialect>();
2325

2426
mlir::registerAllPasses();
2527
mlir::PassPipelineRegistration<>("lab-op-stats", "Lab Op Stats Pass",
@@ -45,6 +47,12 @@ int main(int argc, char **argv) {
4547
pm.addPass(mlir::createLabFusionFeasibilityPass());
4648
});
4749

50+
mlir::PassPipelineRegistration<>(
51+
"lab-async-local-schedule", "Lab Async Local Schedule Pass",
52+
[](mlir::OpPassManager &pm) {
53+
pm.addPass(mlir::createAsyncLocalSchedulePass());
54+
});
55+
4856
return mlir::asMainReturnCode(
4957
mlir::MlirOptMain(argc, argv, "Lab optimizer\n", registry));
5058
}
Lines changed: 308 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,308 @@
1+
#include "mlir/Dialect/Arith/IR/Arith.h"
2+
#include "mlir/Dialect/Async/IR/Async.h"
3+
#include "mlir/Dialect/Func/IR/FuncOps.h"
4+
#include "mlir/IR/Block.h"
5+
#include "mlir/IR/BuiltinTypes.h"
6+
#include "mlir/IR/Operation.h"
7+
#include "mlir/IR/Value.h"
8+
#include "mlir/Interfaces/SideEffectInterfaces.h"
9+
#include "mlir/Pass/Pass.h"
10+
#include "llvm/ADT/DenseMap.h"
11+
#include "llvm/ADT/STLExtras.h"
12+
#include "llvm/ADT/SetVector.h"
13+
#include "llvm/ADT/SmallVector.h"
14+
15+
using namespace mlir;
16+
17+
namespace {
18+
19+
enum class NodeKind {
20+
AsyncExecute,
21+
AsyncAwait,
22+
PureCompute,
23+
BarrierLike,
24+
Other
25+
};
26+
27+
struct SchedNode {
28+
Operation *op = nullptr;
29+
NodeKind kind = NodeKind::Other;
30+
SmallVector<int> preds;
31+
SmallVector<int> succs;
32+
int indegree = 0;
33+
int originalOrder = -1;
34+
};
35+
36+
static bool isAsyncType(Type ty) {
37+
return isa<async::TokenType>(ty) || isa<async::ValueType>(ty);
38+
}
39+
40+
static NodeKind classifyOp(Operation *op) {
41+
if (isa<async::ExecuteOp>(op))
42+
return NodeKind::AsyncExecute;
43+
if (isa<async::AwaitOp>(op))
44+
return NodeKind::AsyncAwait;
45+
46+
// terminator / region branch 直接看作 barrier
47+
if (op->hasTrait<OpTrait::IsTerminator>())
48+
return NodeKind::BarrierLike;
49+
50+
// 无 side effect 的普通算子,视作纯计算
51+
if (isMemoryEffectFree(op))
52+
return NodeKind::PureCompute;
53+
54+
return NodeKind::Other;
55+
}
56+
57+
static bool isBarrier(Operation *op) {
58+
if (op->hasTrait<OpTrait::IsTerminator>())
59+
return true;
60+
61+
// async.execute / async.await 本身不是 barrier
62+
if (isa<async::ExecuteOp, async::AwaitOp>(op))
63+
return false;
64+
65+
// 纯 op 允许参与窗口调度
66+
if (isMemoryEffectFree(op))
67+
return false;
68+
69+
// 其余统统保守视为 barrier
70+
return true;
71+
}
72+
73+
static DenseMap<Operation *, int> buildOpIndex(ArrayRef<Operation *> ops) {
74+
DenseMap<Operation *, int> map;
75+
for (auto [i, op] : llvm::enumerate(ops))
76+
map[op] = i;
77+
return map;
78+
}
79+
80+
static void addEdge(SmallVectorImpl<SchedNode> &nodes, int u, int v) {
81+
if (u == v)
82+
return;
83+
84+
// 避免重复边
85+
if (llvm::is_contained(nodes[u].succs, v))
86+
return;
87+
88+
nodes[u].succs.push_back(v);
89+
nodes[v].preds.push_back(u);
90+
}
91+
92+
static void buildSSADependencies(ArrayRef<Operation *> ops,
93+
SmallVectorImpl<SchedNode> &nodes) {
94+
auto opToIdx = buildOpIndex(ops);
95+
96+
for (auto [i, op] : llvm::enumerate(ops)) {
97+
for (Value operand : op->getOperands()) {
98+
Operation *def = operand.getDefiningOp();
99+
if (!def)
100+
continue;
101+
102+
auto it = opToIdx.find(def);
103+
if (it == opToIdx.end())
104+
continue;
105+
106+
addEdge(nodes, it->second, i);
107+
}
108+
}
109+
}
110+
111+
static bool needsConservativeOrder(Operation *a, Operation *b) {
112+
bool pureA = isMemoryEffectFree(a);
113+
bool pureB = isMemoryEffectFree(b);
114+
115+
// 两个都纯,则不需要额外约束
116+
if (pureA && pureB)
117+
return false;
118+
119+
// async.execute / async.await 与纯 op 混排时,第一版我们也允许
120+
// 只要它们的 SSA 依赖满足即可。
121+
if ((isa<async::ExecuteOp, async::AwaitOp>(a) || pureA) &&
122+
(isa<async::ExecuteOp, async::AwaitOp>(b) || pureB)) {
123+
return false;
124+
}
125+
126+
// 其余情况保守约束
127+
return true;
128+
}
129+
130+
static void buildConservativeOrderEdges(ArrayRef<Operation *> ops,
131+
SmallVectorImpl<SchedNode> &nodes) {
132+
for (int i = 0, e = static_cast<int>(ops.size()); i < e; ++i) {
133+
for (int j = i + 1; j < e; ++j) {
134+
if (needsConservativeOrder(ops[i], ops[j]))
135+
addEdge(nodes, i, j);
136+
}
137+
}
138+
}
139+
140+
static int priorityOf(NodeKind kind) {
141+
switch (kind) {
142+
case NodeKind::AsyncExecute:
143+
return 300;
144+
case NodeKind::PureCompute:
145+
return 200;
146+
case NodeKind::AsyncAwait:
147+
return 100;
148+
case NodeKind::Other:
149+
return 50;
150+
case NodeKind::BarrierLike:
151+
return 0;
152+
}
153+
return 0;
154+
}
155+
156+
static SmallVector<int> scheduleWindow(ArrayRef<SchedNode> inputNodes) {
157+
SmallVector<SchedNode> nodes(inputNodes.begin(), inputNodes.end());
158+
159+
for (auto &n : nodes)
160+
n.indegree = static_cast<int>(n.preds.size());
161+
162+
SmallVector<int> ready;
163+
for (int i = 0, e = static_cast<int>(nodes.size()); i < e; ++i) {
164+
if (nodes[i].indegree == 0)
165+
ready.push_back(i);
166+
}
167+
168+
SmallVector<int> order;
169+
order.reserve(nodes.size());
170+
171+
while (!ready.empty()) {
172+
int bestPos = 0;
173+
for (int k = 1, e = static_cast<int>(ready.size()); k < e; ++k) {
174+
int lhs = ready[k];
175+
int rhs = ready[bestPos];
176+
177+
int pl = priorityOf(nodes[lhs].kind);
178+
int pr = priorityOf(nodes[rhs].kind);
179+
180+
if (pl > pr) {
181+
bestPos = k;
182+
continue;
183+
}
184+
if (pl == pr && nodes[lhs].originalOrder < nodes[rhs].originalOrder) {
185+
bestPos = k;
186+
}
187+
}
188+
189+
int u = ready[bestPos];
190+
ready.erase(ready.begin() + bestPos);
191+
order.push_back(u);
192+
193+
for (int v : nodes[u].succs) {
194+
nodes[v].indegree--;
195+
if (nodes[v].indegree == 0)
196+
ready.push_back(v);
197+
}
198+
}
199+
200+
if (order.size() != nodes.size())
201+
return {}; // 有环,放弃该窗口
202+
203+
return order;
204+
}
205+
206+
static SmallVector<SmallVector<Operation *>> collectWindows(Block &block) {
207+
SmallVector<SmallVector<Operation *>> windows;
208+
SmallVector<Operation *> current;
209+
210+
for (Operation &op : block) {
211+
if (isBarrier(&op)) {
212+
if (!current.empty()) {
213+
windows.push_back(std::move(current));
214+
current.clear();
215+
}
216+
continue;
217+
}
218+
current.push_back(&op);
219+
}
220+
221+
if (!current.empty())
222+
windows.push_back(std::move(current));
223+
224+
return windows;
225+
}
226+
227+
static bool reorderWindow(ArrayRef<Operation *> ops) {
228+
if (ops.size() < 2)
229+
return false;
230+
231+
SmallVector<SchedNode> nodes;
232+
nodes.reserve(ops.size());
233+
234+
for (auto [i, op] : llvm::enumerate(ops)) {
235+
nodes.push_back(SchedNode{
236+
.op = op,
237+
.kind = classifyOp(op),
238+
.preds = {},
239+
.succs = {},
240+
.indegree = 0,
241+
.originalOrder = static_cast<int>(i),
242+
});
243+
}
244+
245+
buildSSADependencies(ops, nodes);
246+
buildConservativeOrderEdges(ops, nodes);
247+
248+
SmallVector<int> newOrder = scheduleWindow(nodes);
249+
if (newOrder.empty())
250+
return false;
251+
252+
bool changed = false;
253+
for (int i = 0, e = static_cast<int>(newOrder.size()); i < e; ++i) {
254+
if (newOrder[i] != i) {
255+
changed = true;
256+
break;
257+
}
258+
}
259+
if (!changed)
260+
return false;
261+
262+
// 锚点:窗口结束位置(最后一个 op 的 next)
263+
Operation *afterWindow = ops.back()->getNextNode();
264+
265+
for (int idx : newOrder) {
266+
Operation *op = nodes[idx].op;
267+
if (afterWindow)
268+
op->moveBefore(afterWindow);
269+
else
270+
op->moveBefore(op->getBlock(), Block::iterator());
271+
}
272+
273+
return true;
274+
}
275+
276+
struct AsyncLocalSchedulePass
277+
: public PassWrapper<AsyncLocalSchedulePass, OperationPass<func::FuncOp>> {
278+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AsyncLocalSchedulePass)
279+
280+
StringRef getArgument() const final { return "lab-async-local-schedule"; }
281+
StringRef getDescription() const final {
282+
return "Locally reorder async.execute/async.await inside a block";
283+
}
284+
285+
void runOnOperation() override;
286+
};
287+
288+
void AsyncLocalSchedulePass::runOnOperation() {
289+
func::FuncOp func = getOperation();
290+
291+
bool changed = false;
292+
for (Block &block : func.getBody()) {
293+
auto windows = collectWindows(block);
294+
for (auto &window : windows) {
295+
changed |= reorderWindow(window);
296+
}
297+
}
298+
299+
(void)changed;
300+
}
301+
302+
} // namespace
303+
304+
namespace mlir {
305+
std::unique_ptr<Pass> createAsyncLocalSchedulePass() {
306+
return std::make_unique<AsyncLocalSchedulePass>();
307+
}
308+
} // namespace mlir
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
func.func @test(%c0 : i32, %c1 : i32) -> i32 {
2+
%token0, %t0 = async.execute -> !async.value<i32> {
3+
async.yield %c0 : i32
4+
}
5+
%v0 = async.await %t0 : !async.value<i32>
6+
7+
%token1, %t1 = async.execute -> !async.value<i32> {
8+
async.yield %c1 : i32
9+
}
10+
%v1 = async.await %t1 : !async.value<i32>
11+
12+
%sum = arith.addi %v0, %v1 : i32
13+
return %sum : i32
14+
}

0 commit comments

Comments
 (0)