Skip to content

Commit fd70514

Browse files
committed
Added Affine pass code
1 parent 37800da commit fd70514

1 file changed

Lines changed: 21 additions & 1 deletion

File tree

mlir/cuda-tile/Toy/toyc.cpp

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ enum Action {
8181
RunJIT,
8282
DumpGpuIR,
8383
DumpCudaTileIR,
84+
DumpGpuAffine,
8485
DumpGPULLVMIR,
8586
RunNVGPUJIT
8687
};
@@ -101,6 +102,9 @@ static cl::opt<enum Action> emitAction(
101102
"output the GPU dialect MLIR dump")),
102103
cl::values(clEnumValN(DumpCudaTileIR, "cuda-tile-ir",
103104
"output the Cuda Tile dialect MLIR dump")),
105+
cl::values(clEnumValN(DumpGpuAffine, "gpu-affine",
106+
"output the GPU dialect MLIR dump after affine "
107+
"lowering")),
104108
cl::values(clEnumValN(DumpGPULLVMIR, "gpu-llvm",
105109
"output the GPU LLVM dialect MLIR dump")),
106110
cl::values(clEnumValN(RunNVGPUJIT, "nv-gpu-jit",
@@ -329,8 +333,24 @@ static int loadAndProcessMLIRGPU(mlir::MLIRContext &context,
329333

330334
// Now process the toy mlir with gpu outline pass.
331335
optPM.addPass(mlir::toy::createGpuOutlinePass(assignGrid));
332-
// pm.addPass(mlir::toy::createCudaTileLoweringPass(assignGrid));
336+
// pm.addPass(mlir::toy::createCudaTileLoweringPass());
333337
// pm.addPass(mlir::toy::createLowerGpuHostToLLVMPass());
338+
// bool isLoweringToAffine = emitAction >= Action::DumpGpuAffine;
339+
// if (isLoweringToAffine) {
340+
// // Partially lower the toy dialect.
341+
// optPM.addPass(mlir::toy::createLowerToAffinePass());
342+
343+
// // Add a few cleanups post lowering.
344+
// mlir::OpPassManager &optPM = pm.nest<mlir::func::FuncOp>();
345+
// optPM.addPass(mlir::createCanonicalizerPass());
346+
// optPM.addPass(mlir::createCSEPass());
347+
348+
// // Add optimizations if enabled.
349+
// if (enableOpt) {
350+
// optPM.addPass(mlir::affine::createLoopFusionPass());
351+
// optPM.addPass(mlir::affine::createAffineScalarReplacementPass());
352+
// }
353+
// }
334354

335355
if (mlir::failed(pm.run(*module)))
336356
return 4;

0 commit comments

Comments
 (0)