Skip to content

Commit f27ad3f

Browse files
committed
添加活跃性和内存生命周期分析的实现
1 parent b9e5a7f commit f27ad3f

8 files changed

Lines changed: 468 additions & 8 deletions

File tree

mlir/optimization/scheduler/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ add_executable(
1313
lab-opt.cpp
1414
lib/OpStatsPass.cpp
1515
lib/BufferAnalysisPass.cpp
16+
lib/LivenessAnalysisPass.cpp
17+
lib/MemrefLifetime.cpp
1618
)
1719

1820
# add_dependencies(lab-scheduler ToyCh6ShapeInferenceInterfaceIncGen

mlir/optimization/scheduler/include/lab/LabPasses.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,7 @@ std::unique_ptr<Pass> createLabBufferStatsPass();
1212
std::unique_ptr<Pass> createLabFusionFeasibilityPass();
1313
std::unique_ptr<Pass> createLabMatmulTilePass();
1414
std::unique_ptr<Pass> createLabPipelinePlanPass();
15+
std::unique_ptr<Pass> createLabLivenessPass();
16+
std::unique_ptr<Pass> createLabMemrefLifetimePass();
1517

1618
} // namespace mlir

mlir/optimization/scheduler/lab-opt.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
#include "lab/LabPasses.h"
22
#include "mlir/Dialect/Affine/IR/AffineOps.h"
33
#include "mlir/Dialect/Arith/IR/Arith.h"
4+
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
45
#include "mlir/Dialect/Func/IR/FuncOps.h"
56
#include "mlir/Dialect/Linalg/IR/Linalg.h"
67
#include "mlir/Dialect/MemRef/IR/MemRef.h"
78
#include "mlir/Dialect/SCF/IR/SCF.h"
89
#include "mlir/Dialect/Tensor/IR/Tensor.h"
910
#include "mlir/InitAllDialects.h"
1011
#include "mlir/InitAllPasses.h"
12+
#include "mlir/Pass/Pass.h"
1113
#include "mlir/Pass/PassManager.h"
1214
#include "mlir/Pass/PassRegistry.h"
1315
#include "mlir/Tools/mlir-opt/MlirOptMain.h"
@@ -17,7 +19,8 @@ int main(int argc, char **argv) {
1719
registry.insert<mlir::func::FuncDialect, mlir::linalg::LinalgDialect,
1820
mlir::arith::ArithDialect, mlir::tensor::TensorDialect,
1921
mlir::memref::MemRefDialect, mlir::scf::SCFDialect,
20-
mlir::affine::AffineDialect>();
22+
mlir::affine::AffineDialect,
23+
mlir::cf::ControlFlowDialect>();
2124

2225
mlir::registerAllPasses();
2326
mlir::PassPipelineRegistration<>("lab-op-stats", "Lab Op Stats Pass",
@@ -28,6 +31,14 @@ int main(int argc, char **argv) {
2831
[](mlir::OpPassManager &pm) {
2932
pm.addPass(mlir::createLabBufferStatsPass());
3033
});
34+
mlir::PassPipelineRegistration<>("lab-liveness", "Lab Liveness Pass",
35+
[](mlir::OpPassManager &pm) {
36+
pm.addPass(mlir::createLabLivenessPass());
37+
});
38+
mlir::PassPipelineRegistration<>("lab-memref-lifetime", "Lab Memref Lifetime Pass",
39+
[](mlir::OpPassManager &pm) {
40+
pm.addPass(mlir::createLabMemrefLifetimePass());
41+
});
3142

3243
return mlir::asMainReturnCode(
3344
mlir::MlirOptMain(argc, argv, "Lab optimizer\n", registry));

mlir/optimization/scheduler/lib/BufferAnalysisPass.cpp

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "mlir/Analysis/AliasAnalysis.h"
2+
#include "mlir/Analysis/Liveness.h"
23
#include "mlir/IR/AsmState.h"
34
#include "mlir/Dialect/Func/IR/FuncOps.h"
45
#include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -90,6 +91,7 @@ class LabBufferStatsAnalysis {
9091
// Step 2: 找 buffer owner,并计算 def/lastUse/size
9192
func.walk([&](Operation *op) {
9293
Value result;
94+
mlir::Liveness liveness(func);
9395

9496
if (auto allocOp = dyn_cast<memref::AllocOp>(op)) {
9597
result = allocOp.getResult();
@@ -108,14 +110,22 @@ class LabBufferStatsAnalysis {
108110
return; // 教学版:先跳过动态 shape
109111

110112
int def = opIndex[op];
111-
int lastUse = def;
112113

113-
for (OpOperand &use : result.getUses()) {
114-
Operation *user = use.getOwner();
115-
auto it = opIndex.find(user);
116-
if (it != opIndex.end())
117-
lastUse = std::max(lastUse, it->second);
118-
}
114+
// First version: 直接找最后一个使用点
115+
// int lastUse = def;
116+
117+
// for (OpOperand &use : result.getUses()) {
118+
// Operation *user = use.getOwner();
119+
// auto it = opIndex.find(user);
120+
// if (it != opIndex.end())
121+
// lastUse = std::max(lastUse, it->second);
122+
// }
123+
124+
auto indexOp = findLastSemanticUser(func, liveness, result);
125+
if (!indexOp)
126+
return; // 没有语义使用点?先跳过
127+
int lastUse = opIndex[indexOp];
128+
119129

120130
buffers.push_back(BufferRecord{
121131
result,
@@ -168,6 +178,31 @@ class LabBufferStatsAnalysis {
168178
return numElems * elemBytes;
169179
}
170180

181+
static Operation *findLastSemanticUser(func::FuncOp funcOp,
182+
mlir::Liveness &liveness,
183+
Value value) {
184+
Operation *lastUser = nullptr;
185+
186+
funcOp.walk([&](Operation *op) {
187+
bool usesValue = false;
188+
for (Value operand : op->getOperands()) {
189+
if (operand == value) {
190+
usesValue = true;
191+
break;
192+
}
193+
}
194+
if (!usesValue)
195+
return;
196+
197+
// 如果这个 op 使用了 value,并且 op 之后 value 已死,
198+
// 就把它视为“最后语义使用点”的候选。
199+
if (liveness.isDeadAfter(value, op))
200+
lastUser = op;
201+
});
202+
203+
return lastUser;
204+
}
205+
171206
static bool isBufferOwner(Value v) {
172207
Operation *defOp = v.getDefiningOp();
173208
return isa_and_nonnull<memref::AllocOp, memref::AllocaOp>(defOp);
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
#include "mlir/Analysis/Liveness.h"
2+
#include "mlir/Dialect/Func/IR/FuncOps.h"
3+
#include "mlir/IR/AsmState.h"
4+
#include "mlir/IR/Block.h"
5+
#include "mlir/IR/Types.h"
6+
#include "mlir/IR/Value.h"
7+
#include "mlir/Pass/Pass.h"
8+
#include "llvm/Support/raw_ostream.h"
9+
10+
using namespace mlir;
11+
12+
namespace {
13+
14+
struct LabLivenessPass
15+
: public PassWrapper<LabLivenessPass, OperationPass<func::FuncOp>> {
16+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LabLivenessPass)
17+
18+
StringRef getArgument() const final { return "lab-liveness"; }
19+
StringRef getDescription() const final {
20+
return "Example pass that prints MLIR liveness information";
21+
}
22+
23+
void runOnOperation() override {
24+
func::FuncOp funcOp = getOperation();
25+
26+
// 1) 构建 liveness 分析
27+
Liveness liveness(funcOp);
28+
29+
AsmState asmState(funcOp);
30+
31+
llvm::outs() << "=== Liveness for function @" << funcOp.getName()
32+
<< " ===\n";
33+
34+
// 2) 遍历每个 block,打印 live-in / live-out
35+
for (Block &block : funcOp.getBlocks()) {
36+
llvm::outs() << "\nBlock ";
37+
if (block.getParentOp() == funcOp)
38+
llvm::outs() << "(top-level)";
39+
llvm::outs() << " {\n";
40+
41+
llvm::outs() << " live-in : ";
42+
printValueSet(liveness.getLiveIn(&block), asmState);
43+
llvm::outs() << "\n";
44+
45+
llvm::outs() << " live-out: ";
46+
printValueSet(liveness.getLiveOut(&block), asmState);
47+
llvm::outs() << "\n";
48+
49+
// 3) 遍历 block 内 op,检查 operand 在该 op 后是否 dead
50+
for (Operation &op : block.getOperations()) {
51+
llvm::outs() << " op: " << op.getName() << "\n";
52+
53+
for (Value operand : op.getOperands()) {
54+
llvm::outs() << " operand=";
55+
56+
operand.printAsOperand(llvm::outs(), asmState);
57+
58+
bool deadAfter = liveness.isDeadAfter(operand, &op);
59+
llvm::outs() << " dead_after_op=" << (deadAfter ? "true" : "false")
60+
<< "\n";
61+
}
62+
}
63+
64+
llvm::outs() << "}\n";
65+
}
66+
67+
// 4) 也可以直接整体打印
68+
llvm::outs() << "\n--- Full liveness dump ---\n";
69+
liveness.print(llvm::outs());
70+
llvm::outs() << "\n";
71+
}
72+
73+
static void printValueSet(const Liveness::ValueSetT &values,
74+
AsmState &asmState) {
75+
llvm::outs() << "{";
76+
bool first = true;
77+
for (Value v : values) {
78+
if (!first)
79+
llvm::outs() << ", ";
80+
v.printAsOperand(llvm::outs(), asmState);
81+
first = false;
82+
}
83+
llvm::outs() << "}";
84+
}
85+
};
86+
87+
} // namespace
88+
89+
namespace mlir {
90+
std::unique_ptr<Pass> createLabLivenessPass() {
91+
return std::make_unique<LabLivenessPass>();
92+
}
93+
} // namespace mlir

0 commit comments

Comments
 (0)