|
| 1 | +// SPDX-License-Identifier: BSD-3-Clause |
| 2 | + |
| 3 | +#include "mlir/IR/MLIRContext.h" |
| 4 | +#include "mlir/InitAllDialects.h" |
| 5 | + |
| 6 | +static mlir::Type makeSignlessType(mlir::Type type) |
| 7 | +{ |
| 8 | + if (auto shaped = type.dyn_cast<mlir::ShapedType>()) { |
| 9 | + auto origElemType = shaped.getElementType(); |
| 10 | + auto signlessElemType = makeSignlessType(origElemType); |
| 11 | + return shaped.clone(signlessElemType); |
| 12 | + } else if (auto intType = type.dyn_cast<mlir::IntegerType>()) { |
| 13 | + if (!intType.isSignless()) |
| 14 | + return mlir::IntegerType::get(intType.getContext(), intType.getWidth()); |
| 15 | + } |
| 16 | + return type; |
| 17 | +} |
| 18 | + |
| 19 | +auto getInt(const mlir::Location & loc, mlir::OpBuilder & builder, int64_t val) |
| 20 | +{ |
| 21 | + auto attr = builder.getI64IntegerAttr(val); |
| 22 | + return builder.create<mlir::arith::ConstantOp>(loc, attr); |
| 23 | + // auto intType = builder.getIntegerType(64, true); |
| 24 | + // return builder.create<plier::SignCastOp>(loc, intType, res); |
| 25 | +} |
| 26 | + |
| 27 | +void ttt() |
| 28 | +{ |
| 29 | + std::vector<int> shape = {16, 16}; |
| 30 | + std::string fname("ttt_mlir"); |
| 31 | + |
| 32 | + mlir::MLIRContext context; |
| 33 | + context.getOrLoadDialect<mlir::arith::ArithmeticDialect>(); |
| 34 | + context.getOrLoadDialect<mlir::linalg::LinalgDialect>(); |
| 35 | + mlir::OpBuilder builder(&context); |
| 36 | + auto theModule = mlir::ModuleOp::create(builder.getUnknownLoc()); |
| 37 | + auto loc = builder.getUnknownLoc(); |
| 38 | + |
| 39 | + // Create a func prototype |
| 40 | + llvm::SmallVector<mlir::Type, 4> argTypes(0); |
| 41 | + auto funcType = builder.getFunctionType(argTypes, llvm::None); |
| 42 | + auto fproto = mlir::FuncOp::create(loc, fname, funcType); |
| 43 | + |
| 44 | + // Create an MLIR function for the given prototype. |
| 45 | + mlir::FuncOp function(fproto); |
| 46 | + assert(function); |
| 47 | + |
| 48 | + // Let's start the body of the function now! |
| 49 | + // In MLIR the entry block of the function is special: it must have the same |
| 50 | + // argument list as the function itself. |
| 51 | + auto &entryBlock = *function.addEntryBlock(); |
| 52 | + |
| 53 | + // Set the insertion point in the builder to the beginning of the function |
| 54 | + // body, it will be used throughout the codegen to create operations in this |
| 55 | + // function. |
| 56 | + builder.setInsertionPointToStart(&entryBlock); |
| 57 | + |
| 58 | + auto elemType = builder.getF64Type(); |
| 59 | + auto signlessElemType = makeSignlessType(elemType); |
| 60 | + auto indexType = builder.getIndexType(); |
| 61 | + auto count = shape.size(); |
| 62 | + llvm::SmallVector<mlir::Value> shapeVal(count); |
| 63 | + llvm::SmallVector<int64_t> staticShape(count); // mlir::ShapedType::kDynamicSize); |
| 64 | + |
| 65 | + for(auto it : llvm::enumerate(shape)) { |
| 66 | + auto i = it.index(); |
| 67 | + auto elem = it.value(); |
| 68 | + auto elemVal = getInt(loc, builder, elem); |
| 69 | + staticShape[i] = elem; |
| 70 | + shapeVal[i] = elemVal; |
| 71 | + } |
| 72 | + |
| 73 | + mlir::Value init; |
| 74 | + if(true) { //initVal.is_none()) { |
| 75 | + init = builder.create<mlir::linalg::InitTensorOp>(loc, shapeVal, signlessElemType); |
| 76 | + }// else { |
| 77 | + // auto val = doCast(builder, loc, ctx.context.unwrapVal(loc, builder, initVal), signlessElemType); |
| 78 | + // llvm::SmallVector<int64_t> shape(count, mlir::ShapedType::kDynamicSize); |
| 79 | + // auto type = mlir::RankedTensorType::get(shape, signlessElemType); |
| 80 | + // auto body = [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::ValueRange /*indices*/) { |
| 81 | + // builder.create<mlir::tensor::YieldOp>(loc, val); |
| 82 | + // }; |
| 83 | + // init = builder.create<mlir::tensor::GenerateOp>(loc, type, shapeVal, body); |
| 84 | + // } |
| 85 | + if (llvm::any_of(staticShape, [](auto val) { return val >= 0; })) { |
| 86 | + auto newType = mlir::RankedTensorType::get(staticShape, signlessElemType); |
| 87 | + init = builder.create<mlir::tensor::CastOp>(loc, newType, init); |
| 88 | + } |
| 89 | + auto resTensorTypeSigness = init.getType().cast<mlir::RankedTensorType>(); |
| 90 | + auto resTensorType = mlir::RankedTensorType::get(resTensorTypeSigness.getShape(), elemType, resTensorTypeSigness.getEncoding()); |
| 91 | +} |
0 commit comments