Skip to content

Commit 2b310d3

Browse files
committed
Added Matmul Toy Op
1 parent e3204ff commit 2b310d3

File tree

6 files changed

+213
-2
lines changed

6 files changed

+213
-2
lines changed

mlir/cuda-tile/Toy/include/toy/Ops.td

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,4 +374,33 @@ def TransposeOp : Toy_Op<"transpose",
374374
let hasVerifier = 1;
375375
}
376376

377+
//===----------------------------------------------------------------------===//
378+
// MatMul Op
379+
//===----------------------------------------------------------------------===//
380+
381+
def MatMulOp : Toy_Op<"matmul",
382+
[Pure, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, MemoryEffectsOpInterface]> {
383+
let summary = "matrix multiplication operation";
384+
let description = [{
385+
The "matmul" operation performs Matrix multiplication between two
386+
tensors. The shapes of the tensor operands are expected to match.
387+
}];
388+
389+
let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs);
390+
let results = (outs Res<F64Tensor, "",
391+
[MemWrite<DefaultResource>,
392+
MemAlloc<DefaultResource>]>:$output);
393+
394+
let assemblyFormat = [{
395+
`(` $lhs `:` type($lhs) `,` $rhs `:` type($rhs) `)` attr-dict `to` type(results)
396+
}];
397+
398+
// Allow building a MatMulOp with from the two input operands.
399+
let builders = [
400+
OpBuilder<(ins "Value":$lhs, "Value":$rhs)>
401+
];
402+
403+
let hasVerifier = 1;
404+
}
405+
377406
#endif // TOY_OPS

mlir/cuda-tile/Toy/mlir/Dialect.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,58 @@ llvm::LogicalResult TransposeOp::verify() {
437437
return mlir::success();
438438
}
439439

440+
//===----------------------------------------------------------------------===//
441+
// MatMulOp
442+
//===----------------------------------------------------------------------===//
443+
444+
void MatMulOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
445+
mlir::Value lhs, mlir::Value rhs) {
446+
state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
447+
state.addOperands({lhs, rhs});
448+
}
449+
450+
/// Infer the output shape of the MatMulOp, this is required by the shape
451+
/// inference interface.
452+
void MatMulOp::inferShapes() {
453+
RankedTensorType lhsType =
454+
llvm::dyn_cast<RankedTensorType>(getLhs().getType());
455+
RankedTensorType rhsType =
456+
llvm::dyn_cast<RankedTensorType>(getRhs().getType());
457+
auto lhsShape = lhsType.getShape();
458+
auto rhsShape = rhsType.getShape();
459+
RankedTensorType res_type = RankedTensorType::get({lhsShape[0], rhsShape[1]},
460+
lhsType.getElementType());
461+
getResult().setType(res_type);
462+
}
463+
464+
llvm::LogicalResult MatMulOp::verify() {
465+
auto lhsType = llvm::dyn_cast<RankedTensorType>(getLhs().getType());
466+
auto rhsType = llvm::dyn_cast<RankedTensorType>(getRhs().getType());
467+
auto resultType = llvm::dyn_cast<RankedTensorType>(getType());
468+
469+
if (!lhsType || !rhsType || !resultType)
470+
return mlir::success();
471+
472+
auto lhsShape = lhsType.getShape();
473+
auto rhsShape = rhsType.getShape();
474+
475+
if (lhsShape.size() != 2 || rhsShape.size() != 2) {
476+
return emitOpError() << "expected 2D matrix";
477+
}
478+
479+
if (lhsShape[1] != rhsShape[0]) {
480+
return emitOpError() << "expected dimension to match"
481+
<< "the shape of lhs is [" << lhsShape[0] << ", "
482+
<< lhsShape[1] << "] "
483+
<< "the shape of rhs is [" << rhsShape[0] << ", "
484+
<< rhsShape[1] << "] "
485+
<< "but the dimension " << lhsShape[1]
486+
<< "!=" << rhsShape[0] << '\n';
487+
}
488+
489+
return mlir::success();
490+
}
491+
440492
//===----------------------------------------------------------------------===//
441493
// TableGen'd op method definitions
442494
//===----------------------------------------------------------------------===//

mlir/cuda-tile/Toy/mlir/LowerToAffineLoops.cpp

