Skip to content

Commit b66526a

Browse files
vinitdeodharVinit Deodhar
andauthored
Emit tosa::erf during lowering gelu op (llvm#4151)
The PR adds support to use tosa::erf during lowering of aten.gelu op Currently, aten.gelu uses its own custom implementation of erf which was designed before tosa support for tosa::erf : llvm/llvm-project@1fef1f9 aten.erf is already mapped to tosa::ErfOp https://github.com/vinitdeodhar/torch-mlir/blob/c785435a049a694a1814a7304ed0c34abe8b2580/lib/Conversion/TorchToTosa/TorchToTosa.cpp#L9176 The two implementations (Existing aten.gelu erf and tosa::erf) produce numerically different results. tosa::erf is more precise approximation of the erf function. tosa::erf approximation is based on the following stackoverflow post: https://stackoverflow.com/questions/35966695/vectorizable-implementation-of-complementary-error-function-erfcf The stackoverflow post is in turn based on: M. M. Shepherd and J. G. Laframboise, "Chebyshev Approximation of (1+2x)exp(x^2)erfc x in 0 <= x < INF", Mathematics of Computation, Vol. 36, No. 153, January 1981, pp. 249-253. Maximum error: 2.65 ulps --------- Co-authored-by: Vinit Deodhar <vdeodhar@mathworks.com>
1 parent 4bd7d03 commit b66526a

File tree

4 files changed

+62
-72
lines changed

4 files changed

+62
-72
lines changed

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 1 addition & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -3369,77 +3369,6 @@ LogicalResult ConvertAtenOp<AtenViewOp>::matchAndRewrite(
33693369
return success();
33703370
}
33713371

3372-
static std::optional<Value>
3373-
approximateErfOp(ConversionPatternRewriter &rewriter, Operation *op, Value x,
3374-
Type dtype) {
3375-
// Using:
3376-
// https://en.wikipedia.org/wiki/Error_function#Numerical_approximations with
3377-
// maximum error as 5 x 10^-4 where a1 = 0.278393, a2 = 0.230389, a3 =
3378-
// 0.000972, a4 = 0.078108.
3379-
//
3380-
// Erf = 1 - 1 / (1 + a1X + a2X + a3X + a4X)^4
3381-
3382-
auto outType = cast<TensorType>(x.getType());
3383-
auto loc = op->getLoc();
3384-
auto absX = rewriter.create<tosa::AbsOp>(loc, outType, x);
3385-
auto zero = tosa::getConstTensor<float>(rewriter, op, 0, {}, dtype).value();
3386-
auto one = tosa::getConstTensor<float>(rewriter, op, 1, {}, dtype).value();
3387-
auto a1 =
3388-
tosa::getConstTensor<float>(rewriter, op, 0.278393f, {}, dtype).value();
3389-
auto a2 =
3390-
tosa::getConstTensor<float>(rewriter, op, 0.230389f, {}, dtype).value();
3391-
auto a3 =
3392-
tosa::getConstTensor<float>(rewriter, op, 0.000972f, {}, dtype).value();
3393-
auto a4 =
3394-
tosa::getConstTensor<float>(rewriter, op, 0.078108f, {}, dtype).value();
3395-
3396-
if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), x, zero).failed() ||
3397-
mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), x, one).failed() ||
3398-
mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), x, a1).failed() ||
3399-
mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), x, a2).failed() ||
3400-
mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), x, a3).failed() ||
3401-
mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), x, a4).failed())
3402-
return std::nullopt;
3403-
3404-
auto a1X =
3405-
tosa::createMulOpAndCast(rewriter, op, outType, a1, absX, /*shift=*/0);
3406-
auto sum = rewriter.create<tosa::AddOp>(loc, outType, a1X, one);
3407-
3408-
auto x2 =
3409-
tosa::createMulOpAndCast(rewriter, op, outType, absX, absX, /*shift=*/0);
3410-
auto a2X =
3411-
tosa::createMulOpAndCast(rewriter, op, outType, a2, x2, /*shift=*/0);
3412-
sum = rewriter.create<tosa::AddOp>(loc, outType, sum, a2X);
3413-
3414-
auto x3 =
3415-
tosa::createMulOpAndCast(rewriter, op, outType, x2, absX, /*shift=*/0);
3416-
auto a3X =
3417-
tosa::createMulOpAndCast(rewriter, op, outType, a3, x3, /*shift=*/0);
3418-
sum = rewriter.create<tosa::AddOp>(loc, outType, sum, a3X);
3419-
3420-
auto x4 =
3421-
tosa::createMulOpAndCast(rewriter, op, outType, x3, absX, /*shift=*/0);
3422-
auto a4X =
3423-
tosa::createMulOpAndCast(rewriter, op, outType, a4, x4, /*shift=*/0);
3424-
sum = rewriter.create<tosa::AddOp>(loc, outType, sum, a4X);
3425-
3426-
auto rcprl = rewriter.create<tosa::ReciprocalOp>(loc, outType, sum);
3427-
auto rcprl2 = tosa::createMulOpAndCast(rewriter, op, outType, rcprl, rcprl,
3428-
/*shift=*/0);
3429-
auto rcprl4 = tosa::createMulOpAndCast(rewriter, op, outType, rcprl2, rcprl2,
3430-
/*shift=*/0);
3431-
auto erf = rewriter.create<tosa::SubOp>(loc, outType, one, rcprl4);
3432-
3433-
// Deal with negative x.
3434-
auto cond = rewriter.create<tosa::GreaterEqualOp>(
3435-
loc,
3436-
RankedTensorType::get(outType.getShape(), rewriter.getIntegerType(1)), x,
3437-
zero);
3438-
auto negateErf = rewriter.create<tosa::NegateOp>(loc, outType, erf);
3439-
3440-
return rewriter.create<tosa::SelectOp>(loc, outType, cond, erf, negateErf);
3441-
}
3442-
34433372
static std::optional<Value>
34443373
buildUnitNormalCdf(ConversionPatternRewriter &rewriter, Operation *op, Value x,
34453374
Type dtype) {
@@ -3467,7 +3396,7 @@ buildUnitNormalCdf(ConversionPatternRewriter &rewriter, Operation *op, Value x,
34673396
Value erfArg =
34683397
tosa::createMulOpAndCast(rewriter, op, outType, xMinusMean, rsqrt2,
34693398
/*shift=*/0);
3470-
Value erf = approximateErfOp(rewriter, op, erfArg, dtype).value();
3399+
Value erf = rewriter.create<tosa::ErfOp>(loc, outType, erfArg);
34713400
Value erfPlus1 = rewriter.create<tosa::AddOp>(loc, outType, one, erf);
34723401

34733402
Value normalCdf = tosa::createMulOpAndCast(rewriter, op, outType, oneHalf,

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2197,6 +2197,7 @@
21972197
"ElementwiseGeIntScalarModule_basic",
21982198
"ElementwiseGeMixedIntScalarModule_basic",
21992199
"ElementwiseGeluModule_basic",
2200+
"ElementwiseGeluTosaModule_basic",
22002201
"ElementwiseGtFloatScalarModule_basic",
22012202
"ElementwiseGtFloatTensorModule_basic",
22022203
"ElementwiseGtIntScalarModule_basic",

projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1374,6 +1374,30 @@ def ElementwiseGeluModule_basic(module, tu: TestUtils):
13741374
# ==============================================================================
13751375

13761376

1377+
class ElementwiseGeluTosaModule(torch.nn.Module):
1378+
def __init__(self):
1379+
super().__init__()
1380+
1381+
@export
1382+
@annotate_args(
1383+
[
1384+
None,
1385+
([-1, -1], torch.float32, True),
1386+
]
1387+
)
1388+
def forward(self, x):
1389+
x = torch.ops.aten.gelu(x)
1390+
return x
1391+
1392+
1393+
@register_test_case(module_factory=lambda: ElementwiseGeluTosaModule())
1394+
def ElementwiseGeluTosaModule_basic(module, tu: TestUtils):
1395+
module.forward(tu.rand(50, 30, low=-2.7, high=2.7))
1396+
1397+
1398+
# ==============================================================================
1399+
1400+
13771401
class ElementwiseGeluApproximateTanhModule(torch.nn.Module):
13781402
def __init__(self):
13791403
super().__init__()

test/Conversion/TorchToTosa/basic.mlir

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3071,6 +3071,42 @@ func.func @torch.aten.upsample_nearest2d.vec$basic(%arg0: !torch.vtensor<[1,1,4,
30713071

30723072
// -----
30733073

3074+
// CHECK-LABEL: func.func @torch.aten.gelu$none(
3075+
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1500,1536],f32>) -> !torch.vtensor<[1,1500,1536],f32> {
3076+
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,1500,1536],f32> -> tensor<1x1500x1536xf32>
3077+
// CHECK: %[[VAL_2:.*]] = torch.constant.str "none"
3078+
// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>
3079+
// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<1.000000e+00> : tensor<f32>}> : () -> tensor<f32>
3080+
// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{values = dense<5.000000e-01> : tensor<f32>}> : () -> tensor<f32>
3081+
// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{values = dense<0.707106769> : tensor<f32>}> : () -> tensor<f32>
3082+
// CHECK: %[[VAL_7:.*]] = tosa.const_shape {values = dense<1> : tensor<3xindex>} : () -> !tosa.shape<3>
3083+
// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_3]], %[[VAL_7]] : (tensor<f32>, !tosa.shape<3>) -> tensor<1x1x1xf32>
3084+
// CHECK: %[[VAL_9:.*]] = tosa.const_shape {values = dense<1> : tensor<3xindex>} : () -> !tosa.shape<3>
3085+
// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_4]], %[[VAL_9]] : (tensor<f32>, !tosa.shape<3>) -> tensor<1x1x1xf32>
3086+
// CHECK: %[[VAL_11:.*]] = tosa.const_shape {values = dense<1> : tensor<3xindex>} : () -> !tosa.shape<3>
3087+
// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_5]], %[[VAL_11]] : (tensor<f32>, !tosa.shape<3>) -> tensor<1x1x1xf32>
3088+
// CHECK: %[[VAL_13:.*]] = tosa.const_shape {values = dense<1> : tensor<3xindex>} : () -> !tosa.shape<3>
3089+
// CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_6]], %[[VAL_13]] : (tensor<f32>, !tosa.shape<3>) -> tensor<1x1x1xf32>
3090+
// CHECK: %[[VAL_15:.*]] = tosa.sub %[[VAL_1]], %[[VAL_8]] : (tensor<1x1500x1536xf32>, tensor<1x1x1xf32>) -> tensor<1x1500x1536xf32>
3091+
// CHECK: %[[VAL_16:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
3092+
// CHECK: %[[VAL_17:.*]] = tosa.mul %[[VAL_15]], %[[VAL_14]], %[[VAL_16]] : (tensor<1x1500x1536xf32>, tensor<1x1x1xf32>, tensor<1xi8>) -> tensor<1x1500x1536xf32>
3093+
// CHECK: %[[VAL_18:.*]] = tosa.erf %[[VAL_17]] : (tensor<1x1500x1536xf32>) -> tensor<1x1500x1536xf32>
3094+
// CHECK: %[[VAL_19:.*]] = tosa.add %[[VAL_10]], %[[VAL_18]] : (tensor<1x1x1xf32>, tensor<1x1500x1536xf32>) -> tensor<1x1500x1536xf32>
3095+
// CHECK: %[[VAL_20:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
3096+
// CHECK: %[[VAL_21:.*]] = tosa.mul %[[VAL_12]], %[[VAL_19]], %[[VAL_20]] : (tensor<1x1x1xf32>, tensor<1x1500x1536xf32>, tensor<1xi8>) -> tensor<1x1500x1536xf32>
3097+
// CHECK: %[[VAL_22:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
3098+
// CHECK: %[[VAL_23:.*]] = tosa.mul %[[VAL_1]], %[[VAL_21]], %[[VAL_22]] : (tensor<1x1500x1536xf32>, tensor<1x1500x1536xf32>, tensor<1xi8>) -> tensor<1x1500x1536xf32>
3099+
// CHECK: %[[VAL_24:.*]] = torch_c.from_builtin_tensor %[[VAL_23]] : tensor<1x1500x1536xf32> -> !torch.vtensor<[1,1500,1536],f32>
3100+
// CHECK: return %[[VAL_24]] : !torch.vtensor<[1,1500,1536],f32>
3101+
// CHECK: }
3102+
func.func @torch.aten.gelu$none(%arg0: !torch.vtensor<[1,1500,1536],f32>) -> !torch.vtensor<[1,1500,1536],f32> {
3103+
%str = torch.constant.str "none"
3104+
%0 = torch.aten.gelu %arg0, %str : !torch.vtensor<[1,1500,1536],f32>, !torch.str -> !torch.vtensor<[1,1500,1536],f32>
3105+
return %0 : !torch.vtensor<[1,1500,1536],f32>
3106+
}
3107+
3108+
// -----
3109+
30743110
// CHECK-LABEL: func.func @torch.aten.gelu$tanh(
30753111
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[5,3],f32>) -> !torch.vtensor<[5,3],f32> {
30763112
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[5,3],f32> -> tensor<5x3xf32>

0 commit comments

Comments
 (0)