|
| 1 | +#include "cuda_tile/Dialect/CudaTile/IR/Types.h" |
| 2 | +#include "mlir/IR/Attributes.h" |
| 3 | +#include "mlir/IR/Block.h" |
| 4 | +#include "mlir/IR/Builders.h" |
| 5 | +#include "mlir/IR/BuiltinOps.h" |
| 6 | +#include "mlir/IR/BuiltinTypes.h" |
| 7 | +#include "mlir/IR/DialectRegistry.h" |
| 8 | +#include "mlir/IR/IRMapping.h" |
| 9 | +#include "mlir/IR/Operation.h" |
| 10 | +#include "mlir/IR/SymbolTable.h" |
| 11 | +#include "mlir/IR/Types.h" |
| 12 | +#include "mlir/IR/Value.h" |
| 13 | +#include "mlir/Pass/Pass.h" |
| 14 | +#include "mlir/Support/LLVM.h" |
| 15 | +#include "mlir/Support/TypeID.h" |
| 16 | +#include "toy/Dialect.h" |
| 17 | +#include "toy/Passes.h" |
| 18 | +#include "llvm/ADT/STLExtras.h" |
| 19 | +#include "llvm/ADT/SmallPtrSet.h" |
| 20 | +#include "llvm/ADT/SmallSet.h" |
| 21 | +#include "llvm/ADT/SmallVector.h" |
| 22 | +#include "llvm/ADT/StringExtras.h" |
| 23 | +#include "llvm/ADT/StringRef.h" |
| 24 | +#include "llvm/Support/Casting.h" |
| 25 | +#include "llvm/Support/DebugLog.h" |
| 26 | + |
| 27 | +#include "cuda_tile/Dialect/CudaTile/IR/Dialect.h" |
| 28 | +#include "cuda_tile/Dialect/CudaTile/IR/Ops.h" |
| 29 | + |
| 30 | +#include <memory> |
| 31 | +#include <string> |
| 32 | + |
| 33 | +#define DEBUG_TYPE "toy-to-cuda-tile" |
| 34 | + |
| 35 | +//===----------------------------------------------------------------------===// |
| 36 | +// ToyToCudaTileLoweringPass |
| 37 | +//===----------------------------------------------------------------------===// |
| 38 | + |
| 39 | +namespace { |
| 40 | +struct ToyToCudaTileLoweringPass |
| 41 | + : public mlir::PassWrapper<ToyToCudaTileLoweringPass, |
| 42 | + mlir::OperationPass<mlir::ModuleOp>> { |
| 43 | + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ToyToCudaTileLoweringPass) |
| 44 | + |
| 45 | + llvm::StringRef getArgument() const override { return "toy-to-cuda-tile"; } |
| 46 | + |
| 47 | + void getDependentDialects(mlir::DialectRegistry ®istry) const override { |
| 48 | + registry.insert<mlir::cuda_tile::CudaTileDialect>(); |
| 49 | + } |
| 50 | + |
| 51 | + void runOnOperation() final; |
| 52 | +}; |
| 53 | +}; // namespace |
| 54 | + |
| 55 | +mlir::cuda_tile::ModuleOp createCudaModuleOp(mlir::OpBuilder &builder, |
| 56 | + mlir::ModuleOp &moduleOp) { |
| 57 | + mlir::OpBuilder::InsertionGuard guard(builder); |
| 58 | + |
| 59 | + builder.setInsertionPoint(moduleOp.getBody(), moduleOp.getBody()->end()); |
| 60 | + auto cudaTileModuleOp = mlir::cuda_tile::ModuleOp::create( |
| 61 | + builder, moduleOp.getLoc(), "cuda_tile_module"); |
| 62 | + |
| 63 | + LDBG() << "Created CudaTile Module: \n" << cudaTileModuleOp << "\n"; |
| 64 | + return cudaTileModuleOp; |
| 65 | +} |
| 66 | + |
| 67 | +void ToyToCudaTileLoweringPass::runOnOperation() { |
| 68 | + auto moduleOp = getOperation(); |
| 69 | + |
| 70 | + // Here we would implement the actual lowering logic from Toy GPUFuncOp |
| 71 | + // to CudaTile operations. For now, we just log that the pass is running. |
| 72 | + // LDBG() << "Running Toy to CudaTile lowering on GPUFuncOp: " << moduleOp |
| 73 | + // << "\n"; |
| 74 | + |
| 75 | + mlir::OpBuilder builder(moduleOp.getContext()); |
| 76 | + // 1. Create new cuda_tile.module Op in the last section. |
| 77 | + auto cudaTileModuleOp = createCudaModuleOp(builder, moduleOp); |
| 78 | + // mlir::SymbolTable cudaTileSymbolTable(cudaTileModuleOp); |
| 79 | + |
| 80 | + moduleOp->walk([&](mlir::toy::GPUFuncOp gfunOp) { |
| 81 | + mlir::OpBuilder::InsertionGuard guard(builder); |
| 82 | + // setInsertionPointToEnd expects a Block*, so take the address of the |
| 83 | + // single block inside the cuda_tile.module region. |
| 84 | + builder.setInsertionPointToEnd(&cudaTileModuleOp.getBodyRegion().front()); |
| 85 | + auto gfunc_name = |
| 86 | + gfunOp->getAttrOfType<mlir::StringAttr>("sym_name").getValue(); |
| 87 | + llvm::SmallVector<mlir::Type, 8> newArgTypes; |
| 88 | + |
| 89 | + LDBG() << "Lowering GPU function: " << gfunc_name << "\n"; |
| 90 | + LDBG() << "Converting input type into cuda tile type" << "\n"; |
| 91 | + |
| 92 | + for (mlir::Type t : gfunOp.getFunctionType().getInputs()) { |
| 93 | + LDBG() << "Original arg type: " << t << "\n"; |
| 94 | + auto tt = llvm::dyn_cast<mlir::TensorType>(t); |
| 95 | + auto elemType = tt.getElementType(); |
| 96 | + auto ptrElem = mlir::cuda_tile::PointerType::get(elemType); |
| 97 | + auto newType = mlir::cuda_tile::TileType::get({}, ptrElem); |
| 98 | + LDBG() << "The new arg type for cuda tile: " << newType << "\n"; |
| 99 | + newArgTypes.push_back(newType); |
| 100 | + } |
| 101 | + |
| 102 | + LDBG() << "Converting result type into cuda tile type" << "\n"; |
| 103 | + for (mlir::Type t : gfunOp.getFunctionType().getResults()) { |
| 104 | + LDBG() << "Original result type: " << t << "\n"; |
| 105 | + auto tt = llvm::dyn_cast<mlir::TensorType>(t); |
| 106 | + auto elemType = tt.getElementType(); |
| 107 | + auto ptrElem = mlir::cuda_tile::PointerType::get(elemType); |
| 108 | + auto newType = mlir::cuda_tile::TileType::get({}, ptrElem); |
| 109 | + LDBG() << "The new arg type for cuda tile: " << newType << "\n"; |
| 110 | + newArgTypes.push_back(newType); |
| 111 | + } |
| 112 | + |
| 113 | + auto newFnType = builder.getFunctionType(newArgTypes, {}); |
| 114 | + auto fname = builder.getStringAttr(gfunc_name); |
| 115 | + auto argTypes = builder.getTypeArrayAttr(newArgTypes); |
| 116 | + auto cudaEntryOp = mlir::cuda_tile::EntryOp::create( |
| 117 | + builder, gfunOp.getLoc(), fname, newFnType, |
| 118 | + /*arg_attrs=*/{}, /*res_attrs=*/{}, {}); |
| 119 | + auto bb = cudaEntryOp.addEntryBlock(); |
| 120 | + builder.setInsertionPointToStart(bb); |
| 121 | + auto retOp = mlir::cuda_tile::ReturnOp::create(builder, gfunOp.getLoc()); |
| 122 | + |
| 123 | + LDBG() << "Created CudaTile Entry Op: \n" << cudaEntryOp << "\n"; |
| 124 | + }); |
| 125 | +} |
| 126 | + |
| 127 | +namespace mlir::toy { |
| 128 | + |
| 129 | +std::unique_ptr<mlir::Pass> createCudaTileLoweringPass() { |
| 130 | + return std::make_unique<ToyToCudaTileLoweringPass>(); |
| 131 | +}; |
| 132 | + |
| 133 | +}; // namespace mlir::toy |
0 commit comments