Skip to content

Commit 4e1c970

Browse files
tungldcjvolzka
andauthored
Disable the rewriting of Sqrt into Exp and Log (#2088)
Signed-off-by: Tung D. Le <tung@jp.ibm.com> Co-authored-by: Charles Volzka <42243335+cjvolzka@users.noreply.github.com>
1 parent df58b4d commit 4e1c970

File tree

3 files changed

+0
-47
lines changed

3 files changed

+0
-47
lines changed

src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Stick/Stick.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,6 @@ void ZHighStickOp::getCanonicalizationPatterns(
136136
results.insert<StickUnstickSameLayoutRemovalPattern>(context);
137137
results.insert<StickUnstickDiffLayoutRemovalPattern>(context);
138138
results.insert<ReplaceONNXLeakyReluPattern>(context);
139-
results.insert<ReplaceONNXSqrtPattern>(context);
140139
results.insert<ReplaceONNXReciprocalSqrtPattern>(context);
141140
results.insert<ReshapeTransposeReshape2DTo3DSPattern>(context);
142141
results.insert<ReshapeTransposeReshape3DSTo2DPattern>(context);

src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Stick/ZHighStick.td

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -80,17 +80,6 @@ def ReplaceONNXLeakyReluPattern: Pat<
8080
(SameLayout $X, $stickout)]
8181
>;
8282

83-
// Since zDNN does not support Sqrt(X), we calculate it by using zDNN-supported
84-
// operations, i.e. Exp and Log.
85-
// Formulas: `sqrt(X) = exp(log(x)/2) = exp(0.5 * log(x))`
86-
def ReplaceONNXSqrtPattern: Pat<
87-
(ZHighStickOp:$stick (ONNXSqrtOp (ZHighUnstickOp $X)), $layout),
88-
(ZHighExpOp (ZHighMulOp (ZHighLogOp $X, (returnType $X)),
89-
(ZHighStickOp (GetConstantOfType<"0.5"> $X), $layout),
90-
(returnType $X))),
91-
[(IsStaticShapeTensor $X), (SameLayout $X, $stick)]
92-
>;
93-
9483
// Calulation of `1/sqrt(X)` or reciprocal square root is often found in
9584
// deep learning models, but zDNN does not support it. Thus, we rewrite it into
9685
// zDNN-supported operations.

test/mlir/accelerators/nnpa/transform/zhigh-combine.mlir

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -130,41 +130,6 @@ func.func @donot_replace_leakyrelu(%arg0 : tensor<1x104x104x128xf32, #zhigh.layo
130130

131131
// -----
132132

133-
func.func @replace_sqrt(%arg0 : tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>) -> (tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>) {
134-
%0 = "zhigh.Unstick"(%arg0) : (tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<4x256x1xf32>
135-
%1 = "onnx.Sqrt"(%0) : (tensor<4x256x1xf32>) -> tensor<4x256x1xf32>
136-
%2 = "zhigh.Stick"(%1) {layout = "3D"} : (tensor<4x256x1xf32>) -> tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>
137-
return %2 : tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>
138-
139-
// CHECK-LABEL: func.func @replace_sqrt
140-
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>> {
141-
// CHECK-DAG: [[VAR_0_:%.+]] = "zhigh.Log"([[PARAM_0_]]) : (tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>
142-
// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<5.000000e-01> : tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>
143-
// CHECK: [[VAR_2_:%.+]] = "zhigh.Stick"([[VAR_1_]]) {layout = "3D"} : (tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>
144-
// CHECK: [[VAR_3_:%.+]] = "zhigh.Mul"([[VAR_0_]], [[VAR_2_]]) : (tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>, tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>
145-
// CHECK: [[VAR_4_:%.+]] = "zhigh.Exp"([[VAR_3_]]) : (tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>
146-
// CHECK: return [[VAR_4_]] : tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>
147-
// CHECK: }
148-
}
149-
150-
// -----
151-
152-
// Do not replace square root because of unknown dimension.
153-
// In this case, there is no static shape to create a constant of 2.
154-
func.func @donot_replace_sqrt(%arg0 : tensor<?x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>) -> (tensor<?x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>) {
155-
%0 = "zhigh.Unstick"(%arg0) : (tensor<?x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<?x256x1xf32>
156-
%1 = "onnx.Sqrt"(%0) : (tensor<?x256x1xf32>) -> tensor<?x256x1xf32>
157-
%2 = "zhigh.Stick"(%1) {layout = "3D"} : (tensor<?x256x1xf32>) -> tensor<?x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>
158-
return %2 : tensor<?x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>
159-
160-
// CHECK-LABEL: func.func @donot_replace_sqrt
161-
// CHECK: zhigh.Unstick
162-
// CHECK: onnx.Sqrt
163-
// CHECK: zhigh.Stick
164-
}
165-
166-
// -----
167-
168133
func.func @replace_reciprocal_sqrt(%arg0 : tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>) -> (tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>) {
169134
%0 = "zhigh.Unstick"(%arg0) : (tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<4x256x1xf32>
170135
%1 = "onnx.Sqrt"(%0) : (tensor<4x256x1xf32>) -> tensor<4x256x1xf32>

0 commit comments

Comments
 (0)