Skip to content

Commit f2d06f6

Browse files
committed
Added the pass that can compile cuda tile IR
1 parent 0a3ccf4 commit f2d06f6

5 files changed

Lines changed: 225 additions & 21 deletions

File tree

mlir/cuda-tile/Toy/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ add_executable(
2727
mlir/ToyCombine.cpp
2828
mlir/LowerToGpu.cpp
2929
mlir/LowerToCudaTile.cpp
30+
mlir/EmitCudaTile.cpp
3031
)
3132

3233
add_dependencies(toy-cuda
@@ -62,5 +63,8 @@ target_link_libraries(toy-cuda
6263
MLIRTargetLLVMIRExport
6364
MLIRTransforms
6465
CudaTileDialect
66+
CudaTileTransforms
67+
CudaTileBytecodeWriter
68+
CudaTileBytecodeCommon
6569
cuda_shim
6670
)

mlir/cuda-tile/Toy/include/toy/Passes.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,14 @@ std::unique_ptr<mlir::Pass> createLowerToAffinePass();
3030
/// well as `Affine` and `Std`, to the LLVM dialect for codegen.
3131
std::unique_ptr<mlir::Pass> createLowerToLLVMPass();
3232

33-
std::unique_ptr<mlir::Pass> createGpuOutlinePass(std::string grid="1,1,1");
33+
std::unique_ptr<mlir::Pass> createGpuOutlinePass(std::string grid = "1,1,1");
3434

3535
std::unique_ptr<mlir::Pass> createCudaTileLoweringPass();
3636

37+
std::unique_ptr<mlir::Pass>
38+
createEmbedCudaTileBinaryPass(std::string tileirasExe = "tileiras",
39+
std::string gpuName = "sm_120");
40+
3741
} // namespace toy
3842
} // namespace mlir
3943

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
#include "mlir/IR/Builders.h"
2+
#include "mlir/IR/BuiltinOps.h"
3+
#include "mlir/Pass/Pass.h"
4+
5+
#include "cuda_tile/Bytecode/Writer/BytecodeWriter.h"
6+
#include "cuda_tile/Dialect/CudaTile/IR/Ops.h"
7+
#include "toy/Dialect.h"
8+
#include "llvm/ADT/SmallVector.h"
9+
#include "llvm/ADT/StringRef.h"
10+
#include "llvm/Support/FileSystem.h"
11+
#include "llvm/Support/MemoryBuffer.h"
12+
#include "llvm/Support/Program.h"
13+
#include "llvm/Support/raw_ostream.h"
14+
#include <system_error>
15+
16+
using namespace llvm;
17+
using namespace mlir;
18+
19+
namespace {
20+
21+
/// Read file contents as raw bytes.
22+
static FailureOr<std::vector<int8_t>> readFileBytes(StringRef path) {
23+
auto bufOrErr = MemoryBuffer::getFile(path, /*IsText=*/false);
24+
if (!bufOrErr)
25+
return failure();
26+
auto &buf = *bufOrErr.get();
27+
std::vector<int8_t> out(buf.getBufferSize());
28+
memcpy(out.data(), buf.getBufferStart(), buf.getBufferSize());
29+
return out;
30+
}
31+
32+
/// Write raw bytes to a file.
33+
static LogicalResult writeFileBytes(StringRef path, ArrayRef<char> bytes) {
34+
std::error_code ec;
35+
raw_fd_ostream os(path, ec, sys::fs::OF_None);
36+
if (ec)
37+
return failure();
38+
os.write(bytes.data(), bytes.size());
39+
os.flush();
40+
return success();
41+
}
42+
43+
/// Execute external tileiras to assemble tilebc into a binary.
44+
static LogicalResult runTileIRAS(Operation *anchor, StringRef tileirasExe,
45+
StringRef gpuName, StringRef inTilebc,
46+
StringRef outBin) {
47+
SmallVector<StringRef, 16> args;
48+
args.push_back(tileirasExe);
49+
args.push_back("--gpu-name");
50+
args.push_back(gpuName);
51+
args.push_back(inTilebc);
52+
args.push_back("-o");
53+
args.push_back(outBin);
54+
55+
std::string errMsg;
56+
int rc = sys::ExecuteAndWait(tileirasExe, args,
57+
/*env=*/std::nullopt,
58+
/*redirects=*/{},
59+
/*secondsToWait=*/0,
60+
/*memoryLimit=*/0, &errMsg);
61+
if (rc != 0) {
62+
return anchor->emitError() << "tileiras failed, rc=" << rc << "\n"
63+
<< errMsg;
64+
}
65+
return success();
66+
}
67+
68+
std::error_code createTemporaryFile(SmallVectorImpl<char> &inPath,
69+
StringRef prefix, StringRef suffix) {
70+
int inFD = -1;
71+
if (std::error_code ec =
72+
sys::fs::createTemporaryFile(prefix, suffix, inFD, inPath)) {
73+
return ec;
74+
}
75+
76+
if (std::error_code ec = sys::fs::closeFile(inFD)) {
77+
return ec;
78+
}
79+
return std::error_code();
80+
}
81+
82+
struct EmbedCudaTileBinaryPass
83+
: public PassWrapper<EmbedCudaTileBinaryPass, OperationPass<ModuleOp>> {
84+
85+
std::string tileirasExe;
86+
std::string gpuName;
87+
88+
EmbedCudaTileBinaryPass(std::string tileirasExe, std::string gpuName)
89+
: tileirasExe(std::move(tileirasExe)), gpuName(std::move(gpuName)) {}
90+
91+
void runOnOperation() override {
92+
ModuleOp top = getOperation();
93+
MLIRContext *ctx = top.getContext();
94+
95+
SmallString<256> cudaBinPath;
96+
97+
top.walk([&](Operation *op) {
98+
// we assume the MLIR only have one cuda tile module.
99+
if (op->getName().getStringRef() != "cuda_tile.module")
100+
return;
101+
102+
auto cudaMod = dyn_cast<cuda_tile::ModuleOp>(op);
103+
if (!cudaMod)
104+
return;
105+
106+
// ---- Step B: generate tilebc bytes in-process ----
107+
SmallVector<char, 0> tilebcBytes;
108+
raw_svector_ostream tilebcOS(tilebcBytes);
109+
110+
// Using writeBytecode API: writeBytecode(output, moduleOp,
111+
// BytecodeVersion::kCurrentVersion)
112+
if (failed(writeBytecode(tilebcOS, cudaMod,
113+
cuda_tile::BytecodeVersion::kCurrentVersion))) {
114+
op->emitError() << "writeBytecode(tilebc) failed";
115+
signalPassFailure();
116+
return;
117+
}
118+
119+
// ---- Step C: create temp files and invoke tileiras ----
120+
SmallString<256> inPath;
121+
122+
if (std::error_code ec =
123+
createTemporaryFile(inPath, "cuda_tile", "tilebc")) {
124+
op->emitError() << "failed to create temp in tilebc: " << ec.message();
125+
signalPassFailure();
126+
return;
127+
}
128+
129+
if (std::error_code ec =
130+
createTemporaryFile(cudaBinPath, "cuda_tile", "bin")) {
131+
op->emitError() << "failed to create temp out bin: " << ec.message();
132+
signalPassFailure();
133+
return;
134+
}
135+
136+
if (failed(writeFileBytes(inPath, tilebcBytes))) {
137+
op->emitError() << "failed to write temp tilebc";
138+
signalPassFailure();
139+
return;
140+
}
141+
142+
if (failed(runTileIRAS(op, tileirasExe, gpuName, inPath, cudaBinPath))) {
143+
signalPassFailure();
144+
return;
145+
}
146+
});
147+
148+
top->walk([&](toy::LaunchGpuOp launchOp) {
149+
// ---- Step D: read cuda binary bytes ----
150+
auto binBytesOrErr = readFileBytes(cudaBinPath);
151+
if (failed(binBytesOrErr)) {
152+
launchOp.emitError() << "failed to read cuda binary file";
153+
signalPassFailure();
154+
return;
155+
}
156+
auto binBytes = *binBytesOrErr;
157+
158+
// ---- Step E: embed binary as LaunchGpuOp attributes ----
159+
llvm::SmallVector<uint8_t, 0> binU8Bytes;
160+
binU8Bytes.reserve(binBytes.size());
161+
for (auto b : binBytes)
162+
binU8Bytes.push_back(static_cast<uint8_t>(b));
163+
164+
auto byteAttr = mlir::DenseIntElementsAttr::get(
165+
mlir::RankedTensorType::get({static_cast<int64_t>(binU8Bytes.size())},
166+
mlir::IntegerType::get(ctx, 8)),
167+
binU8Bytes);
168+
169+
// launchOp->setAttr("cuda_binary", byteAttr);
170+
launchOp->setAttr("cuda_binary_size",
171+
mlir::IntegerAttr::get(mlir::IntegerType::get(ctx, 64),
172+
binU8Bytes.size()));
173+
launchOp->setAttr("cuda_binary_path",
174+
mlir::StringAttr::get(ctx, cudaBinPath.str()));
175+
launchOp->setAttr("cuda_arch", mlir::StringAttr::get(ctx, gpuName));
176+
});
177+
178+
// ---- Step F: Delete the cuda_tile.module ops ----
179+
llvm::SmallVector<mlir::Operation *, 32> toErase;
180+
top->walk([&](cuda_tile::ModuleOp op) { toErase.push_back(op); });
181+
182+
for (auto op : toErase) {
183+
op->erase();
184+
}
185+
};
186+
};
187+
} // namespace
188+
189+
namespace mlir::toy {
190+
191+
std::unique_ptr<mlir::Pass>
192+
createEmbedCudaTileBinaryPass(std::string tileirasExe, std::string gpuName) {
193+
return std::make_unique<EmbedCudaTileBinaryPass>(tileirasExe, gpuName);
194+
};
195+
196+
}; // namespace mlir::toy

