Skip to content

Commit 87de451

Browse files
committed
[mlir][Math] Fix NaN handling in ExpM1 approximation.
Differential Revision: https://reviews.llvm.org/D119822
1 parent f35af77 commit 87de451

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1033,8 +1033,8 @@ ExpM1Approximation::matchAndRewrite(math::ExpM1Op op,
10331033
Value cstNegOne = bcast(f32Cst(builder, -1.0f));
10341034
Value x = op.getOperand();
10351035
Value u = builder.create<math::ExpOp>(x);
1036-
Value uEqOne =
1037-
builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, cstOne);
1036+
Value uEqOneOrNaN =
1037+
builder.create<arith::CmpFOp>(arith::CmpFPredicate::UEQ, u, cstOne);
10381038
Value uMinusOne = builder.create<arith::SubFOp>(u, cstOne);
10391039
Value uMinusOneEqNegOne = builder.create<arith::CmpFOp>(
10401040
arith::CmpFPredicate::OEQ, uMinusOne, cstNegOne);
@@ -1050,7 +1050,7 @@ ExpM1Approximation::matchAndRewrite(math::ExpM1Op op,
10501050
uMinusOne, builder.create<arith::DivFOp>(x, logU));
10511051
expm1 = builder.create<arith::SelectOp>(isInf, u, expm1);
10521052
Value approximation = builder.create<arith::SelectOp>(
1053-
uEqOne, x,
1053+
uEqOneOrNaN, x,
10541054
builder.create<arith::SelectOp>(uMinusOneEqNegOne, cstNegOne, expm1));
10551055
rewriter.replaceOp(op, approximation);
10561056
return success();

mlir/test/Dialect/Math/polynomial-approximation.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ func @exp_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
163163
// CHECK-NOT: exp
164164
// CHECK-COUNT-3: select
165165
// CHECK: %[[EXP_X:.*]] = arith.select
166-
// CHECK: %[[VAL_58:.*]] = arith.cmpf oeq, %[[EXP_X]], %[[CST_ONE]] : f32
166+
// CHECK: %[[IS_ONE_OR_NAN:.*]] = arith.cmpf ueq, %[[EXP_X]], %[[CST_ONE]] : f32
167167
// CHECK: %[[VAL_59:.*]] = arith.subf %[[EXP_X]], %[[CST_ONE]] : f32
168168
// CHECK: %[[VAL_60:.*]] = arith.cmpf oeq, %[[VAL_59]], %[[CST_MINUSONE]] : f32
169169
// CHECK-NOT: log
@@ -174,7 +174,7 @@ func @exp_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
174174
// CHECK: %[[VAL_106:.*]] = arith.mulf %[[VAL_59]], %[[VAL_105]] : f32
175175
// CHECK: %[[VAL_107:.*]] = arith.select %[[VAL_104]], %[[EXP_X]], %[[VAL_106]] : f32
176176
// CHECK: %[[VAL_108:.*]] = arith.select %[[VAL_60]], %[[CST_MINUSONE]], %[[VAL_107]] : f32
177-
// CHECK: %[[VAL_109:.*]] = arith.select %[[VAL_58]], %[[X]], %[[VAL_108]] : f32
177+
// CHECK: %[[VAL_109:.*]] = arith.select %[[IS_ONE_OR_NAN]], %[[X]], %[[VAL_108]] : f32
178178
// CHECK: return %[[VAL_109]] : f32
179179
// CHECK: }
180180
func @expm1_scalar(%arg0: f32) -> f32 {

0 commit comments

Comments
 (0)