@@ -100,6 +100,35 @@ func.func private @collapsable_memref(%arg0: memref<1x24x32x8xf32>, %arg1: memre
100
100
101
101
// -----
102
102
103
+ // CHECK-DAG: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0 * 7680 + d1 * 320 + d2 * 10 + d3)>
104
+ // CHECK-DAG: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
105
+
106
+ // CHECK-LABEL: func.func @collapsable_memref_projected_ops(
107
+ // CHECK-SAME: %[[ARG0:.*]]: memref<1x24x32x8xf32>, %[[ARG1:.*]]: memref<1x24x32x8xf32>, %[[ARG2:.*]]: memref<1x24x32x8xf32, #[[$ATTR_0]]>) {
108
+ // CHECK: %[[VAL_0:.*]] = memref.collapse_shape %[[ARG0]] {{\[\[}}0], [1, 2], [3]] : memref<1x24x32x8xf32> into memref<1x768x8xf32>
109
+ // CHECK: %[[VAL_1:.*]] = memref.collapse_shape %[[ARG1]] {{\[\[}}0], [1, 2], [3]] : memref<1x24x32x8xf32> into memref<1x768x8xf32>
110
+ // CHECK: %[[VAL_2:.*]] = memref.collapse_shape %[[ARG2]] {{\[\[}}0], [1, 2], [3]] : memref<1x24x32x8xf32, #[[$ATTR_0]]> into memref<1x768x8xf32, strided<[7680, 10, 1]>>
111
+ // CHECK: linalg.generic {indexing_maps = [#[[$ATTR_1]], #[[$ATTR_1]], #[[$ATTR_1]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[VAL_0]], %[[VAL_1]] : memref<1x768x8xf32>, memref<1x768x8xf32>) outs(%[[VAL_2]] : memref<1x768x8xf32, strided<[7680, 10, 1]>>) {
112
+ // CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32, %[[VAL_5:.*]]: f32):
113
+ // CHECK: %[[VAL_6:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f32
114
+ // CHECK: linalg.yield %[[VAL_6]] : f32
115
+ // CHECK: }
116
+ // CHECK: return
117
+ // CHECK: }
118
+
119
+ #map = affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d2 , d3 , d1 )>
120
+ #map1 = affine_map <(d0 , d1 , d2 , d3 ) -> (d0 * 7680 + d1 * 320 + d2 * 10 + d3 )>
121
+ func.func @collapsable_memref_projected_ops (%arg0: memref <1 x24 x32 x8 xf32 >, %arg1: memref <1 x24 x32 x8 xf32 >, %arg2: memref <1 x24 x32 x8 xf32 , #map1 >) {
122
+ linalg.generic {index ing_maps = [#map , #map , #map ], iterator_types = [" parallel" , " parallel" , " parallel" , " parallel" ]} ins (%arg0 , %arg1 : memref <1 x24 x32 x8 xf32 >, memref <1 x24 x32 x8 xf32 >) outs (%arg2 : memref <1 x24 x32 x8 xf32 , #map1 >) {
123
+ ^bb0 (%in: f32 , %in_0: f32 , %out: f32 ):
124
+ %0 = arith.addf %in , %in_0 : f32
125
+ linalg.yield %0 : f32
126
+ }
127
+ return
128
+ }
129
+
130
+ // -----
131
+
103
132
// CHECK-LABEL: func @uncollapsable_strided_memref(
104
133
// CHECK: linalg.generic
105
134
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
@@ -119,6 +148,23 @@ func.func @uncollapsable_strided_memref(%arg0: memref<2x6x24x48xi32>, %arg1: mem
119
148
120
149
// -----
121
150
151
+ // CHECK-LABEL: func @uncollapsable_memref_projected_ops(
152
+ // CHECK: linalg.generic
153
+ // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
154
+
155
+ #map = affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d2 , d3 , d1 )>
156
+ #map1 = affine_map <(d0 , d1 , d2 , d3 ) -> (d0 * 7680 + d1 * 320 + d2 * 8 + d3 )>
157
+ func.func @uncollapsable_memref_projected_ops (%arg0: memref <1 x24 x32 x8 xf32 >, %arg1: memref <1 x24 x32 x8 xf32 >, %arg2: memref <1 x24 x32 x8 xf32 , #map1 >) {
158
+ linalg.generic {index ing_maps = [#map , #map , #map ], iterator_types = [" parallel" , " parallel" , " parallel" , " parallel" ]} ins (%arg0 , %arg1 : memref <1 x24 x32 x8 xf32 >, memref <1 x24 x32 x8 xf32 >) outs (%arg2 : memref <1 x24 x32 x8 xf32 , #map1 >) {
159
+ ^bb0 (%in: f32 , %in_0: f32 , %out: f32 ):
160
+ %0 = arith.addf %in , %in_0 : f32
161
+ linalg.yield %0 : f32
162
+ }
163
+ return
164
+ }
165
+
166
+ // -----
167
+
122
168
// CHECK-LABEL: func.func @linalg_copy(
123
169
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x2x3x4x5xf32, 1 : i64>,
124
170
// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2x3x4x5xf32, 3 : i64>) -> tensor<1x2x3x4x5xf32, 3 : i64> {
0 commit comments