@@ -20,6 +20,8 @@ include "mlir/Interfaces/CastInterfaces.td"
2020include "mlir/Interfaces/SideEffectInterfaces.td"
2121include "toy/ShapeInferenceInterface.td"
2222
23+ def F32ElementsAttr : FloatElementsAttr<32>;
24+
2325// Provide a definition of the 'toy' dialect in the ODS framework so that we
2426// can define our operations.
2527def Toy_Dialect : Dialect {
@@ -57,15 +59,15 @@ def ConstantOp : Toy_Op<"constant", [Pure]> {
5759
5860 ```mlir
5961 %0 = toy.constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]>
60- : tensor<2x3xf64 >
62+ : tensor<2x3xf32 >
6163 ```
6264 }];
6365
6466 // The constant operation takes an attribute as the only input.
65- let arguments = (ins F64ElementsAttr :$value);
67+ let arguments = (ins F32ElementsAttr :$value);
6668
6769 // The constant operation returns a single value of TensorType.
68- let results = (outs F64Tensor );
70+ let results = (outs F32Tensor );
6971
7072 // Indicate that the operation has a custom parser and printer method.
7173 let hasCustomAssemblyFormat = 1;
@@ -80,7 +82,7 @@ def ConstantOp : Toy_Op<"constant", [Pure]> {
8082 }]>,
8183
8284 // Build a constant with a given constant floating-point value.
83- OpBuilder<(ins "double ":$value)>
85+ OpBuilder<(ins "float ":$value)>
8486 ];
8587
8688 // Indicate that additional verification for this operation is necessary.
@@ -99,8 +101,8 @@ def AddOp : Toy_Op<"add",
99101 The shapes of the tensor operands are expected to match.
100102 }];
101103
102- let arguments = (ins F64Tensor :$lhs, F64Tensor :$rhs);
103- let results = (outs F64Tensor );
104+ let arguments = (ins F32Tensor :$lhs, F32Tensor :$rhs);
105+ let results = (outs F32Tensor );
104106
105107 // Indicate that the operation has a custom parser and printer method.
106108 let hasCustomAssemblyFormat = 1;
@@ -130,8 +132,8 @@ def CastOp : Toy_Op<"cast", [
130132 mismatching constant dimension.
131133 }];
132134
133- let arguments = (ins F64Tensor :$input);
134- let results = (outs F64Tensor :$output);
135+ let arguments = (ins F32Tensor :$input);
136+ let results = (outs F32Tensor :$output);
135137
136138 let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)";
137139}
@@ -152,9 +154,9 @@ def FuncOp : Toy_Op<"func", [
152154
153155 ```mlir
154156 toy.func @main() {
155- %0 = toy.constant dense<5.500000e+00> : tensor<f64 >
156- %1 = toy.reshape(%0 : tensor<f64 >) to tensor<2x2xf64 >
157- toy.print %1 : tensor<2x2xf64 >
157+ %0 = toy.constant dense<5.500000e+00> : tensor<f32 >
158+ %1 = toy.reshape(%0 : tensor<f32 >) to tensor<2x2xf32 >
159+ toy.print %1 : tensor<2x2xf32 >
158160 toy.return
159161 }
160162 ```
@@ -205,7 +207,7 @@ def GenericCallOp : Toy_Op<"generic_call",
205207
206208 ```mlir
207209 %4 = toy.generic_call @my_func(%1, %3)
208- : (tensor<2x3xf64 >, tensor<2x3xf64 >) -> tensor<*xf64 >
210+ : (tensor<2x3xf32 >, tensor<2x3xf32 >) -> tensor<*xf32 >
209211 ```
210212
211213 This is only valid if a function named "my_func" exists and takes two
@@ -216,13 +218,13 @@ def GenericCallOp : Toy_Op<"generic_call",
216218 // callee, and inputs for the call.
217219 let arguments = (ins
218220 FlatSymbolRefAttr:$callee,
219- Variadic<F64Tensor >:$inputs,
221+ Variadic<F32Tensor >:$inputs,
220222 OptionalAttr<DictArrayAttr>:$arg_attrs,
221223 OptionalAttr<DictArrayAttr>:$res_attrs
222224 );
223225
224226 // The generic call operation returns a single value of TensorType.
225- let results = (outs F64Tensor );
227+ let results = (outs F32Tensor );
226228
227229 // Specialize assembly printing and parsing using a declarative format.
228230 let assemblyFormat = [{
@@ -247,8 +249,8 @@ def MulOp : Toy_Op<"mul",
247249 tensors. The shapes of the tensor operands are expected to match.
248250 }];
249251
250- let arguments = (ins F64Tensor :$lhs, F64Tensor :$rhs);
251- let results = (outs F64Tensor );
252+ let arguments = (ins F32Tensor :$lhs, F32Tensor :$rhs);
253+ let results = (outs F32Tensor );
252254
253255 // Indicate that the operation has a custom parser and printer method.
254256 let hasCustomAssemblyFormat = 1;
@@ -271,8 +273,8 @@ def PrintOp : Toy_Op<"print"> {
271273 }];
272274
273275 // The print operation takes an input tensor to print.
274- // We also allow a F64MemRef to enable interop during partial lowering.
275- let arguments = (ins AnyTypeOf<[F64Tensor, F64MemRef ]>:$input);
276+ // We also allow a F32MemRef to enable interop during partial lowering.
277+ let arguments = (ins AnyTypeOf<[F32Tensor, F32MemRef ]>:$input);
276278
277279 let assemblyFormat = "$input attr-dict `:` type($input)";
278280}
@@ -288,11 +290,11 @@ def ReshapeOp : Toy_Op<"reshape", [Pure]> {
288290 the same number of elements but different shapes. For example:
289291
290292 ```mlir
291- %0 = toy.reshape (%arg1 : tensor<10xf64 >) to tensor<5x2xf64 >
293+ %0 = toy.reshape (%arg1 : tensor<10xf32 >) to tensor<5x2xf32 >
292294 ```
293295 }];
294296
295- let arguments = (ins F64Tensor :$input);
297+ let arguments = (ins F32Tensor :$input);
296298
297299 let assemblyFormat = [{
298300 `(` $input `:` type($input) `)` attr-dict `to` type(results)
@@ -302,7 +304,7 @@ def ReshapeOp : Toy_Op<"reshape", [Pure]> {
302304 let hasCanonicalizer = 1;
303305
304306 // We expect that the reshape operation returns a statically shaped tensor.
305- let results = (outs StaticShapeTensorOf<[F64 ]>);
307+ let results = (outs StaticShapeTensorOf<[F32 ]>);
306308}
307309
308310//===----------------------------------------------------------------------===//
@@ -319,16 +321,16 @@ def ReturnOp : Toy_Op<"return", [Pure, HasParent<"FuncOp">,
319321 the operation. For example:
320322
321323 ```mlir
322- toy.func @foo() -> tensor<2xf64 > {
324+ toy.func @foo() -> tensor<2xf32 > {
323325 ...
324- toy.return %0 : tensor<2xf64 >
326+ toy.return %0 : tensor<2xf32 >
325327 }
326328 ```
327329 }];
328330
329331 // The return operation takes an optional input operand to return. This
330332 // value must match the return type of the enclosing function.
331- let arguments = (ins Variadic<F64Tensor >:$input);
333+ let arguments = (ins Variadic<F32Tensor >:$input);
332334
333335 // The return operation only emits the input in the format if it is present.
334336 let assemblyFormat = "($input^ `:` type($input))? attr-dict ";
@@ -355,8 +357,8 @@ def TransposeOp : Toy_Op<"transpose",
355357 [Pure, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
356358 let summary = "transpose operation";
357359
358- let arguments = (ins F64Tensor :$input);
359- let results = (outs F64Tensor );
360+ let arguments = (ins F32Tensor :$input);
361+ let results = (outs F32Tensor );
360362
361363 let assemblyFormat = [{
362364 `(` $input `:` type($input) `)` attr-dict `to` type(results)
@@ -386,8 +388,8 @@ def MatMulOp : Toy_Op<"matmul",
386388 tensors. The shapes of the tensor operands are expected to match.
387389 }];
388390
389- let arguments = (ins F64Tensor :$lhs, F64Tensor :$rhs);
390- let results = (outs Res<F64Tensor , "",
391+ let arguments = (ins F32Tensor :$lhs, F32Tensor :$rhs);
392+ let results = (outs Res<F32Tensor , "",
391393 [MemWrite<DefaultResource>,
392394 MemAlloc<DefaultResource>]>:$output);
393395
0 commit comments