Skip to content

Commit 29858b4

Browse files
authored
Adjust tolerance for fp16 exp & gelu ops test to handle reasonable calculation discrepancies (#12150)
### Summary This PR improves the exp_fp16 and gelu_fp16 tests by using a dynamic tolerance strategy similar to the XNNPACK tolerance calculation for validating float16 exponential kernels. Instead of relying on fixed absolute and relative tolerances, the test now calculates acceptable error bounds based on the output magnitude and float16 precision constraints. This change ensures correctness while accommodating the inherent limitations of float16 arithmetic. ### Problem While testing the float16 exponential kernel from XNNPACK against PyTorch's eager-mode implementation, sparse errors occured. The failures were due to small mismatches between the output values, often in the range of ~0.01 to ~0.015. These discrepancies occurred despite both outputs being reasonably close when viewed through the lens of float16 precision. The original test used fixed tolerance values (atol=1e-3, rtol=1e-3), which were too strict for float16 results, particularly for inputs that produced large exponentials. ### Investigation To understand the failures, I traced specific cases where discrepancies occurred. For example, for the input 2.2715, PyTorch computes exp(2.2715) in float32 and rounds the result to float16, yielding 9.6953. In contrast, XNNPACK uses float16-only arithmetic throughout its kernel, computing a slightly lower value of 9.6797. The difference between the two outputs is exactly 0.0156, which corresponds to one ULP (unit in the last place) at that magnitude in float16. This led me to examine the structure of float16 and its numerical limits in detail. Further analysis revealed that IEEE 754 half-precision floating point (float16) has a limited resolution — only 10 bits for the significand — meaning the spacing between representable values increases with magnitude. Near 1.0, the ULP is about 0.00098, but near 9.7, it rises to 0.0156. Given this, it became clear that small absolute differences in the output were not only expected but within the bounds of what float16 can actually represent. To confirm the root cause, I reviewed the XNNPACK source code and documentation. Their float16 exponential kernel uses a 2^z * 2^r decomposition and evaluates a degree-3 polynomial using multiple steps of float16 arithmetic exclusively, which introduces a lot of error. More importantly, I found that XNNPACK’s own test infrastructure accepts outputs within a mixed tolerance of 2 × ε absolute and 6 × ε relative error, where ε ≈ 9.77e-4 is the machine epsilon for float16. This tolerance model is defined by their TolMixed function and effectively allows up to ~6 ULPs of error, depending on the output value. ### Solution This PR updates the exp_fp16 and gelu_fp16 tests to use the same tolerance policy as XNNPACK. For float16 inputs, the test now computes the reference output using float32 precision, then applies the following tolerance calculation: Absolute tolerance: 2 × ε ≈ 0.00195 Relative tolerance: 6 × ε ≈ 0.00586 Final tolerance per output: max(atol, rtol × |y_ref|) ### Test plan I tested this by adding the new rtol and atol values to the test suite and running the tests with various random inputs to ensure that the tests pass.
1 parent f11e4d3 commit 29858b4

File tree

2 files changed

+56
-5
lines changed

2 files changed

+56
-5
lines changed

backends/xnnpack/test/ops/test_exp.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,23 @@
1010
from executorch.backends.xnnpack.test.tester import Tester
1111

1212

13+
def calculate_fp16_exp_tolerance(ref_output_tensor):
14+
# Calculate mixed tolerance for float16 used in XNNPACK's float16 policy
15+
fp16_epsilon = 9.77e-4
16+
abs_tol = 2 * fp16_epsilon
17+
rel_tol = 6 * fp16_epsilon
18+
19+
ref_abs = ref_output_tensor.abs()
20+
mixed_tol = torch.maximum(
21+
torch.full_like(ref_abs, abs_tol),
22+
ref_abs * rel_tol,
23+
)
24+
25+
final_atol = mixed_tol.max().item()
26+
27+
return final_atol, rel_tol
28+
29+
1330
class TestExp(unittest.TestCase):
1431
def setUp(self):
1532
torch._dynamo.reset()
@@ -22,6 +39,16 @@ def forward(self, x):
2239
return torch.exp(x)
2340

2441
def run_exp_test(self, inputs):
42+
input_tensor = inputs[0]
43+
44+
if input_tensor.dtype == torch.float16:
45+
with torch.no_grad():
46+
ref_output = torch.exp(input_tensor.to(torch.float32)).to(torch.float16)
47+
atol, rtol = calculate_fp16_exp_tolerance(ref_output)
48+
else:
49+
atol = 1e-03
50+
rtol = 1e-03
51+
2552
(
2653
Tester(self.Exp(), inputs)
2754
.export()
@@ -31,12 +58,9 @@ def run_exp_test(self, inputs):
3158
.check_not(["executorch_exir_dialects_edge__ops_aten_exp_default"])
3259
.to_executorch()
3360
.serialize()
34-
.run_method_and_compare_outputs()
61+
.run_method_and_compare_outputs(atol=atol, rtol=rtol)
3562
)
3663

37-
# TODO (leafs1): Fix flaky tests. Land fix asap
38-
# and cherry-pick onto release/0.7 branch
39-
@unittest.skip(reason="For float16, numerical discepancies are too high")
4064
def test_fp16_exp(self):
4165
inputs = (torch.randn(20).to(torch.float16),)
4266
self.run_exp_test(inputs)

backends/xnnpack/test/ops/test_gelu.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,21 @@
1010
from executorch.backends.xnnpack.test.tester import Tester
1111

1212

13+
def calculate_fp16_gelu_tolerance(ref_output_tensor):
14+
fp16_epsilon = 9.77e-4
15+
abs_tol = 2 * fp16_epsilon
16+
rel_tol = 6 * fp16_epsilon
17+
18+
ref_abs = ref_output_tensor.abs()
19+
mixed_tol = torch.maximum(
20+
torch.full_like(ref_abs, abs_tol),
21+
ref_abs * rel_tol,
22+
)
23+
24+
final_atol = mixed_tol.max().item()
25+
return final_atol, rel_tol
26+
27+
1328
class TestGelu(unittest.TestCase):
1429
def setUp(self):
1530
torch._dynamo.reset()
@@ -23,6 +38,18 @@ def forward(self, x):
2338
return self.gelu(x)
2439

2540
def run_gelu_test(self, inputs):
41+
input_tensor = inputs[0]
42+
43+
if input_tensor.dtype == torch.float16:
44+
with torch.no_grad():
45+
ref_output = torch.nn.functional.gelu(
46+
input_tensor.to(torch.float32)
47+
).to(torch.float16)
48+
atol, rtol = calculate_fp16_gelu_tolerance(ref_output)
49+
else:
50+
atol = 1e-03
51+
rtol = 1e-03
52+
2653
(
2754
Tester(self.Gelu(), inputs)
2855
.export()
@@ -32,7 +59,7 @@ def run_gelu_test(self, inputs):
3259
.check_not(["executorch_exir_dialects_edge__ops_aten_gelu_default"])
3360
.to_executorch()
3461
.serialize()
35-
.run_method_and_compare_outputs()
62+
.run_method_and_compare_outputs(atol=atol, rtol=rtol)
3663
)
3764

3865
def test_fp16_gelu(self):

0 commit comments

Comments
 (0)