Skip to content

Commit 8a0ba30

Browse files
committed
Change float type to FP32
1 parent 88ce121 commit 8a0ba30

File tree

4 files changed

+56
-44
lines changed

4 files changed

+56
-44
lines changed

mlir/cuda-tile/Toy/include/toy/Ops.td

Lines changed: 30 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ include "mlir/Interfaces/CastInterfaces.td"
2020
include "mlir/Interfaces/SideEffectInterfaces.td"
2121
include "toy/ShapeInferenceInterface.td"
2222

23+
def F32ElementsAttr : FloatElementsAttr<32>;
24+
2325
// Provide a definition of the 'toy' dialect in the ODS framework so that we
2426
// can define our operations.
2527
def Toy_Dialect : Dialect {
@@ -57,15 +59,15 @@ def ConstantOp : Toy_Op<"constant", [Pure]> {
5759

5860
```mlir
5961
%0 = toy.constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]>
60-
: tensor<2x3xf64>
62+
: tensor<2x3xf32>
6163
```
6264
}];
6365

6466
// The constant operation takes an attribute as the only input.
65-
let arguments = (ins F64ElementsAttr:$value);
67+
let arguments = (ins F32ElementsAttr:$value);
6668

6769
// The constant operation returns a single value of TensorType.
68-
let results = (outs F64Tensor);
70+
let results = (outs F32Tensor);
6971

7072
// Indicate that the operation has a custom parser and printer method.
7173
let hasCustomAssemblyFormat = 1;
@@ -80,7 +82,7 @@ def ConstantOp : Toy_Op<"constant", [Pure]> {
8082
}]>,
8183

8284
// Build a constant with a given constant floating-point value.
83-
OpBuilder<(ins "double":$value)>
85+
OpBuilder<(ins "float":$value)>
8486
];
8587

8688
// Indicate that additional verification for this operation is necessary.
@@ -99,8 +101,8 @@ def AddOp : Toy_Op<"add",
99101
The shapes of the tensor operands are expected to match.
100102
}];
101103

102-
let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs);
103-
let results = (outs F64Tensor);
104+
let arguments = (ins F32Tensor:$lhs, F32Tensor:$rhs);
105+
let results = (outs F32Tensor);
104106

105107
// Indicate that the operation has a custom parser and printer method.
106108
let hasCustomAssemblyFormat = 1;
@@ -130,8 +132,8 @@ def CastOp : Toy_Op<"cast", [
130132
mismatching constant dimension.
131133
}];
132134

133-
let arguments = (ins F64Tensor:$input);
134-
let results = (outs F64Tensor:$output);
135+
let arguments = (ins F32Tensor:$input);
136+
let results = (outs F32Tensor:$output);
135137

136138
let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)";
137139
}
@@ -152,9 +154,9 @@ def FuncOp : Toy_Op<"func", [
152154

153155
```mlir
154156
toy.func @main() {
155-
%0 = toy.constant dense<5.500000e+00> : tensor<f64>
156-
%1 = toy.reshape(%0 : tensor<f64>) to tensor<2x2xf64>
157-
toy.print %1 : tensor<2x2xf64>
157+
%0 = toy.constant dense<5.500000e+00> : tensor<f32>
158+
%1 = toy.reshape(%0 : tensor<f32>) to tensor<2x2xf32>
159+
toy.print %1 : tensor<2x2xf32>
158160
toy.return
159161
}
160162
```
@@ -205,7 +207,7 @@ def GenericCallOp : Toy_Op<"generic_call",
205207

206208
```mlir
207209
%4 = toy.generic_call @my_func(%1, %3)
208-
: (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64>
210+
: (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<*xf32>
209211
```
210212

211213
This is only valid if a function named "my_func" exists and takes two
@@ -216,13 +218,13 @@ def GenericCallOp : Toy_Op<"generic_call",
216218
// callee, and inputs for the call.
217219
let arguments = (ins
218220
FlatSymbolRefAttr:$callee,
219-
Variadic<F64Tensor>:$inputs,
221+
Variadic<F32Tensor>:$inputs,
220222
OptionalAttr<DictArrayAttr>:$arg_attrs,
221223
OptionalAttr<DictArrayAttr>:$res_attrs
222224
);
223225

224226
// The generic call operation returns a single value of TensorType.
225-
let results = (outs F64Tensor);
227+
let results = (outs F32Tensor);
226228

227229
// Specialize assembly printing and parsing using a declarative format.
228230
let assemblyFormat = [{
@@ -247,8 +249,8 @@ def MulOp : Toy_Op<"mul",
247249
tensors. The shapes of the tensor operands are expected to match.
248250
}];
249251

250-
let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs);
251-
let results = (outs F64Tensor);
252+
let arguments = (ins F32Tensor:$lhs, F32Tensor:$rhs);
253+
let results = (outs F32Tensor);
252254

253255
// Indicate that the operation has a custom parser and printer method.
254256
let hasCustomAssemblyFormat = 1;
@@ -271,8 +273,8 @@ def PrintOp : Toy_Op<"print"> {
271273
}];
272274

