4
4
5
5
// CHECK-LABEL: contraction_dot
6
6
func @contraction_dot (%A: memref <1584 xf32 >, %B: memref <1584 xf32 >, %C: memref <f32 >) {
7
- // CHECK: vector.contract
8
- // CHECK-SAME: vector<1584xf32>, vector<1584xf32> into f32
7
+
8
+ // CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584xf32>
9
+ // CHECK: vector.multi_reduction #vector.kind<add>, %{{.*}} [0] : vector<1584xf32> to f32
10
+ // CHECK: arith.addf %{{.*}}, %{{.*}} : f32
9
11
linalg.dot ins (%A , %B: memref <1584 xf32 >, memref <1584 xf32 >)
10
12
outs (%C: memref <f32 >)
11
13
return
@@ -15,8 +17,10 @@ func @contraction_dot(%A: memref<1584xf32>, %B: memref<1584xf32>, %C: memref<f32
15
17
16
18
// CHECK-LABEL: contraction_matvec
17
19
func @contraction_matvec (%A: memref <1584 x1584 xf32 >, %B: memref <1584 xf32 >, %C: memref <1584 xf32 >) {
18
- // CHECK: vector.contract
19
- // CHECK-SAME: vector<1584x1584xf32>, vector<1584xf32> into vector<1584xf32>
20
+
21
+ // CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584x1584xf32>
22
+ // CHECK: vector.multi_reduction #vector.kind<add>, %{{.*}} [1] : vector<1584x1584xf32> to vector<1584xf32>
23
+ // CHECK: arith.addf %{{.*}}, %{{.*}} : vector<1584xf32>
20
24
linalg.matvec ins (%A , %B: memref <1584 x1584 xf32 >, memref <1584 xf32 >)
21
25
outs (%C: memref <1584 xf32 >)
22
26
return
@@ -26,8 +30,9 @@ func @contraction_matvec(%A: memref<1584x1584xf32>, %B: memref<1584xf32>, %C: me
26
30
27
31
// CHECK-LABEL: contraction_matmul
28
32
func @contraction_matmul (%A: memref <1584 x1584 xf32 >, %B: memref <1584 x1584 xf32 >, %C: memref <1584 x1584 xf32 >) {
29
- // CHECK: vector.contract
30
- // CHECK-SAME: vector<1584x1584xf32>, vector<1584x1584xf32> into vector<1584x1584xf32>
33
+ // CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584x1584x1584xf32>
34
+ // CHECK: vector.multi_reduction #vector.kind<add>, %{{.*}} [2] : vector<1584x1584x1584xf32> to vector<1584x1584xf32>
35
+ // CHECK: arith.addf %{{.*}}, %{{.*}} : vector<1584x1584xf32>
31
36
linalg.matmul ins (%A , %B: memref <1584 x1584 xf32 >, memref <1584 x1584 xf32 >)
32
37
outs (%C: memref <1584 x1584 xf32 >)
33
38
return
@@ -37,8 +42,9 @@ func @contraction_matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf32>, %
37
42
38
43
// CHECK-LABEL: contraction_batch_matmul
39
44
func @contraction_batch_matmul (%A: memref <1584 x1584 x1584 xf32 >, %B: memref <1584 x1584 x1584 xf32 >, %C: memref <1584 x1584 x1584 xf32 >) {
40
- // CHECK: vector.contract
41
- // CHECK-SAME: vector<1584x1584x1584xf32>, vector<1584x1584x1584xf32> into vector<1584x1584x1584xf32>
45
+ // CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584x1584x1584x1584xf32>
46
+ // CHECK: vector.multi_reduction #vector.kind<add>, %{{.*}} [3] : vector<1584x1584x1584x1584xf32> to vector<1584x1584x1584xf32>
47
+ // CHECK: arith.addf %{{.*}}, %{{.*}} : vector<1584x1584x1584xf32>
42
48
linalg.batch_matmul
43
49
ins (%A , %B: memref <1584 x1584 x1584 xf32 >, memref <1584 x1584 x1584 xf32 >)
44
50
outs (%C: memref <1584 x1584 x1584 xf32 >)
@@ -58,19 +64,15 @@ func @contraction_batch_matmul(%A: memref<1584x1584x1584xf32>, %B: memref<1584x1
58
64
iterator_types = [" parallel" , " parallel" , " reduction" ]
59
65
}
60
66
61
- // CHECK-DAG: #[[$trans_2d:.*]] = affine_map<(d0, d1) -> (d1, d0)>
62
- // CHECK-DAG: #[[$mk:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
63
- // CHECK-DAG: #[[$nk:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
64
- // CHECK-DAG: #[[$mn:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
65
-
66
67
// CHECK-LABEL: func @vectorization_test
67
68
func @vectorization_test (%A: memref <8 x16 xf32 >, %B: memref <16 x32 xf32 >,
68
69
%C: memref <8 x32 xf32 >) {
69
- // CHECK: vector.transfer_read %{{.*}} : memref<8x16xf32>, vector<8x16xf32 >
70
- // CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<32x16xf32 >
70
+ // CHECK: vector.transfer_read %{{.*}} : memref<8x16xf32>, vector<8x32x16xf32 >
71
+ // CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<8x32x16xf32 >
71
72
// CHECK: vector.transfer_read %{{.*}} : memref<8x32xf32>, vector<8x32xf32>
72
- // CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$nk]], #[[$mn]]]
73
- // CHECK-SAME: vector<8x16xf32>, vector<32x16xf32> into vector<8x32xf32>
73
+ // CHECK: %[[MUL:.*]] = arith.mulf %{{.*}}, %{{.*}} : vector<8x32x16xf32>
74
+ // CHECK: %[[R:.*]] = vector.multi_reduction #vector.kind<add>, %[[MUL]] [2] : vector<8x32x16xf32> to vector<8x32xf32>
75
+ // CHECK: arith.addf %[[R]], %{{.*}} : vector<8x32xf32>
74
76
// CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xf32>, memref<8x32xf32>
75
77
linalg.generic #matmul_trait
76
78
ins (%A , %B : memref <8 x16 xf32 >, memref <16 x32 xf32 >)
@@ -96,19 +98,15 @@ func @vectorization_test(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
96
98
iterator_types = [" parallel" , " parallel" , " reduction" ]
97
99
}
98
100
99
- // CHECK-DAG: #[[$trans_2d:.*]] = affine_map<(d0, d1) -> (d1, d0)>
100
- // CHECK-DAG: #[[$mk:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
101
- // CHECK-DAG: #[[$nk:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
102
- // CHECK-DAG: #[[$mn:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
103
-
104
101
// CHECK-LABEL: func @generic_output_transpose
105
102
func @generic_output_transpose (%A: memref <8 x16 xf32 >, %B: memref <16 x32 xf32 >,
106
103
%C: memref <32 x8 xf32 >) {
107
- // CHECK: vector.transfer_read %{{.*}} : memref<8x16xf32>, vector<8x16xf32 >
108
- // CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<32x16xf32 >
104
+ // CHECK: vector.transfer_read %{{.*}} : memref<8x16xf32>, vector<8x32x16xf32 >
105
+ // CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<8x32x16xf32 >
109
106
// CHECK: vector.transfer_read %{{.*}} : memref<32x8xf32>, vector<8x32xf32>
110
- // CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$nk]], #[[$mn]]]
111
- // CHECK-SAME: vector<8x16xf32>, vector<32x16xf32> into vector<8x32xf32>
107
+ // CHECK: %[[MUL:.*]] = arith.mulf %{{.*}}, %{{.*}} : vector<8x32x16xf32>
108
+ // CHECK: %[[R:.*]] = vector.multi_reduction #vector.kind<add>, %[[MUL]] [2] : vector<8x32x16xf32> to vector<8x32xf32>
109
+ // CHECK: arith.addf %[[R]], %{{.*}} : vector<8x32xf32>
112
110
// CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xf32>, memref<32x8xf32>
113
111
linalg.generic #matmul_transpose_out_trait
114
112
ins (%A , %B : memref <8 x16 xf32 >, memref <16 x32 xf32 >)
@@ -134,19 +132,16 @@ func @generic_output_transpose(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
134
132
iterator_types = [" parallel" , " parallel" , " reduction" ]
135
133
}
136
134
137
- // CHECK-DAG: #[[$trans_2d:.*]] = affine_map<(d0, d1) -> (d1, d0)>
138
- // CHECK-DAG: #[[$mk:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
139
- // CHECK-DAG: #[[$nk:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
140
- // CHECK-DAG: #[[$mn:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
141
-
142
135
// CHECK-LABEL: func @vectorization_test_integer
143
136
func @vectorization_test_integer (%A: memref <8 x16 xi32 >, %B: memref <16 x32 xi32 >,
144
137
%C: memref <8 x32 xi32 >) {
145
- // CHECK: vector.transfer_read %{{.*}} : memref<8x16xi32>, vector<8x16xi32 >
146
- // CHECK: vector.transfer_read %{{.*}} : memref<16x32xi32>, vector<32x16xi32 >
138
+ // CHECK: vector.transfer_read %{{.*}} : memref<8x16xi32>, vector<8x32x16xi32 >
139
+ // CHECK: vector.transfer_read %{{.*}} : memref<16x32xi32>, vector<8x32x16xi32 >
147
140
// CHECK: vector.transfer_read %{{.*}} : memref<8x32xi32>, vector<8x32xi32>
148
- // CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$nk]], #[[$mn]]],
149
- // CHECK-SAME: vector<8x16xi32>, vector<32x16xi32> into vector<8x32xi32>
141
+ // CHECK: %[[MUL:.*]] = arith.muli %{{.*}}, %{{.*}} : vector<8x32x16xi32>
142
+ // CHECK: %[[R:.*]] = vector.multi_reduction #vector.kind<add>, %[[MUL]] [2] : vector<8x32x16xi32> to vector<8x32xi32>
143
+ // CHECK: arith.addi %[[R]], %{{.*}} : vector<8x32xi32>
144
+
150
145
// CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xi32>, memref<8x32xi32>
151
146
linalg.generic #matmul_trait
152
147
ins (%A , %B : memref <8 x16 xi32 >, memref <16 x32 xi32 >)
@@ -164,8 +159,9 @@ func @vectorization_test_integer(%A: memref<8x16xi32>, %B: memref<16x32xi32>,
164
159
// CHECK-LABEL: func @vectorization_test_2
165
160
func @vectorization_test_2 (%A: memref <8 x16 xf32 >, %B: memref <16 x32 xf32 >,
166
161
%C: memref <8 x32 xf32 >) {
167
- // CHECK: vector.contract {{.*}} :
168
- // vector<8x16xf32>, vector<16x32xf32> into vector<8x32xf32>
162
+ // CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<8x32x16xf32>
163
+ // CHECK: vector.multi_reduction #vector.kind<add>, %{{.*}} [2] : vector<8x32x16xf32> to vector<8x32xf32>
164
+ // CHECK: arith.addf %{{.*}}, %{{.*}} : vector<8x32xf32>
169
165
linalg.matmul
170
166
ins (%A , %B: memref <8 x16 xf32 >, memref <16 x32 xf32 >)
171
167
outs (%C: memref <8 x32 xf32 >)
@@ -520,19 +516,16 @@ func @matmul_tensors(
520
516
%arg0: tensor <8 x4 xf32 >, %arg1: tensor <4 x12 xf32 >, %arg2: tensor <8 x12 xf32 >)
521
517
-> tensor <8 x12 xf32 > {
522
518
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
523
- // CHECK-DAG: %[[VEC_C0:.*]] = arith.constant dense<0.000000e+00> : vector<8x12xf32>
524
- // CHECK-DAG: %[[V0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x4xf32>, vector<8x4xf32>
525
- // CHECK-DAG: %[[V1:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x12xf32>, vector<12x4xf32>
519
+ // CHECK-DAG: %[[V0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x4xf32>, vector<8x12x4xf32>
520
+ // CHECK-DAG: %[[V1:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x12xf32>, vector<8x12x4xf32>
526
521
// CHECK-DAG: %[[V2:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x12xf32>, vector<8x12xf32>
527
522
//
528
- // linalg contraction lowers to %tmp = vector.contract %a, %b, %c0 followed by addf %c, %tmp.
529
- // a later canonicalization fuses the add into vector.contract.
530
- // CHECK: %[[C:.*]] = vector.contract
531
- // CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
532
- // CHECK-SAME: %[[V0]], %[[V1]], %[[VEC_C0]] :
533
- // CHECK-SAME: vector<8x4xf32>, vector<12x4xf32> into vector<8x12xf32>
534
- // CHECK: %[[C2:.*]] = arith.addf %[[V2]], %[[C]] : vector<8x12xf32>
535
- // CHECK: %[[W:.*]] = vector.transfer_write %[[C2]], %[[ARG2]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x12xf32>, tensor<8x12xf32>
523
+ // linalg matmul lowers gets expanded to a 3D reduction, canonicalization later
524
+ // convert it to a 2D contract.
525
+ // CHECK: %[[MUL:.*]] = arith.mulf %[[V0]], %[[V1]] : vector<8x12x4xf32>
526
+ // CHECK: %[[R:.*]] = vector.multi_reduction #vector.kind<add>, %[[MUL]] [2] : vector<8x12x4xf32> to vector<8x12xf32>
527
+ // CHECK: %[[ADD:.*]] = arith.addf %[[R]], %[[V2]] : vector<8x12xf32>
528
+ // CHECK: %[[W:.*]] = vector.transfer_write %[[ADD]], %[[ARG2]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x12xf32>, tensor<8x12xf32>
536
529
%0 = linalg.matmul ins (%arg0 , %arg1: tensor <8 x4 xf32 >, tensor <4 x12 xf32 >)
537
530
outs (%arg2: tensor <8 x12 xf32 >)
538
531
-> tensor <8 x12 xf32 >
0 commit comments