@@ -3369,77 +3369,6 @@ LogicalResult ConvertAtenOp<AtenViewOp>::matchAndRewrite(
3369
3369
return success ();
3370
3370
}
3371
3371
3372
- static std::optional<Value>
3373
- approximateErfOp (ConversionPatternRewriter &rewriter, Operation *op, Value x,
3374
- Type dtype) {
3375
- // Using:
3376
- // https://en.wikipedia.org/wiki/Error_function#Numerical_approximations with
3377
- // maximum error as 5 x 10^-4 where a1 = 0.278393, a2 = 0.230389, a3 =
3378
- // 0.000972, a4 = 0.078108.
3379
- //
3380
- // Erf = 1 - 1 / (1 + a1X + a2X + a3X + a4X)^4
3381
-
3382
- auto outType = cast<TensorType>(x.getType ());
3383
- auto loc = op->getLoc ();
3384
- auto absX = rewriter.create <tosa::AbsOp>(loc, outType, x);
3385
- auto zero = tosa::getConstTensor<float >(rewriter, op, 0 , {}, dtype).value ();
3386
- auto one = tosa::getConstTensor<float >(rewriter, op, 1 , {}, dtype).value ();
3387
- auto a1 =
3388
- tosa::getConstTensor<float >(rewriter, op, 0 .278393f , {}, dtype).value ();
3389
- auto a2 =
3390
- tosa::getConstTensor<float >(rewriter, op, 0 .230389f , {}, dtype).value ();
3391
- auto a3 =
3392
- tosa::getConstTensor<float >(rewriter, op, 0 .000972f , {}, dtype).value ();
3393
- auto a4 =
3394
- tosa::getConstTensor<float >(rewriter, op, 0 .078108f , {}, dtype).value ();
3395
-
3396
- if (mlir::tosa::EqualizeRanks (rewriter, op->getLoc (), x, zero).failed () ||
3397
- mlir::tosa::EqualizeRanks (rewriter, op->getLoc (), x, one).failed () ||
3398
- mlir::tosa::EqualizeRanks (rewriter, op->getLoc (), x, a1).failed () ||
3399
- mlir::tosa::EqualizeRanks (rewriter, op->getLoc (), x, a2).failed () ||
3400
- mlir::tosa::EqualizeRanks (rewriter, op->getLoc (), x, a3).failed () ||
3401
- mlir::tosa::EqualizeRanks (rewriter, op->getLoc (), x, a4).failed ())
3402
- return std::nullopt;
3403
-
3404
- auto a1X =
3405
- tosa::createMulOpAndCast (rewriter, op, outType, a1, absX, /* shift=*/ 0 );
3406
- auto sum = rewriter.create <tosa::AddOp>(loc, outType, a1X, one);
3407
-
3408
- auto x2 =
3409
- tosa::createMulOpAndCast (rewriter, op, outType, absX, absX, /* shift=*/ 0 );
3410
- auto a2X =
3411
- tosa::createMulOpAndCast (rewriter, op, outType, a2, x2, /* shift=*/ 0 );
3412
- sum = rewriter.create <tosa::AddOp>(loc, outType, sum, a2X);
3413
-
3414
- auto x3 =
3415
- tosa::createMulOpAndCast (rewriter, op, outType, x2, absX, /* shift=*/ 0 );
3416
- auto a3X =
3417
- tosa::createMulOpAndCast (rewriter, op, outType, a3, x3, /* shift=*/ 0 );
3418
- sum = rewriter.create <tosa::AddOp>(loc, outType, sum, a3X);
3419
-
3420
- auto x4 =
3421
- tosa::createMulOpAndCast (rewriter, op, outType, x3, absX, /* shift=*/ 0 );
3422
- auto a4X =
3423
- tosa::createMulOpAndCast (rewriter, op, outType, a4, x4, /* shift=*/ 0 );
3424
- sum = rewriter.create <tosa::AddOp>(loc, outType, sum, a4X);
3425
-
3426
- auto rcprl = rewriter.create <tosa::ReciprocalOp>(loc, outType, sum);
3427
- auto rcprl2 = tosa::createMulOpAndCast (rewriter, op, outType, rcprl, rcprl,
3428
- /* shift=*/ 0 );
3429
- auto rcprl4 = tosa::createMulOpAndCast (rewriter, op, outType, rcprl2, rcprl2,
3430
- /* shift=*/ 0 );
3431
- auto erf = rewriter.create <tosa::SubOp>(loc, outType, one, rcprl4);
3432
-
3433
- // Deal with negative x.
3434
- auto cond = rewriter.create <tosa::GreaterEqualOp>(
3435
- loc,
3436
- RankedTensorType::get (outType.getShape (), rewriter.getIntegerType (1 )), x,
3437
- zero);
3438
- auto negateErf = rewriter.create <tosa::NegateOp>(loc, outType, erf);
3439
-
3440
- return rewriter.create <tosa::SelectOp>(loc, outType, cond, erf, negateErf);
3441
- }
3442
-
3443
3372
static std::optional<Value>
3444
3373
buildUnitNormalCdf (ConversionPatternRewriter &rewriter, Operation *op, Value x,
3445
3374
Type dtype) {
@@ -3467,7 +3396,7 @@ buildUnitNormalCdf(ConversionPatternRewriter &rewriter, Operation *op, Value x,
3467
3396
Value erfArg =
3468
3397
tosa::createMulOpAndCast (rewriter, op, outType, xMinusMean, rsqrt2,
3469
3398
/* shift=*/ 0 );
3470
- Value erf = approximateErfOp ( rewriter, op , erfArg, dtype). value ( );
3399
+ Value erf = rewriter. create <tosa::ErfOp>(loc, outType , erfArg);
3471
3400
Value erfPlus1 = rewriter.create <tosa::AddOp>(loc, outType, one, erf);
3472
3401
3473
3402
Value normalCdf = tosa::createMulOpAndCast (rewriter, op, outType, oneHalf,
0 commit comments