Lines changed: 93 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir/IR/BuiltinAttributes.h"
1616
#include "mlir/IR/BuiltinDialect.h"
1717
#include "mlir/IR/BuiltinOps.h"
18+
#include "mlir/IR/BuiltinTypeInterfaces.h"
1819
#include "mlir/IR/BuiltinTypes.h"
1920
#include "mlir/IR/Diagnostics.h"
2021
#include "mlir/IR/DialectRegistry.h"
@@ -31,9 +32,11 @@
3132
#include "mlir/Dialect/MemRef/IR/MemRef.h"
3233
#include "mlir/Pass/Pass.h"
3334
#include "mlir/Transforms/DialectConversion.h"
35+
#include "llvm/ADT/APFloat.h"
3436
#include "llvm/ADT/ArrayRef.h"
3537
#include "llvm/ADT/STLExtras.h"
3638
#include "llvm/ADT/Sequence.h"
39+
#include "llvm/ADT/StringExtras.h"
3740
#include "llvm/Support/Casting.h"
3841
#include <algorithm>
3942
#include <cstdint>
@@ -299,6 +302,94 @@ struct TransposeOpLowering : public OpConversionPattern<toy::TransposeOp> {
299302
}
300303
};
301304

305+
//===----------------------------------------------------------------------===//
306+
// ToyToAffine RewritePatterns: MatMul operations
307+
//===----------------------------------------------------------------------===//
308+
309+
struct MatMulOpLowering : public ConversionPattern {
310+
MatMulOpLowering(MLIRContext *ctx)
311+
: ConversionPattern(toy::MatMulOp::getOperationName(), 1, ctx) {}
312+
313+
LogicalResult
314+
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
315+
ConversionPatternRewriter &rewriter) const final {
316+
auto loc = op->getLoc();
317+
318+
RankedTensorType lhsType =
319+
llvm::dyn_cast<RankedTensorType>(op->getOperand(0).getType());
320+
RankedTensorType rhsType =
321+
llvm::dyn_cast<RankedTensorType>(op->getOperand(1).getType());
322+
auto lhsShape = lhsType.getShape();
323+
auto rhsShape = rhsType.getShape();
324+
325+
auto tensorType =
326+
llvm::dyn_cast<RankedTensorType>((*op->result_type_begin()));
327+
328+
auto elemType = llvm::dyn_cast<FloatType>(tensorType.getElementType());
329+
330+
// Insert an allocation and deallocation for the result of this operation.
331+
auto memRefType = convertTensorToMemRef(tensorType);
332+
auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
333+
334+
SmallVector<int64_t, 4> lowerBounds(tensorType.getRank() + 1, /*Value=*/0);
335+
SmallVector<int64_t, 4> steps(tensorType.getRank() + 1, /*Value=*/1);
336+
SmallVector<int64_t, 4> upperBounds{lhsShape[0], rhsShape[0], rhsShape[1]};
337+
338+
// add initialization of result tensor.
339+
// Create a nest of affine loops to initialize the result tensor to 0.
340+
affine::buildAffineLoopNest(
341+
rewriter, loc, {0, 0}, tensorType.getShape(), {1, 1},
342+
[&](OpBuilder &nestedBuilder, Location loc, ValueRange ivs) {
343+
// Create a constant float value of 0.0.
344+
auto valueToStore = arith::ConstantFloatOp::create(
345+
nestedBuilder, loc, elemType,
346+
llvm::APFloat::getZero(elemType.getFloatSemantics()));
347+
348+
// Store the constant value into the allocated memory.
349+
affine::AffineStoreOp::create(nestedBuilder, loc, valueToStore, alloc,
350+
ivs);
351+
});
352+
353+
// Create a nest of affine loops for matrix multiplication.
354+
affine::buildAffineLoopNest(
355+
rewriter, loc, lowerBounds, upperBounds, steps,
356+
[&](OpBuilder &nestedBuilder, Location loc, ValueRange ivs) {
357+
// Extract loop induction variables.
358+
Value m = ivs[0];
359+
Value k = ivs[1];
360+
Value n = ivs[2];
361+
362+
// Create an adaptor for the remapped operands of the MatMulOp.
363+
toy::MatMulOpAdaptor matmulAdaptor(operands);
364+
365+
// Load elements from the left-hand side and right-hand side matrices.
366+
auto loadedLhs = affine::AffineLoadOp::create(
367+
nestedBuilder, loc, matmulAdaptor.getLhs(), ValueRange{m, k});
368+
369+
auto loadedRhs = affine::AffineLoadOp::create(
370+
nestedBuilder, loc, matmulAdaptor.getRhs(), ValueRange{k, n});
371+
// Load elements from the result tensor from initial process above.
372+
auto loadedRes = affine::AffineLoadOp::create(
373+
nestedBuilder, loc, alloc, ValueRange{m, n});
374+
375+
// Perform the multiplication and addition operations.
376+
auto mulop =
377+
arith::MulFOp::create(nestedBuilder, loc, loadedLhs, loadedRhs);
378+
auto valueToStore =
379+
arith::AddFOp::create(nestedBuilder, loc, loadedRes, mulop);
380+
381+
// Store the result back into the allocated memory.
382+
affine::AffineStoreOp::create(nestedBuilder, loc, valueToStore, alloc,
383+
ValueRange{m, n});
384+
});
385+
386+
// Replace this operation with the generated alloc.
387+
rewriter.replaceOp(op, alloc);
388+
389+
return success();
390+
}
391+
};
392+
302393
} // namespace
303394

