@@ -130,41 +130,6 @@ func.func @donot_replace_leakyrelu(%arg0 : tensor<1x104x104x128xf32, #zhigh.layo
130
130
131
131
// -----
132
132
133
- func.func @replace_sqrt (%arg0 : tensor <4 x256 x1 xf32 , #zhigh.layout <{dataLayout = " 3D" }>>) -> (tensor <4 x256 x1 xf32 , #zhigh.layout <{dataLayout = " 3D" }>>) {
134
- %0 = " zhigh.Unstick" (%arg0 ) : (tensor <4 x256 x1 xf32 , #zhigh.layout <{dataLayout = " 3D" }>>) -> tensor <4 x256 x1 xf32 >
135
- %1 = " onnx.Sqrt" (%0 ) : (tensor <4 x256 x1 xf32 >) -> tensor <4 x256 x1 xf32 >
136
- %2 = " zhigh.Stick" (%1 ) {layout = " 3D" } : (tensor <4 x256 x1 xf32 >) -> tensor <4 x256 x1 xf32 , #zhigh.layout <{dataLayout = " 3D" }>>
137
- return %2 : tensor <4 x256 x1 xf32 , #zhigh.layout <{dataLayout = " 3D" }>>
138
-
139
- // CHECK-LABEL: func.func @replace_sqrt
140
- // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>> {
141
- // CHECK-DAG: [[VAR_0_:%.+]] = "zhigh.Log"([[PARAM_0_]]) : (tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>
142
- // CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<5.000000e-01> : tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>
143
- // CHECK: [[VAR_2_:%.+]] = "zhigh.Stick"([[VAR_1_]]) {layout = "3D"} : (tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>
144
- // CHECK: [[VAR_3_:%.+]] = "zhigh.Mul"([[VAR_0_]], [[VAR_2_]]) : (tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>, tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>
145
- // CHECK: [[VAR_4_:%.+]] = "zhigh.Exp"([[VAR_3_]]) : (tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>
146
- // CHECK: return [[VAR_4_]] : tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>
147
- // CHECK: }
148
- }
149
-
150
- // -----
151
-
152
- // Do not replace square root because of unknown dimension.
153
- // In this case, there is no static shape to create a constant of 2.
154
- func.func @donot_replace_sqrt (%arg0 : tensor <?x256 x1 xf32 , #zhigh.layout <{dataLayout = " 3D" }>>) -> (tensor <?x256 x1 xf32 , #zhigh.layout <{dataLayout = " 3D" }>>) {
155
- %0 = " zhigh.Unstick" (%arg0 ) : (tensor <?x256 x1 xf32 , #zhigh.layout <{dataLayout = " 3D" }>>) -> tensor <?x256 x1 xf32 >
156
- %1 = " onnx.Sqrt" (%0 ) : (tensor <?x256 x1 xf32 >) -> tensor <?x256 x1 xf32 >
157
- %2 = " zhigh.Stick" (%1 ) {layout = " 3D" } : (tensor <?x256 x1 xf32 >) -> tensor <?x256 x1 xf32 , #zhigh.layout <{dataLayout = " 3D" }>>
158
- return %2 : tensor <?x256 x1 xf32 , #zhigh.layout <{dataLayout = " 3D" }>>
159
-
160
- // CHECK-LABEL: func.func @donot_replace_sqrt
161
- // CHECK: zhigh.Unstick
162
- // CHECK: onnx.Sqrt
163
- // CHECK: zhigh.Stick
164
- }
165
-
166
- // -----
167
-
168
133
func.func @replace_reciprocal_sqrt (%arg0 : tensor <4 x256 x1 xf32 , #zhigh.layout <{dataLayout = " 3D" }>>) -> (tensor <4 x256 x1 xf32 , #zhigh.layout <{dataLayout = " 3D" }>>) {
169
134
%0 = " zhigh.Unstick" (%arg0 ) : (tensor <4 x256 x1 xf32 , #zhigh.layout <{dataLayout = " 3D" }>>) -> tensor <4 x256 x1 xf32 >
170
135
%1 = " onnx.Sqrt" (%0 ) : (tensor <4 x256 x1 xf32 >) -> tensor <4 x256 x1 xf32 >
0 commit comments