Skip to content

Commit 2ef5bc6

Browse files
committed
Workaround numpy 1.x assert_allclose false-positive result in comparing complex infinities.
1 parent 026f309 commit 2ef5bc6

File tree

2 files changed

+94
-45
lines changed

2 files changed

+94
-45
lines changed

jax/_src/test_util.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1725,22 +1725,12 @@ def negative(self, x):
17251725

17261726
def sqrt(self, x):
17271727
ctx = x.context
1728-
# workaround mpmath bugs:
17291728
if isinstance(x, ctx.mpc):
1730-
if ctx.isinf(x.real) and ctx.isinf(x.imag):
1731-
if x.real > 0: return x
1732-
ninf = x.real
1733-
inf = -ninf
1734-
if x.imag > 0: return ctx.make_mpc((inf._mpf_, inf._mpf_))
1735-
return ctx.make_mpc((inf._mpf_, inf._mpf_))
1736-
elif ctx.isfinite(x.real) and ctx.isinf(x.imag):
1737-
if x.imag > 0:
1738-
inf = x.imag
1739-
return ctx.make_mpc((inf._mpf_, inf._mpf_))
1740-
else:
1741-
ninf = x.imag
1742-
inf = -ninf
1743-
return ctx.make_mpc((inf._mpf_, ninf._mpf_))
1729+
# Workaround mpmath 1.3 bug in sqrt(+-inf+-infj) evaluation (see mpmath/mpmath#776).
1730+
# TODO(pearu): remove this function when mpmath 1.4 or newer
1731+
# will be the required test dependency.
1732+
if ctx.isinf(x.imag):
1733+
return ctx.make_mpc((ctx.inf._mpf_, x.imag._mpf_))
17441734
return ctx.sqrt(x)
17451735

17461736
def expm1(self, x):

tests/lax_test.py

Lines changed: 89 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3567,21 +3567,37 @@ def _testOnComplexPlaneWorker(self, name, dtype, kind):
35673567
mposj=(slice(s0 + 3 + s03, s0 + 3 + 2 * s03), s1 + 1),
35683568
)
35693569

3570+
# The regions are split to real and imaginary parts (of function
3571+
# return values) to (i) workaround numpy 1.x assert_allclose bug
3572+
# in comparing complex infinities, and (ii) expose more details
3573+
# about failing cases:
3574+
s_dict_parts = dict()
3575+
for k, v in s_dict.items():
3576+
s_dict_parts[k + '.real'] = v
3577+
s_dict_parts[k + '.imag'] = v
3578+
35703579
# Start with an assumption that all regions are problematic for a
35713580
# particular function:
3572-
regions_with_inaccuracies = list(s_dict)
3581+
regions_with_inaccuracies = list(s_dict_parts)
35733582

35743583
# Next, we'll remove non-problematic regions from the
35753584
# regions_with_inaccuracies list by explicitly keeping problematic
35763585
# regions:
35773586
def regions_with_inaccuracies_keep(*to_keep):
3587+
to_keep_parts = []
3588+
for r in to_keep:
3589+
if r.endswith('.real') or r.endswith('.imag'):
3590+
to_keep_parts.append(r)
3591+
else:
3592+
to_keep_parts.append(r + '.real')
3593+
to_keep_parts.append(r + '.imag')
35783594
for item in regions_with_inaccuracies[:]:
3579-
if item not in to_keep:
3595+
if item not in to_keep_parts:
35803596
regions_with_inaccuracies.remove(item)
35813597

35823598
if name == 'absolute':
35833599
if is_cuda and dtype == np.complex128:
3584-
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4')
3600+
regions_with_inaccuracies_keep('q1.real', 'q2.real', 'q3.real', 'q4.real')
35853601
else:
35863602
regions_with_inaccuracies.clear()
35873603

@@ -3590,95 +3606,122 @@ def regions_with_inaccuracies_keep(*to_keep):
35903606