273275
// The print operation takes an input tensor to print.
274-
// We also allow a F64MemRef to enable interop during partial lowering.
275-
let arguments = (ins AnyTypeOf<[F64Tensor, F64MemRef]>:$input);
276+
// We also allow a F32MemRef to enable interop during partial lowering.
277+
let arguments = (ins AnyTypeOf<[F32Tensor, F32MemRef]>:$input);
276278

277279
let assemblyFormat = "$input attr-dict `:` type($input)";
278280
}
@@ -288,11 +290,11 @@ def ReshapeOp : Toy_Op<"reshape", [Pure]> {
288290
the same number of elements but different shapes. For example:
289291

290292
```mlir
291-
%0 = toy.reshape (%arg1 : tensor<10xf64>) to tensor<5x2xf64>
293+
%0 = toy.reshape (%arg1 : tensor<10xf32>) to tensor<5x2xf32>
292294
```
293295
}];
294296

295-
let arguments = (ins F64Tensor:$input);
297+
let arguments = (ins F32Tensor:$input);
296298

297299
let assemblyFormat = [{
298300
`(` $input `:` type($input) `)` attr-dict `to` type(results)
@@ -302,7 +304,7 @@ def ReshapeOp : Toy_Op<"reshape", [Pure]> {
302304
let hasCanonicalizer = 1;
303305

304306
// We expect that the reshape operation returns a statically shaped tensor.
305-
let results = (outs StaticShapeTensorOf<[F64]>);
307+
let results = (outs StaticShapeTensorOf<[F32]>);
306308
}
307309

308310
//===----------------------------------------------------------------------===//
@@ -319,16 +321,16 @@ def ReturnOp : Toy_Op<"return", [Pure, HasParent<"FuncOp">,
319321
the operation. For example:
320322

321323
```mlir
322-
toy.func @foo() -> tensor<2xf64> {
324+
toy.func @foo() -> tensor<2xf32> {
323325
...
324-
toy.return %0 : tensor<2xf64>
326+
toy.return %0 : tensor<2xf32>
325327
}
326328
```
327329
}];
328330

329331
// The return operation takes an optional input operand to return. This
330332
// value must match the return type of the enclosing function.
331-
let arguments = (ins Variadic<F64Tensor>:$input);
333+
let arguments = (ins Variadic<F32Tensor>:$input);
332334

333335
// The return operation only emits the input in the format if it is present.
334336
let assemblyFormat = "($input^ `:` type($input))? attr-dict ";
@@ -355,8 +357,8 @@ def TransposeOp : Toy_Op<"transpose",
355357
[Pure, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
356358
let summary = "transpose operation";
357359

358-
let arguments = (ins F64Tensor:$input);
359-
let results = (outs F64Tensor);
360+
let arguments = (ins F32Tensor:$input);
361+
let results = (outs F32Tensor);
360362

361363
let assemblyFormat = [{
362364
`(` $input `:` type($input) `)` attr-dict `to` type(results)
@@ -386,8 +388,8 @@ def MatMulOp : Toy_Op<"matmul",
386388
tensors. The shapes of the tensor operands are expected to match.
387389
}];
388390

389-
let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs);
390-
let results = (outs Res<F64Tensor, "",
391+
let arguments = (ins F32Tensor:$lhs, F32Tensor:$rhs);
392+
let results = (outs Res<F32Tensor, "",
391393
[MemWrite<DefaultResource>,
392394
MemAlloc<DefaultResource>]>:$output);
393395

mlir/cuda-tile/Toy/mlir/Dialect.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,8 @@ static void printBinaryOp(mlir::OpAsmPrinter &printer, mlir::Operation *op) {
169169
/// The builder is passed as an argument, so is the state that this method is
170170
/// expected to fill in order to build the operation.
171171
void ConstantOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
172-
double value) {
173-
auto dataType = RankedTensorType::get({}, builder.getF64Type());
172+
float value) {
173+
auto dataType = RankedTensorType::get({}, builder.getF32Type());
174174
auto dataAttribute = DenseElementsAttr::get(dataType, value);
175175
ConstantOp::build(builder, state, dataType, dataAttribute);
176176
}
@@ -238,7 +238,7 @@ llvm::LogicalResult ConstantOp::verify() {
238238

239239
void AddOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
240240
mlir::Value lhs, mlir::Value rhs) {
241-
state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
241+
state.addTypes(UnrankedTensorType::get(builder.getF32Type()));
242242
state.addOperands({lhs, rhs});
243243
}
244244

@@ -319,7 +319,7 @@ void FuncOp::print(mlir::OpAsmPrinter &p) {
319319
void GenericCallOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
320320
StringRef callee, ArrayRef<mlir::Value> arguments) {
321321
// Generic call always returns an unranked Tensor initially.
322-
state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
322+
state.addTypes(UnrankedTensorType::get(builder.getF32Type()));
323323
state.addOperands(arguments);
324324
state.addAttribute("callee",
325325
mlir::SymbolRefAttr::get(builder.getContext(), callee));
@@ -353,7 +353,7 @@ MutableOperandRange GenericCallOp::getArgOperandsMutable() {
353353

354354
void MulOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
355355
mlir::Value lhs, mlir::Value rhs) {
356-
state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
356+
state.addTypes(UnrankedTensorType::get(builder.getF32Type()));
357357
state.addOperands({lhs, rhs});
358358
}
359359

@@ -412,7 +412,7 @@ llvm::LogicalResult ReturnOp::verify() {
412412

413413
void TransposeOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
414414
mlir::Value value) {
415-
state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
415+
state.addTypes(UnrankedTensorType::get(builder.getF32Type()));
416416
state.addOperands(value);
417417
}
418418

@@ -443,7 +443,7 @@ llvm::LogicalResult TransposeOp::verify() {
443443

444444
void MatMulOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
445445
mlir::Value lhs, mlir::Value rhs) {
446-
state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
446+
state.addTypes(UnrankedTensorType::get(builder.getF32Type()));
447447
state.addOperands({lhs, rhs});
448448
}
449449

mlir/cuda-tile/Toy/mlir/LowerToLLVM.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,18 @@ class PrintOpLowering : public OpConversionPattern<toy::PrintOp> {
109109
// Generate a call to printf for the current element of the loop.
110110
auto elementLoad =
111111
memref::LoadOp::create(rewriter, loc, op.getInput(), loopIvs);
112+
113+
// Varargs promotion: float -> double
114+
Value arg = elementLoad;
115+
Type t = elementLoad.getType();
116+
if (t.isF32()) {
117+
arg = arith::ExtFOp::create(rewriter, loc, rewriter.getF64Type(), arg);
118+
} else if (!t.isF64()) {
119+
return rewriter.notifyMatchFailure(op, "toy.print only supports f32/f64");
120+
}
121+
112122
LLVM::CallOp::create(rewriter, loc, getPrintfType(context), printfRef,
113-
ArrayRef<Value>({formatSpecifierCst, elementLoad}));
123+
ArrayRef<Value>({formatSpecifierCst, arg}));
114124

115125
// Notify the rewriter that this operation has been removed.
116126
rewriter.eraseOp(op);

mlir/cuda-tile/Toy/mlir/MLIRGen.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -258,22 +258,22 @@ class MLIRGenImpl {
258258
/// Example, the source level statement:
259259
/// var a<2, 3> = [[1, 2, 3], [4, 5, 6]];
260260
/// will be converted to:
261-
/// %0 = "toy.constant"() {value: dense<tensor<2x3xf64>,
261+
/// %0 = "toy.constant"() {value: dense<tensor<2x3xf32>,
262262
/// [[1.000000e+00, 2.000000e+00, 3.000000e+00],
263-
/// [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> tensor<2x3xf64>
263+
/// [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> tensor<2x3xf32>
264264
///
265265
mlir::Value mlirGen(LiteralExprAST &lit) {
266266
auto type = getType(lit.getDims());
267267

268268
// The attribute is a vector with a floating point value per element
269269
// (number) in the array, see `collectData()` below for more details.
270-
std::vector<double> data;
270+
std::vector<float> data;
271271
data.reserve(llvm::product_of(lit.getDims()));
272272
collectData(lit, data);
273273

274274
// The type of this attribute is tensor of 64-bit floating-point with the
275275
// shape of the literal.
276-
mlir::Type elementType = builder.getF64Type();
276+
mlir::Type elementType = builder.getF32Type();
277277
auto dataType = mlir::RankedTensorType::get(lit.getDims(), elementType);
278278

279279
// This is the actual attribute that holds the list of values for this
@@ -292,9 +292,9 @@ class MLIRGenImpl {
292292
/// [[1, 2], [3, 4]]
293293
/// we will generate:
294294
/// [ 1, 2, 3, 4 ]
295-
/// Individual numbers are represented as doubles.
295+
/// Individual numbers are represented as floats.
296296
/// Attributes are the way MLIR attaches constant to operations.
297-
void collectData(ExprAST &expr, std::vector<double> &data) {
297+
void collectData(ExprAST &expr, std::vector<float> &data) {
298298
if (auto *lit = dyn_cast<LiteralExprAST>(&expr)) {
299299
for (auto &value : lit->getValues())
300300
collectData(*value, data);
@@ -444,10 +444,10 @@ class MLIRGenImpl {
444444
mlir::Type getType(ArrayRef<int64_t> shape) {
445445
// If the shape is empty, then this type is unranked.
446446
if (shape.empty())
447-
return mlir::UnrankedTensorType::get(builder.getF64Type());
447+
return mlir::UnrankedTensorType::get(builder.getF32Type());
448448

449449
// Otherwise, we use the given shape.
450-
return mlir::RankedTensorType::get(shape, builder.getF64Type());
450+
return mlir::RankedTensorType::get(shape, builder.getF32Type());
451451
}
452452

453453
/// Build an MLIR type from a Toy AST variable type (forward to the generic

0 commit comments

Comments
 (0)