304395
//===----------------------------------------------------------------------===//
@@ -350,8 +441,8 @@ void ToyToAffineLoweringPass::runOnOperation() {
350441
// the set of patterns that will lower the Toy operations.
351442
RewritePatternSet patterns(&getContext());
352443
patterns.add<AddOpLowering, ConstantOpLowering, FuncOpLowering, MulOpLowering,
353-
PrintOpLowering, ReturnOpLowering, TransposeOpLowering>(
354-
&getContext());
444+
PrintOpLowering, ReturnOpLowering, TransposeOpLowering,
445+
MatMulOpLowering>(&getContext());
355446

356447
// With the target and rewrite patterns defined, we can now attempt the
357448
// conversion. The conversion will signal failure if any of our `illegal`

mlir/cuda-tile/Toy/mlir/MLIRGen.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,15 @@ class MLIRGenImpl {
331331
return TransposeOp::create(builder, location, operands[0]);
332332
}
333333

334+
if (callee == "matmul") {
335+
if (call.getArgs().size() != 2) {
336+
emitError(location, "MLIR codegen encountered an error: toy.matmul "
337+
"expected 2 arguments");
338+
return nullptr;
339+
}
340+
return MatMulOp::create(builder, location, operands[0], operands[1]);
341+
}
342+
334343
// Otherwise this is a call to a user-defined function. Calls to
335344
// user-defined functions are mapped to a custom call that takes the callee
336345
// name as an attribute.

mlir/cuda-tile/sample/matmul.toy

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
def main() {
2+
# Define a variable `a` with shape <2, 3>, initialized with the literal value.
3+
# The shape is inferred from the supplied literal.
4+
var a = [[1, 2, 3], [4, 5, 6]];
5+
6+
# b is identical to a, the literal tensor is implicitly reshaped: defining new
7+
# variables is the way to reshape tensors (element count must match).
8+
var b<2, 3> = [1, 2, 3, 4, 5, 6];
9+
10+
# transpose() and print() are the only builtin, the following will transpose
11+
# a and b and perform an element-wise multiplication before printing the result.
12+
# print(a * b + b);
13+
print(matmul(a, transpose(b)));
14+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
toy.func private @matmul_transpose(%arg0: tensor<*xf64>, %arg1: tensor<*xf64>) -> tensor<*xf64> {
2+
%0 = toy.transpose(%arg0 : tensor<*xf64>) to tensor<*xf64>
3+
%1 = toy.transpose(%arg1 : tensor<*xf64>) to tensor<*xf64>
4+
%2 = toy.matmul(%0 : tensor<*xf64>, %1 : tensor<*xf64>) to tensor<*xf64>
5+
toy.return %2 : tensor<*xf64>
6+
}
7+
8+
toy.func @main() {
9+
%0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
10+
%1 = toy.reshape(%0 : tensor<2x3xf64>) to tensor<2x3xf64>
11+
%2 = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>
12+
%3 = toy.reshape(%2 : tensor<6xf64>) to tensor<3x2xf64>
13+
%4 = toy.generic_call @matmul_transpose(%1, %3) : (tensor<2x3xf64>, tensor<3x2xf64>) -> tensor<*xf64>
14+
toy.print %4 : tensor<*xf64>
15+
toy.return
16+
}

0 commit comments

Comments
 (0)