35913607
elif name == 'square':
35923608
if is_cuda:
3593-
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'ninf', 'pinf', 'ninfj', 'pinfj')
3609+
regions_with_inaccuracies_keep('q1.real', 'q2.real', 'q3.real', 'q4.real', 'ninf.real', 'pinf.real', 'ninfj.real', 'pinfj.real')
35943610
if is_cpu:
3595-
regions_with_inaccuracies_keep('ninf', 'pinf')
3611+
regions_with_inaccuracies_keep('ninf.real', 'pinf.real', 'q1.real', 'q2.real', 'q3.real', 'q4.real')
35963612

35973613
elif name == 'log':
3598-
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'ninf', 'pinf', 'ninfj', 'pinfj')
3614+
regions_with_inaccuracies_keep('q1.real', 'q2.real', 'q3.real', 'q4.real', 'ninf.imag', 'pinf.imag', 'ninfj.imag', 'pinfj.imag')
35993615

36003616
elif name == 'log10':
3601-
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'ninf', 'pinf', 'ninfj', 'pinfj', 'zero')
3617+
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'ninf.imag', 'pinf.imag', 'ninfj.imag', 'pinfj.imag', 'zero.imag')
36023618

36033619
elif name == 'log1p':
3604-
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj')
3620+
regions_with_inaccuracies_keep('q1.real', 'q2.real', 'q3.real', 'q4.real', 'neg.real', 'pos.real',
3621+
'negj.real', 'posj.real', 'ninf.real', 'ninfj.real', 'pinfj.real')
36053622
# TODO(pearu): after landing openxla/xla#10503, switch to
36063623
# regions_with_inaccuracies_keep('ninf', 'pinf', 'ninfj', 'pinfj')
36073624

36083625
elif name == 'exp':
3609-
regions_with_inaccuracies_keep('pos', 'pinf', 'mpos')
3626+
regions_with_inaccuracies_keep('pos.imag', 'pinf.imag', 'mpos.imag')
36103627

36113628
elif name == 'exp2':
36123629
if dtype == np.complex64:
3613-
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'mq1', 'mq2', 'mq3', 'mq4', 'mpos', 'mnegj', 'mposj')
3630+
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'pos.imag', 'negj', 'posj', 'ninf', 'pinf', 'mq1', 'mq2', 'mq3', 'mq4', 'mpos.imag', 'mnegj', 'mposj')
36143631
if dtype == np.complex128:
3615-
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'mpos')
3632+
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'pos.imag', 'negj', 'posj', 'ninf', 'pinf', 'mpos.imag')
36163633

36173634
elif name == 'expm1' and xla_extension_version < 250:
36183635
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'pinf', 'mq1', 'mq2', 'mq3', 'mq4', 'mneg', 'mpos')
36193636

36203637
elif name == 'sinc':
3621-
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'mq1', 'mq2', 'mq3', 'mq4', 'mneg', 'mpos', 'mnegj', 'mposj')
3638+
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'mq1', 'mq2', 'mq3', 'mq4',
3639+
'mneg.real', 'mpos.real', 'mnegj', 'mposj',
3640+
'ninf.imag', 'pinf.imag', 'ninfj.real', 'pinfj.real')
36223641

36233642
elif name == 'tan':
3624-
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'negj', 'posj', 'ninfj', 'pinfj', 'mq1', 'mq2', 'mq3', 'mq4', 'mnegj', 'mposj')
3643+
regions_with_inaccuracies_keep('q1.imag', 'q2.imag', 'q3.imag', 'q4.imag', 'negj.imag', 'posj.imag',
3644+
'ninfj.imag', 'pinfj.imag', 'mq1.imag', 'mq2.imag', 'mq3.imag', 'mq4.imag', 'mnegj.imag', 'mposj.imag',
3645+
'ninf.imag', 'pinf.imag')
36253646