mlir/cuda-tile/Toy/mlir/LowerToCudaTile.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,7 @@
66
#include "mlir/IR/BuiltinTypeInterfaces.h"
77
#include "mlir/IR/BuiltinTypes.h"
88
#include "mlir/IR/DialectRegistry.h"
9-
#include "mlir/IR/IRMapping.h"
109
#include "mlir/IR/Operation.h"
11-
#include "mlir/IR/SymbolTable.h"
1210
#include "mlir/IR/Types.h"
1311
#include "mlir/IR/Value.h"
1412
#include "mlir/Pass/Pass.h"
@@ -19,8 +17,6 @@
1917
#include "toy/Passes.h"
2018
#include "llvm/ADT/ArrayRef.h"
2119
#include "llvm/ADT/STLExtras.h"
22-
#include "llvm/ADT/SmallPtrSet.h"
23-
#include "llvm/ADT/SmallSet.h"
2420
#include "llvm/ADT/SmallVector.h"
2521
#include "llvm/ADT/StringExtras.h"
2622
#include "llvm/ADT/StringRef.h"

mlir/cuda-tile/Toy/toyc.cpp

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -335,23 +335,27 @@ static int loadAndProcessMLIRGPU(mlir::MLIRContext &context,
335335
optPM.addPass(mlir::toy::createGpuOutlinePass(assignGrid));
336336
// mlir::OpPassManager &gpuOptPM = pm.nest<mlir::toy::GPUFuncOp>();
337337
pm.addPass(mlir::toy::createCudaTileLoweringPass());
338+
pm.addPass(mlir::createCSEPass());
339+
338340
// pm.addPass(mlir::toy::createLowerGpuHostToLLVMPass());
339-
// bool isLoweringToAffine = emitAction >= Action::DumpGpuAffine;
340-
// if (isLoweringToAffine) {
341-
// // Partially lower the toy dialect.
342-
// optPM.addPass(mlir::toy::createLowerToAffinePass());
343-
344-
// // Add a few cleanups post lowering.
345-
// mlir::OpPassManager &optPM = pm.nest<mlir::func::FuncOp>();
346-
// optPM.addPass(mlir::createCanonicalizerPass());
347-
// optPM.addPass(mlir::createCSEPass());
348-
349-
// // Add optimizations if enabled.
350-
// if (enableOpt) {
351-
// optPM.addPass(mlir::affine::createLoopFusionPass());
352-
// optPM.addPass(mlir::affine::createAffineScalarReplacementPass());
353-
// }
354-
// }
341+
bool isLoweringToAffine = emitAction >= Action::DumpGpuAffine;
342+
if (isLoweringToAffine) {
343+
pm.addPass(mlir::toy::createEmbedCudaTileBinaryPass(
344+
"/usr/local/cuda/bin/tileiras", "sm_120"));
345+
// // Partially lower the toy dialect.
346+
// optPM.addPass(mlir::toy::createLowerToAffinePass());
347+
348+
// // Add a few cleanups post lowering.
349+
// mlir::OpPassManager &optPM = pm.nest<mlir::func::FuncOp>();
350+
// optPM.addPass(mlir::createCanonicalizerPass());
351+
// optPM.addPass(mlir::createCSEPass());
352+
353+
// // Add optimizations if enabled.
354+
// if (enableOpt) {
355+
// optPM.addPass(mlir::affine::createLoopFusionPass());
356+
// optPM.addPass(mlir::affine::createAffineScalarReplacementPass());
357+
// }
358+
}
355359

356360
if (mlir::failed(pm.run(*module)))
357361
return 4;

0 commit comments

Comments
 (0)