|
15 | 15 | #include "mlir/IR/BuiltinAttributes.h" |
16 | 16 | #include "mlir/IR/BuiltinDialect.h" |
17 | 17 | #include "mlir/IR/BuiltinOps.h" |
| 18 | +#include "mlir/IR/BuiltinTypeInterfaces.h" |
18 | 19 | #include "mlir/IR/BuiltinTypes.h" |
19 | 20 | #include "mlir/IR/Diagnostics.h" |
20 | 21 | #include "mlir/IR/DialectRegistry.h" |
|
31 | 32 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
32 | 33 | #include "mlir/Pass/Pass.h" |
33 | 34 | #include "mlir/Transforms/DialectConversion.h" |
| 35 | +#include "llvm/ADT/APFloat.h" |
34 | 36 | #include "llvm/ADT/ArrayRef.h" |
35 | 37 | #include "llvm/ADT/STLExtras.h" |
36 | 38 | #include "llvm/ADT/Sequence.h" |
| 39 | +#include "llvm/ADT/StringExtras.h" |
37 | 40 | #include "llvm/Support/Casting.h" |
38 | 41 | #include <algorithm> |
39 | 42 | #include <cstdint> |
@@ -299,6 +302,94 @@ struct TransposeOpLowering : public OpConversionPattern<toy::TransposeOp> { |
299 | 302 | } |
300 | 303 | }; |
301 | 304 |
|
| 305 | +//===----------------------------------------------------------------------===// |
| 306 | +// ToyToAffine RewritePatterns: MatMul operations |
| 307 | +//===----------------------------------------------------------------------===// |
| 308 | + |
| 309 | +struct MatMulOpLowering : public ConversionPattern { |
| 310 | + MatMulOpLowering(MLIRContext *ctx) |
| 311 | + : ConversionPattern(toy::MatMulOp::getOperationName(), 1, ctx) {} |
| 312 | + |
| 313 | + LogicalResult |
| 314 | + matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
| 315 | + ConversionPatternRewriter &rewriter) const final { |
| 316 | + auto loc = op->getLoc(); |
| 317 | + |
| 318 | + RankedTensorType lhsType = |
| 319 | + llvm::dyn_cast<RankedTensorType>(op->getOperand(0).getType()); |
| 320 | + RankedTensorType rhsType = |
| 321 | + llvm::dyn_cast<RankedTensorType>(op->getOperand(1).getType()); |
| 322 | + auto lhsShape = lhsType.getShape(); |
| 323 | + auto rhsShape = rhsType.getShape(); |
| 324 | + |
| 325 | + auto tensorType = |
| 326 | + llvm::dyn_cast<RankedTensorType>((*op->result_type_begin())); |
| 327 | + |
| 328 | + auto elemType = llvm::dyn_cast<FloatType>(tensorType.getElementType()); |
| 329 | + |
| 330 | + // Insert an allocation and deallocation for the result of this operation. |
| 331 | + auto memRefType = convertTensorToMemRef(tensorType); |
| 332 | + auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter); |
| 333 | + |
| 334 | + SmallVector<int64_t, 4> lowerBounds(tensorType.getRank() + 1, /*Value=*/0); |
| 335 | + SmallVector<int64_t, 4> steps(tensorType.getRank() + 1, /*Value=*/1); |
| 336 | + SmallVector<int64_t, 4> upperBounds{lhsShape[0], rhsShape[0], rhsShape[1]}; |
| 337 | + |
| 338 | + // add initialization of result tensor. |
| 339 | + // Create a nest of affine loops to initialize the result tensor to 0. |
| 340 | + affine::buildAffineLoopNest( |
| 341 | + rewriter, loc, {0, 0}, tensorType.getShape(), {1, 1}, |
| 342 | + [&](OpBuilder &nestedBuilder, Location loc, ValueRange ivs) { |
| 343 | + // Create a constant float value of 0.0. |
| 344 | + auto valueToStore = arith::ConstantFloatOp::create( |
| 345 | + nestedBuilder, loc, elemType, |
| 346 | + llvm::APFloat::getZero(elemType.getFloatSemantics())); |
| 347 | + |
| 348 | + // Store the constant value into the allocated memory. |
| 349 | + affine::AffineStoreOp::create(nestedBuilder, loc, valueToStore, alloc, |
| 350 | + ivs); |
| 351 | + }); |
| 352 | + |
| 353 | + // Create a nest of affine loops for matrix multiplication. |
| 354 | + affine::buildAffineLoopNest( |
| 355 | + rewriter, loc, lowerBounds, upperBounds, steps, |
| 356 | + [&](OpBuilder &nestedBuilder, Location loc, ValueRange ivs) { |
| 357 | + // Extract loop induction variables. |
| 358 | + Value m = ivs[0]; |
| 359 | + Value k = ivs[1]; |
| 360 | + Value n = ivs[2]; |
| 361 | + |
| 362 | + // Create an adaptor for the remapped operands of the MatMulOp. |
| 363 | + toy::MatMulOpAdaptor matmulAdaptor(operands); |
| 364 | + |
| 365 | + // Load elements from the left-hand side and right-hand side matrices. |
| 366 | + auto loadedLhs = affine::AffineLoadOp::create( |
| 367 | + nestedBuilder, loc, matmulAdaptor.getLhs(), ValueRange{m, k}); |
| 368 | + |
| 369 | + auto loadedRhs = affine::AffineLoadOp::create( |
| 370 | + nestedBuilder, loc, matmulAdaptor.getRhs(), ValueRange{k, n}); |
| 371 | + // Load elements from the result tensor from initial process above. |
| 372 | + auto loadedRes = affine::AffineLoadOp::create( |
| 373 | + nestedBuilder, loc, alloc, ValueRange{m, n}); |
| 374 | + |
| 375 | + // Perform the multiplication and addition operations. |
| 376 | + auto mulop = |
| 377 | + arith::MulFOp::create(nestedBuilder, loc, loadedLhs, loadedRhs); |
| 378 | + auto valueToStore = |
| 379 | + arith::AddFOp::create(nestedBuilder, loc, loadedRes, mulop); |
| 380 | + |
| 381 | + // Store the result back into the allocated memory. |
| 382 | + affine::AffineStoreOp::create(nestedBuilder, loc, valueToStore, alloc, |
| 383 | + ValueRange{m, n}); |
| 384 | + }); |
| 385 | + |
| 386 | + // Replace this operation with the generated alloc. |
| 387 | + rewriter.replaceOp(op, alloc); |
| 388 | + |
| 389 | + return success(); |
| 390 | + } |
| 391 | +}; |
| 392 | + |
302 | 393 | } // namespace |
303 | 394 |
|
304 | 395 | //===----------------------------------------------------------------------===// |
@@ -350,8 +441,8 @@ void ToyToAffineLoweringPass::runOnOperation() { |
350 | 441 | // the set of patterns that will lower the Toy operations. |
351 | 442 | RewritePatternSet patterns(&getContext()); |
352 | 443 | patterns.add<AddOpLowering, ConstantOpLowering, FuncOpLowering, MulOpLowering, |
353 | | - PrintOpLowering, ReturnOpLowering, TransposeOpLowering>( |
354 | | - &getContext()); |
| 444 | + PrintOpLowering, ReturnOpLowering, TransposeOpLowering, |
| 445 | + MatMulOpLowering>(&getContext()); |
355 | 446 |
|
356 | 447 | // With the target and rewrite patterns defined, we can now attempt the |
357 | 448 | // conversion. The conversion will signal failure if any of our `illegal` |
|
0 commit comments