1
+ // RUN: mlir-opt %s | FileCheck %s
2
+
3
+ // CHECK-LABEL: test_collapse(
4
+ func.func @test_collapse (%arg0: memref <1 x?xf32 , strided <[5 , 1 ]>>) {
5
+ %collapse_shape = memref.collapse_shape %arg0 [[0 , 1 ]] : memref <1 x?xf32 , strided <[5 , 1 ]>> into memref <?xf32 , strided <[1 ]>>
6
+ return
7
+ }
8
+
9
+ // CHECK-LABEL: test_collapse_5d_middle_dynamic(
10
+ func.func @test_collapse_5d_middle_dynamic (%arg0: memref <1 x5 x1 x?x1 xf32 , strided <[540 , 108 , 18 , 2 , 1 ]>>) {
11
+ %collapse_shape = memref.collapse_shape %arg0 [[0 , 1 , 2 , 3 , 4 ]]
12
+ : memref <1 x5 x1 x?x1 xf32 , strided <[540 , 108 , 18 , 2 , 1 ]>> into memref <?xf32 , strided <[?]>>
13
+ return
14
+ }
15
+
16
+ // CHECK-LABEL: test_collapse_5d_mostly_units(
17
+ func.func @test_collapse_5d_mostly_units (%arg0: memref <1 x1 x1 x?x1 xf32 , strided <[320 , 80 , 16 , 2 , 1 ]>>) {
18
+ %collapse_shape = memref.collapse_shape %arg0 [[0 , 1 , 2 , 3 , 4 ]]
19
+ : memref <1 x1 x1 x?x1 xf32 , strided <[320 , 80 , 16 , 2 , 1 ]>> into memref <?xf32 , strided <[2 ]>>
20
+ return
21
+ }
22
+
23
+ // CHECK-LABEL: test_partial_collapse_6d(
24
+ func.func @test_partial_collapse_6d (%arg0: memref <1 x?x1 x1 x5 x1 xf32 , strided <[3360 , 420 , 140 , 35 , 7 , 1 ]>>) {
25
+ %collapse_shape = memref.collapse_shape %arg0 [[0 , 1 , 2 , 3 ], [4 , 5 ]]
26
+ : memref <1 x?x1 x1 x5 x1 xf32 , strided <[3360 , 420 , 140 , 35 , 7 , 1 ]>> into memref <?x5 xf32 , strided <[420 , 7 ]>>
27
+ return
28
+ }
29
+
30
+ // CHECK-LABEL: test_collapse_5d_grouped(
31
+ func.func @test_collapse_5d_grouped (%arg0: memref <1 x5 x1 x?x1 xf32 , strided <[540 , 108 , 18 , 2 , 1 ]>>) {
32
+ %collapse_shape = memref.collapse_shape %arg0 [[0 ], [1 , 2 , 3 , 4 ]]
33
+ : memref <1 x5 x1 x?x1 xf32 , strided <[540 , 108 , 18 , 2 , 1 ]>> into memref <1 x?xf32 , strided <[540 , ?]>>
34
+ return
35
+ }
36
+
37
+ // CHECK-LABEL: test_collapse_all_units(
38
+ func.func @test_collapse_all_units (%arg0: memref <1 x1 x1 x1 x1 xf32 , strided <[100 , 50 , 25 , 10 , 1 ]>>) {
39
+ %collapse_shape = memref.collapse_shape %arg0 [[0 , 1 , 2 , 3 , 4 ]]
40
+ : memref <1 x1 x1 x1 x1 xf32 , strided <[100 , 50 , 25 , 10 , 1 ]>> into memref <1 xf32 , strided <[100 ]>>
41
+ return
42
+ }
0 commit comments