@@ -26,6 +26,7 @@ func.func @fc_relu(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>,
26
26
// Elementwise max with 0 (ReLU).
27
27
%c0f = arith.constant 0.0 : f32
28
28
%relued = linalg.elementwise kind=#linalg.elementwise_kind<max_signed>
29
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>]
29
30
ins(%biased, %c0f : tensor<512x512xf32>, f32)
30
31
outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32>
31
32
func.return %relued : tensor<512x512xf32>
@@ -95,18 +96,18 @@ $ mlir-opt sequence.mlir --pass-pipeline="
95
96
The ` sequence.mlir ` file contains _ both_ the payload IR function _ and_ the transform IR sequence nested in the same module. The transform interpreter pass will apply the ` @__transform_main ` named sequence to the anchor operation of the pass. In our case, we also asked the interpreter pass to associate the two extra arguments of the top-level sequence with all ` linalg.matmul ` and ` linalg.elementwise ` payload operations through the respective pass options. Running this pass results in the expected remarks:
96
97
97
98
``` sh
98
- sequence.mlir:7 :13: remark: matmul
99
+ sequence.mlir:5 :13: remark: matmul
99
100
%matmul = linalg.matmul ins(%lhs, %rhs: tensor< 512x512xf32> , tensor< 512x512xf32> )
100
101
^
101
- sequence.mlir:7 :13: note: see current operation: %0 = linalg.matmul ins(%arg0, %arg1 : tensor< 512x512xf32> , tensor< 512x512xf32> ) outs(%arg3 : tensor< 512x512xf32> ) -> tensor< 512x512xf32>
102
- sequence.mlir:10 :13: remark: elemwise_binaries
102
+ sequence.mlir:5 :13: note: see current operation: %0 = linalg.matmul ins(%arg0, %arg1 : tensor< 512x512xf32> , tensor< 512x512xf32> ) outs(%arg3 : tensor< 512x512xf32> ) -> tensor< 512x512xf32>
103
+ sequence.mlir:9 :13: remark: elemwise_binaries
103
104
%biased = linalg.elementwise kind=# linalg.elementwise_kind<add>
104
105
^
105
- sequence.mlir:10 :13: note: see current operation: %1 = linalg.elementwise kind=# linalg.elementwise_kind<add> > ins(%0, %arg2 : tensor<512x512xf32>, tensor<512x512xf32>) outs(%arg3 : tensor<512x512xf32>) -> tensor<512x512xf32>
106
- sequence.mlir:14 :13: remark: elemwise_binaries
106
+ sequence.mlir:9 :13: note: see current operation: %1 = linalg.elementwise kind=# linalg.elementwise_kind<add> ins(%0, %arg2 : tensor<512x512xf32>, tensor<512x512xf32>) outs(%arg3 : tensor<512x512xf32>) -> tensor<512x512xf32>
107
+ sequence.mlir:15 :13: remark: elemwise_binaries
107
108
%relued = linalg.elementwise kind=# linalg.elementwise_kind<max_signed>
108
109
^
109
- sequence.mlir:14 :13: note: see current operation: %2 = linalg.elementwise kind=# linalg.elementwise_kind<max_signed>> ins(%1, %cst : tensor<512x512xf32>, f32) outs(%arg3 : tensor<512x512xf32>) -> tensor<512x512xf32>
110
+ sequence.mlir:15 :13: note: see current operation: %2 = linalg.elementwise kind=# linalg.elementwise_kind<max_signed> indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>] ins(%1, %cst : tensor<512x512xf32>, f32) outs(%arg3 : tensor<512x512xf32>) -> tensor<512x512xf32>
110
111
```
111
112
112
113
Note that ` %arg2 ` is associated with both elementwise payload operations. Any handle is associated with a list of entities. Individual transformations may or may not care about the order of elements in that list.
@@ -140,33 +141,39 @@ The transformation returns two handles, as indicated in its [documentation](http
140
141
Running this transformation with the same command as above expectedly produces the tiled code.
141
142
142
143
``` mlir
144
+ #map = affine_map<(d0) -> (d0 * 4)>
145
+ #map1 = affine_map<(d0) -> (d0 * 32)>
146
+ #map2 = affine_map<(d0, d1) -> (d0, d1)>
147
+ #map3 = affine_map<(d0, d1) -> ()>
148
+
143
149
func.func @fc_relu(%arg0: tensor<512x512xf32>,
144
150
%arg1: tensor<512x512xf32>,
145
151
%arg2: tensor<512x512xf32>,
146
152
%arg3: tensor<512x512xf32>) -> tensor<512x512xf32> {
147
- %cst = arith.constant 0.000000e+00 : f32
148
153
%0 = scf.forall (%arg4, %arg5) in (128, 16) shared_outs(%arg6 = %arg3) -> (tensor<512x512xf32>) {
149
- %3 = affine.apply affine_map<(d0) -> (d0 * 4)> (%arg4)
150
- %4 = affine.apply affine_map<(d0) -> (d0 * 32)> (%arg5)
154
+ %3 = affine.apply #map (%arg4)
155
+ %4 = affine.apply #map1 (%arg5)
151
156
%extracted_slice = tensor.extract_slice %arg0[%3, 0] [4, 512] [1, 1]
152
157
: tensor<512x512xf32> to tensor<4x512xf32>
153
158
%extracted_slice_0 = tensor.extract_slice %arg1[0, %4] [512, 32] [1, 1]
154
- : tensor<512x512xf32> to tensor<512x32xf32>
159
+ : tensor<512x512xf32> to tensor<512x32xf32>
155
160
%extracted_slice_1 = tensor.extract_slice %arg6[%3, %4] [4, 32] [1, 1]
156
- : tensor<512x512xf32> to tensor<4x32xf32>
161
+ : tensor<512x512xf32> to tensor<4x32xf32>
157
162
%5 = linalg.matmul
158
163
ins(%extracted_slice, %extracted_slice_0
159
- : tensor<4x512xf32>, tensor<512x32xf32>)
164
+ : tensor<4x512xf32>, tensor<512x32xf32>)
160
165
outs(%extracted_slice_1 : tensor<4x32xf32>) -> tensor<4x32xf32>
161
166
scf.forall.in_parallel {
162
167
tensor.parallel_insert_slice %5 into %arg6[%3, %4] [4, 32] [1, 1]
163
- : tensor<4x32xf32> into tensor<512x512xf32>
168
+ : tensor<4x32xf32> into tensor<512x512xf32>
164
169
}
165
170
}
166
- %1 = linalg.elementwise kind=#linalg.elementwise_kind<add>>
167
- ins(%0, %arg2 : tensor<512x512xf32>, tensor<512x512xf32>)
168
- outs(%arg3 : tensor<512x512xf32>) -> tensor<512x512xf32>
169
- %2 = linalg.elementwise kind=#linalg.elementwise_kind<max_signed>>
171
+ %1 = linalg.elementwise kind=#linalg.elementwise_kind<add>
172
+ ins(%0, %arg2 : tensor<512x512xf32>, tensor<512x512xf32>)
173
+ outs(%arg3 : tensor<512x512xf32>) -> tensor<512x512xf32>
174
+ %cst = arith.constant 0.000000e+00 : f32
175
+ %2 = linalg.elementwise kind=#linalg.elementwise_kind<max_signed>
176
+ indexing_maps = [#map2, #map3, #map2]
170
177
ins(%1, %cst : tensor<512x512xf32>, f32)
171
178
outs(%arg3 : tensor<512x512xf32>) -> tensor<512x512xf32>
172
179
return %2 : tensor<512x512xf32>
@@ -216,7 +223,7 @@ One may observe that some operations such as `transform.cast` do not consume the
216
223
217
224
``` mlir
218
225
module attributes {transform.with_named_sequence} {
219
- transform.named_sequence @__transform_main
226
+ transform.named_sequence @__transform_main(
220
227
%arg0: !transform.any_op,
221
228
%arg1: !transform.op<"linalg.matmul">,
222
229
%arg2: !transform.op<"linalg.elementwise">) {
0 commit comments