36263647
elif name == 'sinh':
36273648
if is_cuda:
3628-
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'ninf', 'pinf', 'mq1', 'mq2', 'mq3', 'mq4', 'mneg', 'mpos')
3649+
regions_with_inaccuracies_keep('q1.real', 'q2.real', 'q3.real', 'q4.real', 'neg', 'pos',
3650+
'ninf.imag', 'pinf.imag', 'mq1.real', 'mq2.real', 'mq3.real', 'mq4.real', 'mneg', 'mpos',
3651+
'ninfj.real', 'pinfj.real')
36293652
if is_cpu:
3630-
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'mq1', 'mq2', 'mq3', 'mq4', 'mneg', 'mpos')
3631-
3653+
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj.imag', 'posj.imag', 'ninf.imag', 'pinf.imag',
3654+
'mq1.real', 'mq2.real', 'mq3.real', 'mq4.real', 'mneg', 'mpos',
3655+
'ninfj.real', 'pinfj.real')
36323656
elif name == 'cosh':
3633-
regions_with_inaccuracies_keep('neg', 'pos', 'ninf', 'pinf', 'mneg', 'mpos')
3657+
regions_with_inaccuracies_keep('neg.imag', 'pos.imag', 'ninf.imag', 'pinf.imag', 'mneg.imag', 'mpos.imag',
3658+
'ninfj.imag', 'pinfj.imag')
36343659

36353660
elif name == 'tanh':
36363661
regions_with_inaccuracies_keep('ninf', 'pinf', 'ninfj', 'pinfj')
36373662

36383663
elif name == 'arccos':
36393664
if dtype == np.complex64:
3640-
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq2', 'mq3', 'mneg', 'mpos', 'mnegj')
3665+
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq2', 'mq3', 'mneg',
3666+
'mpos.imag', 'mnegj')
36413667
if dtype == np.complex128:
3642-
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq2', 'mq3', 'mq4', 'mneg', 'mpos', 'mnegj')
3668+
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq2', 'mq3', 'mq4', 'mneg', 'mpos.imag', 'mnegj')
36433669

36443670
elif name == 'arccosh':
36453671
if dtype == np.complex64:
3646-
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq2', 'mq3', 'mneg', 'mpos', 'mnegj')
3672+
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj.real', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq2', 'mq3', 'mneg', 'mpos.imag', 'mnegj')
36473673
if dtype == np.complex128:
3648-
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq2', 'mq3', 'mq4', 'mneg', 'mnegj')
3674+
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos.real', 'negj', 'posj.real', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq2', 'mq3', 'mq4', 'mneg', 'mnegj')
36493675

36503676
elif name == 'arcsin':
36513677
if dtype == np.complex64:
36523678
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq1', 'mq2', 'mq3', 'mq4', 'mneg', 'mpos', 'mnegj', 'mposj')
36533679
if dtype == np.complex128:
3654-
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq1', 'mq2', 'mq3', 'mq4', 'mneg', 'mpos', 'mnegj', 'mposj')
3680+
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj',
3681+
'mq1', 'mq2', 'mq3', 'mq4', 'mneg.imag', 'mpos.imag', 'mnegj', 'mposj')
36553682

36563683
elif name == 'arcsinh':
36573684
if dtype == np.complex64:
3658-
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq1', 'mq2', 'mq3', 'mq4', 'mneg', 'mpos', 'mnegj')
3685+
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg.real', 'pos.real', 'negj', 'posj.real', 'ninf', 'pinf', 'ninfj', 'pinfj',
3686+
'mq1.real', 'mq2', 'mq3', 'mq4.real', 'mneg.real', 'mpos.real', 'mnegj')
36593687
if dtype == np.complex128:
3660-
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq2', 'mq3', 'mneg', 'mnegj')
3688+
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg.real', 'pos.real', 'negj', 'posj.real', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq2', 'mq3', 'mneg.real', 'mnegj')
36613689

36623690
elif name == 'arctan':
36633691
if dtype == np.complex64:
3664-
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq1', 'mq2', 'mq3', 'mq4', 'mnegj', 'mposj')
3692+
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj',
3693+
'mq1.imag', 'mq2.imag', 'mq3.imag', 'mq4.imag', 'mnegj.imag', 'mposj.imag')
36653694
if dtype == np.complex128:
36663695
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj')
36673696

