1
1
// RUN: mlir-opt %s -split-input-file -canonicalize | FileCheck %s
2
2
3
- // +----------------------------------------
4
- // Tests of TransposeToShapeCast
5
- // +----------------------------------------
6
-
7
- // CHECK-LABEL: @transpose_to_shape_cast
8
- // CHECK-SAME: %[[ARG0:.*]]: vector<2x1x2xf32>
9
- // CHECK-NEXT: %[[SCAST:.*]] = vector.shape_cast %[[ARG0]]
10
- // CHECK-NEXT: return %[[SCAST]] : vector<2x2x1xf32>
11
- func.func @transpose_to_shape_cast (%arg0 : vector <2 x1 x2 xf32 >) -> vector <2 x2 x1 xf32 > {
12
- %0 = vector.transpose %arg0 , [0 , 2 , 1 ] : vector <2 x1 x2 xf32 > to vector <2 x2 x1 xf32 >
13
- return %0 : vector <2 x2 x1 xf32 >
14
- }
15
-
16
-
17
- // -----
18
-
19
- // CHECK-LABEL: @negative_transpose_to_shape_cast
20
- // CHECK-SAME: %[[ARG0:.*]]: vector<2x1x2xf32>
21
- // CHECK-NEXT: %[[TRANSPOSE:.*]] = vector.transpose %[[ARG0]], [2, 0, 1]
22
- // CHECK-NEXT: return %[[TRANSPOSE]] : vector<2x2x1xf32>
23
- func.func @negative_transpose_to_shape_cast (%arg0 : vector <2 x1 x2 xf32 >) -> vector <2 x2 x1 xf32 > {
24
- %0 = vector.transpose %arg0 , [2 , 0 , 1 ] : vector <2 x1 x2 xf32 > to vector <2 x2 x1 xf32 >
25
- return %0 : vector <2 x2 x1 xf32 >
26
- }
27
-
28
- // -----
29
3
30
4
// +----------------------------------------
31
5
// Tests of BroadcastToShapeCast
@@ -42,16 +16,20 @@ func.func @broadcast_to_shape_cast(%arg0 : vector<4xi8>) -> vector<1x1x4xi8> {
42
16
43
17
// -----
44
18
45
- // CHECK-LABEL: @negative_broadcast_to_shape_cast
19
+ // broadcast can only be transformed to a shape_cast if the number of elements is
20
+ // unchanged by the broadcast
21
+ // CHECK-LABEL: @negative_broadcast_increased_elements_to_shape_cast
46
22
// CHECK-NOT: shape_cast
47
23
// CHECK: return
48
- func.func @negative_broadcast_to_shape_cast (%arg0 : vector <1 x4 xi8 >) -> vector <2 x3 x4 xi8 > {
24
+ func.func @negative_broadcast_increased_elements_to_shape_cast (%arg0 : vector <1 x4 xi8 >) -> vector <2 x3 x4 xi8 > {
49
25
%0 = vector.broadcast %arg0 : vector <1 x4 xi8 > to vector <2 x3 x4 xi8 >
50
26
return %0 : vector <2 x3 x4 xi8 >
51
27
}
52
28
53
29
// -----
54
30
31
+ // shape_cast does not support scalar inputs/outputs, so a broadcast of a scalar
32
+ // cannot be transformed to a shape_cast.
55
33
// CHECK-LABEL: @negative_broadcast_scalar_to_shape_cast
56
34
// CHECK-NOT: shape_cast
57
35
// CHECK: return
@@ -62,56 +40,101 @@ func.func @negative_broadcast_scalar_to_shape_cast(%arg0 : i8) -> vector<1xi8> {
62
40
63
41
// -----
64
42
65
- // The conversion transpose(shape_cast) -> shape_cast is currently disabled for scalable
66
- // vectors.
67
- // CHECK-LABEL: @transpose_of_shape_cast_scalable
68
- // CHECK: vector.shape_cast
69
- // CHECK: vector.transpose
70
- func.func @transpose_of_shape_cast_scalable (%arg : vector <[4 ]xi8 >) -> vector <[4 ]x1 xi8 > {
71
- %0 = vector.shape_cast %arg : vector <[4 ]xi8 > to vector <1 x[4 ]xi8 >
72
- %1 = vector.transpose %0 , [1 , 0 ] : vector <1 x[4 ]xi8 > to vector <[4 ]x1 xi8 >
73
- return %1 : vector <[4 ]x1 xi8 >
43
+ // +----------------------------------------
44
+ // Tests of TransposeToShapeCast
45
+ // +----------------------------------------
46
+
47
+ // In this test, the permutation maps the non-unit dimensions (0 and 2) as follows:
48
+ // 0 -> 0
49
+ // 2 -> 1
50
+ // Because 0 < 1, this permutation is order preserving and effectively a shape_cast.
51
+ // CHECK-LABEL: @transpose_to_shape_cast
52
+ // CHECK-SAME: %[[ARG0:.*]]: vector<2x1x2xf32>
53
+ // CHECK-NEXT: %[[SCAST:.*]] = vector.shape_cast %[[ARG0]]
54
+ // CHECK-NEXT: return %[[SCAST]] : vector<2x2x1xf32>
55
+ func.func @transpose_to_shape_cast (%arg0 : vector <2 x1 x2 xf32 >) -> vector <2 x2 x1 xf32 > {
56
+ %0 = vector.transpose %arg0 , [0 , 2 , 1 ] : vector <2 x1 x2 xf32 > to vector <2 x2 x1 xf32 >
57
+ return %0 : vector <2 x2 x1 xf32 >
74
58
}
75
59
76
60
// -----
77
61
78
- // A transpose that is 'order preserving' can be treated like a shape_cast.
79
- // CHECK-LABEL: @transpose_of_shape_cast
80
- // CHECK-SAME: %[[ARG:.*]]: vector<2x3x1x1xi8>) -> vector<6x1x1xi8> {
62
+ // In this test, the permutation maps the non-unit dimensions (1 and 2) as follows:
63
+ // 1 -> 0
64
+ // 2 -> 4
65
+ // Because 0 < 4, this permutation is order preserving and effectively a shape_cast.
66
+ // CHECK-LABEL: @shape_cast_of_transpose
67
+ // CHECK-SAME: %[[ARG:.*]]: vector<1x4x4x1x1xi8>)
81
68
// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
82
- // CHECK-SAME: vector<2x3x1x1xi8> to vector<6x1x1xi8>
83
- // CHECK: return %[[SHAPE_CAST]] : vector<6x1x1xi8>
84
- func.func @transpose_of_shape_cast (%arg : vector <2 x3 x1 x1 xi8 >) -> vector <6 x1 x1 xi8 > {
85
- %0 = vector.shape_cast %arg : vector <2 x3 x1 x1 xi8 > to vector <6 x1 x1 xi8 >
86
- %1 = vector.transpose %0 , [0 , 2 , 1 ]
87
- : vector <6 x1 x1 xi8 > to vector <6 x1 x1 xi8 >
88
- return %1 : vector <6 x1 x1 xi8 >
69
+ // CHECK-SAME: vector<1x4x4x1x1xi8> to vector<4x1x1x1x4xi8>
70
+ // CHECK: return %[[SHAPE_CAST]]
71
+ func.func @shape_cast_of_transpose (%arg : vector <1 x4 x4 x1 x1 xi8 >) -> vector <4 x1 x1 x1 x4 xi8 > {
72
+ %0 = vector.transpose %arg , [1 , 0 , 3 , 4 , 2 ] : vector <1 x4 x4 x1 x1 xi8 > to vector <4 x1 x1 x1 x4 xi8 >
73
+ return %0 : vector <4 x1 x1 x1 x4 xi8 >
89
74
}
90
75
91
76
// -----
92
77
93
78
// Scalable dimensions should be treated as non-unit dimensions.
94
- // CHECK-LABEL: @transpose_of_shape_cast_scalable
79
+ // CHECK-LABEL: @transpose_scalable_unit
80
+ // CHECK-NOT: shape_cast
81
+ func.func @transpose_scalable_unit (%arg : vector <[1 ]x4 xi8 >) -> vector <4 x[1 ]xi8 > {
82
+ %0 = vector.transpose %arg , [1 , 0 ] : vector <[1 ]x4 xi8 > to vector <4 x[1 ]xi8 >
83
+ return %0 : vector <4 x[1 ]xi8 >
84
+ }
85
+
86
+ // -----
87
+
88
+ // In this test, the mapping of non-unit dimensions (1 and 2) is as follows:
89
+ // 1 -> 2
90
+ // 2 -> 1
91
+ // As this is not increasing (2 > 1), this transpose is not order
92
+ // preserving and cannot be treated as a shape_cast.
93
+ // CHECK-LABEL: @negative_transpose_to_shape_cast
94
+ // CHECK-NOT: shape_cast
95
+ func.func @negative_transpose_to_shape_cast (%arg : vector <1 x4 x4 x1 xi8 >) -> vector <1 x4 x4 x1 xi8 > {
96
+ %0 = vector.transpose %arg , [0 , 2 , 1 , 3 ]
97
+ : vector <1 x4 x4 x1 xi8 > to vector <1 x4 x4 x1 xi8 >
98
+ return %0 : vector <1 x4 x4 x1 xi8 >
99
+ }
100
+
101
+ // -----
102
+
103
+ // Currently the conversion shape_cast(transpose) -> shape_cast is disabled for
104
+ // scalable vectors because of bad interaction with ConvertIllegalShapeCastOpsToTransposes
105
+ // CHECK-LABEL: @negative_shape_cast_of_transpose_scalable
106
+ // CHECK: vector.transpose
107
+ // CHECK: vector.shape_cast
108
+ func.func @negative_shape_cast_of_transpose_scalable (%arg : vector <[4 ]x1 xi8 >) -> vector <[4 ]xi8 > {
109
+ %0 = vector.transpose %arg , [1 , 0 ] : vector <[4 ]x1 xi8 > to vector <1 x[4 ]xi8 >
110
+ %1 = vector.shape_cast %0 : vector <1 x[4 ]xi8 > to vector <[4 ]xi8 >
111
+ return %1 : vector <[4 ]xi8 >
112
+ }
113
+
114
+ // -----
115
+
116
+ // The conversion transpose(shape_cast) -> shape_cast is currently disabled for scalable
117
+ // vectors.
118
+ // CHECK-LABEL: @negative_transpose_of_shape_cast_scalable
95
119
// CHECK: vector.shape_cast
96
120
// CHECK: vector.transpose
97
- func.func @transpose_of_shape_cast_scalable_unit (%arg : vector <[1 ]x 4 x 1 x i8 >) -> vector <4 x[ 1 ]x i8 > {
98
- %0 = vector.shape_cast %arg : vector <[1 ]x 4 x 1 x i8 > to vector <[ 1 ]x 4 x i8 >
99
- %1 = vector.transpose %0 , [1 , 0 ] : vector <[ 1 ]x 4 x i8 > to vector <4 x[ 1 ]x i8 >
100
- return %1 : vector <4 x[ 1 ]x i8 >
121
+ func.func @negative_transpose_of_shape_cast_scalable (%arg : vector <[4 ]x i8 >) -> vector <[ 4 ]x 1 x i8 > {
122
+ %0 = vector.shape_cast %arg : vector <[4 ]x i8 > to vector <1 x[ 4 ]x i8 >
123
+ %1 = vector.transpose %0 , [1 , 0 ] : vector <1 x[ 4 ]x i8 > to vector <[ 4 ]x 1 x i8 >
124
+ return %1 : vector <[ 4 ]x 1 x i8 >
101
125
}
102
126
103
127
// -----
104
128
105
- // Test of shape_cast (not) folding.
106
- // CHECK-LABEL: @negative_transpose_of_shape_cast
107
- // CHECK-SAME: %[[ARG:.*]]: vector<6xi8>) -> vector<2x3xi8> {
108
- // CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
109
- // CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[SHAPE_CAST]]
110
- // CHECK: return %[[TRANSPOSE]] : vector<2x3xi8>
111
- func.func @negative_transpose_of_shape_cast (%arg : vector <6 xi8 >) -> vector <2 x3 xi8 > {
112
- %0 = vector.shape_cast %arg : vector <6 xi8 > to vector <3 x2 xi8 >
113
- %1 = vector.transpose %0 , [1 , 0 ] : vector <3 x2 xi8 > to vector <2 x3 xi8 >
114
- return %1 : vector <2 x3 xi8 >
129
+ // A test where a transpose cannot be transformed to a shape_cast because it is not order
130
+ // preserving
131
+ // CHECK-LABEL: @negative_transpose_to_shape_cast
132
+ // CHECK-SAME: %[[ARG0:.*]]: vector<2x1x2xf32>
133
+ // CHECK-NEXT: %[[TRANSPOSE:.*]] = vector.transpose %[[ARG0]], [2, 0, 1]
134
+ // CHECK-NEXT: return %[[TRANSPOSE]] : vector<2x2x1xf32>
135
+ func.func @negative_transpose_to_shape_cast (%arg0 : vector <2 x1 x2 xf32 >) -> vector <2 x2 x1 xf32 > {
136
+ %0 = vector.transpose %arg0 , [2 , 0 , 1 ] : vector <2 x1 x2 xf32 > to vector <2 x2 x1 xf32 >
137
+ return %0 : vector <2 x2 x1 xf32 >
115
138
}
116
139
117
140
// -----
0 commit comments