36683697
elif name == 'arctanh':
3669-
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mpos')
3698+
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mpos.imag')
36703699
# TODO(pearu): after landing openxla/xla#10503, switch to
36713700
# regions_with_inaccuracies_keep('pos', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mpos')
3701+
3702+
elif name in {'cos', 'sin'}:
3703+
regions_with_inaccuracies_keep('ninf.imag', 'pinf.imag')
3704+
36723705
elif name in {'positive', 'negative', 'conjugate', 'sin', 'cos', 'sqrt', 'expm1'}:
36733706
regions_with_inaccuracies.clear()
36743707
else:
36753708
assert 0 # unreachable
36763709

36773710
# Finally, perform the closeness tests per region:
36783711
unexpected_success_regions = []
3679-
for region_name, region_slice in s_dict.items():
3712+
for region_name, region_slice in s_dict_parts.items():
36803713
region = args[0][region_slice]
3681-
inexact_indices = np.where(normalized_result[region_slice] != normalized_expected[region_slice])
3714+
if region_name.endswith('.real'):
3715+
result_slice, expected_slice = result[region_slice].real, expected[region_slice].real
3716+
normalized_result_slice, normalized_expected_slice = normalized_result[region_slice].real, normalized_expected[region_slice].real
3717+
elif region_name.endswith('.imag'):
3718+
result_slice, expected_slice = result[region_slice].imag, expected[region_slice].imag
3719+
normalized_result_slice, normalized_expected_slice = normalized_result[region_slice].imag, normalized_expected[region_slice].imag
3720+
else:
3721+
result_slice, expected_slice = result[region_slice], expected[region_slice]
3722+
normalized_result_slice, normalized_expected_slice = normalized_result[region_slice], normalized_expected[region_slice]
3723+
3724+
inexact_indices = np.where(normalized_result_slice != normalized_expected_slice)
36823725

36833726
if inexact_indices[0].size == 0:
36843727
inexact_samples = ''
@@ -3697,20 +3740,36 @@ def regions_with_inaccuracies_keep(*to_keep):
36973740
if kind == 'success' and region_name not in regions_with_inaccuracies:
36983741
with jtu.ignore_warning(category=RuntimeWarning, message="overflow encountered in.*"):
36993742
self.assertAllClose(
3700-
normalized_result[region_slice], normalized_expected[region_slice], atol=atol,
3743+
normalized_result_slice, normalized_expected_slice, atol=atol,
37013744
err_msg=f"{name} in {region_name}, {is_cpu=} {is_cuda=}, {xla_extension_version=}\n{inexact_samples}")
37023745

37033746
if kind == 'failure' and region_name in regions_with_inaccuracies:
37043747
try:
37053748
with self.assertRaises(AssertionError, msg=f"{name} in {region_name}, {is_cpu=} {is_cuda=}, {xla_extension_version=}"):
37063749
with jtu.ignore_warning(category=RuntimeWarning, message="overflow encountered in.*"):
3707-
self.assertAllClose(normalized_result[region_slice], normalized_expected[region_slice])
3750+
self.assertAllClose(normalized_result_slice, normalized_expected_slice)
37083751
except AssertionError as msg:
37093752
if str(msg).startswith('AssertionError not raised'):
37103753
unexpected_success_regions.append(region_name)
37113754
else:
37123755
raise # something else is wrong..
37133756

3757+
def eliminate_parts(seq):
3758+
# replace n.real and n.imag items in seq with n.
3759+
result = []
3760+
for part_name in seq:
3761+
name = part_name.split('.')[0]
3762+
if name in result:
3763+
continue
3764+
if name + '.real' in seq and name + '.imag' in seq:
3765+
result.append(name)
3766+
else:
3767+
result.append(part_name)
3768+
return result
3769+
3770+
regions_with_inaccuracies = eliminate_parts(regions_with_inaccuracies)
3771+
unexpected_success_regions = eliminate_parts(unexpected_success_regions)
3772+
37143773
if kind == 'success' and regions_with_inaccuracies:
37153774
reason = "xfail: problematic regions: " + ", ".join(regions_with_inaccuracies)
37163775
raise unittest.SkipTest(reason)

0 commit comments

Comments
 (0)