@@ -3567,21 +3567,37 @@ def _testOnComplexPlaneWorker(self, name, dtype, kind):
3567
3567
mposj = (slice (s0 + 3 + s03 , s0 + 3 + 2 * s03 ), s1 + 1 ),
3568
3568
)
3569
3569
3570
+ # The regions are split to real and imaginary parts (of function
3571
+ # return values) to (i) workaround numpy 1.x assert_allclose bug
3572
+ # in comparing complex infinities, and (ii) expose more details
3573
+ # about failing cases:
3574
+ s_dict_parts = dict ()
3575
+ for k , v in s_dict .items ():
3576
+ s_dict_parts [k + '.real' ] = v
3577
+ s_dict_parts [k + '.imag' ] = v
3578
+
3570
3579
# Start with an assumption that all regions are problematic for a
3571
3580
# particular function:
3572
- regions_with_inaccuracies = list (s_dict )
3581
+ regions_with_inaccuracies = list (s_dict_parts )
3573
3582
3574
3583
# Next, we'll remove non-problematic regions from the
3575
3584
# regions_with_inaccuracies list by explicitly keeping problematic
3576
3585
# regions:
3577
3586
def regions_with_inaccuracies_keep (* to_keep ):
3587
+ to_keep_parts = []
3588
+ for r in to_keep :
3589
+ if r .endswith ('.real' ) or r .endswith ('.imag' ):
3590
+ to_keep_parts .append (r )
3591
+ else :
3592
+ to_keep_parts .append (r + '.real' )
3593
+ to_keep_parts .append (r + '.imag' )
3578
3594
for item in regions_with_inaccuracies [:]:
3579
- if item not in to_keep :
3595
+ if item not in to_keep_parts :
3580
3596
regions_with_inaccuracies .remove (item )
3581
3597
3582
3598
if name == 'absolute' :
3583
3599
if is_cuda and dtype == np .complex128 :
3584
- regions_with_inaccuracies_keep ('q1' , 'q2' , 'q3' , 'q4' )
3600
+ regions_with_inaccuracies_keep ('q1.real ' , 'q2.real ' , 'q3.real ' , 'q4.real ' )
3585
3601
else :
3586
3602
regions_with_inaccuracies .clear ()
3587
3603
@@ -3590,95 +3606,122 @@ def regions_with_inaccuracies_keep(*to_keep):
3590
3606
3591
3607
elif name == 'square' :
3592
3608
if is_cuda :
3593
- regions_with_inaccuracies_keep ('q1' , 'q2' , 'q3' , 'q4' , 'ninf' , 'pinf' , 'ninfj' , 'pinfj' )
3609
+ regions_with_inaccuracies_keep ('q1.real ' , 'q2.real ' , 'q3.real ' , 'q4.real ' , 'ninf.real ' , 'pinf.real ' , 'ninfj.real ' , 'pinfj.real ' )
3594
3610
if is_cpu :
3595
- regions_with_inaccuracies_keep ('ninf' , 'pinf' )
3611
+ regions_with_inaccuracies_keep ('ninf.real ' , 'pinf.real' , 'q1.real' , 'q2.real' , 'q3.real' , 'q4.real ' )
3596
3612
3597
3613
elif name == 'log' :
3598
- regions_with_inaccuracies_keep ('q1' , 'q2' , 'q3' , 'q4' , 'ninf' , 'pinf' , 'ninfj' , 'pinfj' )
3614
+ regions_with_inaccuracies_keep ('q1.real ' , 'q2.real ' , 'q3.real ' , 'q4.real ' , 'ninf.imag ' , 'pinf.imag ' , 'ninfj.imag ' , 'pinfj.imag ' )
3599
3615
3600
3616
elif name == 'log10' :
3601
- regions_with_inaccuracies_keep ('q1' , 'q2' , 'q3' , 'q4' , 'ninf' , 'pinf' , 'ninfj' , 'pinfj' , 'zero' )
3617
+ regions_with_inaccuracies_keep ('q1' , 'q2' , 'q3' , 'q4' , 'ninf.imag ' , 'pinf.imag ' , 'ninfj.imag ' , 'pinfj.imag ' , 'zero.imag ' )
3602
3618
3603
3619
elif name == 'log1p' :
3604
- regions_with_inaccuracies_keep ('q1' , 'q2' , 'q3' , 'q4' , 'neg' , 'pos' , 'negj' , 'posj' , 'ninf' , 'pinf' , 'ninfj' , 'pinfj' )
3620
+ regions_with_inaccuracies_keep ('q1.real' , 'q2.real' , 'q3.real' , 'q4.real' , 'neg.real' , 'pos.real' ,
3621
+ 'negj.real' , 'posj.real' , 'ninf.real' , 'ninfj.real' , 'pinfj.real' )
3605
3622
# TODO(pearu): after landing openxla/xla#10503, switch to
3606
3623
# regions_with_inaccuracies_keep('ninf', 'pinf', 'ninfj', 'pinfj')
3607
3624
3608
3625
elif name == 'exp' :
3609
- regions_with_inaccuracies_keep ('pos' , 'pinf' , 'mpos' )
3626
+ regions_with_inaccuracies_keep ('pos.imag ' , 'pinf.imag ' , 'mpos.imag ' )
3610
3627
3611
3628
elif name == 'exp2' :
3612
3629
if dtype == np .complex64 :
3613
- regions_with_inaccuracies_keep ('q1' , 'q2' , 'q3' , 'q4' , 'pos' , 'negj' , 'posj' , 'ninf' , 'pinf' , 'mq1' , 'mq2' , 'mq3' , 'mq4' , 'mpos' , 'mnegj' , 'mposj' )
3630
+ regions_with_inaccuracies_keep ('q1' , 'q2' , 'q3' , 'q4' , 'pos.imag ' , 'negj' , 'posj' , 'ninf' , 'pinf' , 'mq1' , 'mq2' , 'mq3' , 'mq4' , 'mpos.imag ' , 'mnegj' , 'mposj' )
3614
3631
if dtype == np .complex128 :
3615
- regions_with_inaccuracies_keep ('q1' , 'q2' , 'q3' , 'q4' , 'pos' , 'negj' , 'posj' , 'ninf' , 'pinf' , 'mpos' )
3632
+ regions_with_inaccuracies_keep ('q1' , 'q2' , 'q3' , 'q4' , 'pos.imag ' , 'negj' , 'posj' , 'ninf' , 'pinf' , 'mpos.imag ' )
3616
3633
3617
3634
elif name == 'expm1' and xla_extension_version < 250 :
3618
3635
regions_with_inaccuracies_keep ('q1' , 'q2' , 'q3' , 'q4' , 'neg' , 'pos' , 'pinf' , 'mq1' , 'mq2' , 'mq3' , 'mq4' , 'mneg' , 'mpos' )
3619
3636
3620
3637
elif name == 'sinc' :
3621
- regions_with_inaccuracies_keep ('q1' , 'q2' , 'q3' , 'q4' , 'neg' , 'pos' , 'negj' , 'posj' , 'mq1' , 'mq2' , 'mq3' , 'mq4' , 'mneg' , 'mpos' , 'mnegj' , 'mposj' )
3638
+ regions_with_inaccuracies_keep ('q1' , 'q2' , 'q3' , 'q4' , 'neg' , 'pos' , 'negj' , 'posj' , 'mq1' , 'mq2' , 'mq3' , 'mq4' ,
3639
+ 'mneg.real' , 'mpos.real' , 'mnegj' , 'mposj' ,
3640
+ 'ninf.imag' , 'pinf.imag' , 'ninfj.real' , 'pinfj.real' )
3622
3641
3623
3642
elif name == 'tan' :
3624
- regions_with_inaccuracies_keep ('q1' , 'q2' , 'q3' , 'q4' , 'negj' , 'posj' , 'ninfj' , 'pinfj' , 'mq1' , 'mq2' , 'mq3' , 'mq4' , 'mnegj' , 'mposj' )
3643
+ regions_with_inaccuracies_keep ('q1.imag' , 'q2.imag' , 'q3.imag' , 'q4.imag' , 'negj.imag' , 'posj.imag' ,
3644
+ 'ninfj.imag' , 'pinfj.imag' , 'mq1.imag' , 'mq2.imag' , 'mq3.imag' , 'mq4.imag' , 'mnegj.imag' , 'mposj.imag' ,
3645
+ 'ninf.imag' , 'pinf.imag' )
3625
3646
3626
3647
elif name == 'sinh' :
3627
3648
if is_cuda :
3628
- regions_with_inaccuracies_keep ('q1' , 'q2' , 'q3' , 'q4' , 'neg' , 'pos' , 'ninf' , 'pinf' , 'mq1' , 'mq2' , 'mq3' , 'mq4' , 'mneg' , 'mpos' )
3649
+ regions_with_inaccuracies_keep ('q1.real' , 'q2.real' , 'q3.real' , 'q4.real' , 'neg' , 'pos' ,
3650
+ 'ninf.imag' , 'pinf.imag' , 'mq1.real' , 'mq2.real' , 'mq3.real' , 'mq4.real' , 'mneg' , 'mpos' ,
3651
+ 'ninfj.real' , 'pinfj.real' )
3629
3652
if is_cpu :
3630
- regions_with_inaccuracies_keep ('q1' , 'q2' , 'q3' , 'q4' , 'neg' , 'pos' , 'negj' , 'posj' , 'ninf' , 'pinf' , 'mq1' , 'mq2' , 'mq3' , 'mq4' , 'mneg' , 'mpos' )
3631
-
3653
+ regions_with_inaccuracies_keep ('q1' , 'q2' , 'q3' , 'q4' , 'neg' , 'pos' , 'negj.imag' , 'posj.imag' , 'ninf.imag' , 'pinf.imag' ,
3654
+ 'mq1.real' , 'mq2.real' , 'mq3.real' , 'mq4.real' , 'mneg' , 'mpos' ,
3655
+ 'ninfj.real' , 'pinfj.real' )
3632
3656
elif name == 'cosh' :
3633
- regions_with_inaccuracies_keep ('neg' , 'pos' , 'ninf' , 'pinf' , 'mneg' , 'mpos' )
3657
+ regions_with_inaccuracies_keep ('neg.imag' , 'pos.imag' , 'ninf.imag' , 'pinf.imag' , 'mneg.imag' , 'mpos.imag' ,
3658
+ 'ninfj.imag' , 'pinfj.imag' )
3634
3659
3635
3660
elif name == 'tanh' :
3636
3661
regions_with_inaccuracies_keep ('ninf' , 'pinf' , 'ninfj' , 'pinfj' )
3637
3662
3638
3663
elif name == 'arccos' :
3639
3664
if dtype == np .complex64 :
3640
- regions_with_inaccuracies_keep ('q1' , 'q2' , 'q3' , 'q4' , 'neg' , 'pos' , 'negj' , 'posj' , 'ninf' , 'pinf' , 'ninfj' , 'pinfj' , 'mq2' , 'mq3' , 'mneg' , 'mpos' , 'mnegj' )
3665
+ regions_with_inaccuracies_keep ('q1' , 'q2' , 'q3' , 'q4' , 'neg' , 'pos' , 'negj' , 'posj' , 'ninf' , 'pinf' , 'ninfj' , 'pinfj' , 'mq2' , 'mq3' , 'mneg' ,
3666
+ 'mpos.imag' , 'mnegj' )
3641
3667
if dtype == np .complex128 :
3642
- regions_with_inaccuracies_keep ('q1' , 'q2' , 'q3' , 'q4' , 'neg' , 'pos' , 'negj' , 'posj' , 'ninf' , 'pinf' , 'ninfj' , 'pinfj' , 'mq2' , 'mq3' , 'mq4' , 'mneg' , 'mpos' , 'mnegj' )
3668
+ regions_with_inaccuracies_keep ('q1' , 'q2' , 'q3' , 'q4' , 'neg' , 'pos' , 'negj' , 'posj' , 'ninf' , 'pinf' , 'ninfj' , 'pinfj' , 'mq2' , 'mq3' , 'mq4' , 'mneg' , 'mpos.imag ' , 'mnegj' )
3643
3669
3644
3670
elif name == 'arccosh' :
3645
3671
if dtype == np .complex64 :
3646
- regions_with_inaccuracies_keep ('q1' , 'q2' , 'q3' , 'q4' , 'neg' , 'pos' , 'negj' , 'posj' , 'ninf' , 'pinf' , 'ninfj' , 'pinfj' , 'mq2' , 'mq3' , 'mneg' , 'mpos' , 'mnegj' )
3672
+ regions_with_inaccuracies_keep ('q1' , 'q2' , 'q3' , 'q4' , 'neg' , 'pos' , 'negj' , 'posj.real ' , 'ninf' , 'pinf' , 'ninfj' , 'pinfj' , 'mq2' , 'mq3' , 'mneg' , 'mpos.imag ' , 'mnegj' )
3647
3673
if dtype == np .complex128 :
3648
- regions_with_inaccuracies_keep ('q1' , 'q2' , 'q3' , 'q4' , 'neg' , 'pos' , 'negj' , 'posj' , 'ninf' , 'pinf' , 'ninfj' , 'pinfj' , 'mq2' , 'mq3' , 'mq4' , 'mneg' , 'mnegj' )
3674
+ regions_with_inaccuracies_keep ('q1' , 'q2' , 'q3' , 'q4' , 'neg' , 'pos.real ' , 'negj' , 'posj.real ' , 'ninf' , 'pinf' , 'ninfj' , 'pinfj' , 'mq2' , 'mq3' , 'mq4' , 'mneg' , 'mnegj' )
3649
3675
3650
3676
elif name == 'arcsin' :
3651
3677
if dtype == np .complex64 :
3652
3678
regions_with_inaccuracies_keep ('q1' , 'q2' , 'q3' , 'q4' , 'neg' , 'pos' , 'negj' , 'posj' , 'ninf' , 'pinf' , 'ninfj' , 'pinfj' , 'mq1' , 'mq2' , 'mq3' , 'mq4' , 'mneg' , 'mpos' , 'mnegj' , 'mposj' )
3653
3679
if dtype == np .complex128 :
3654
- regions_with_inaccuracies_keep ('q1' , 'q2' , 'q3' , 'q4' , 'neg' , 'pos' , 'negj' , 'posj' , 'ninf' , 'pinf' , 'ninfj' , 'pinfj' , 'mq1' , 'mq2' , 'mq3' , 'mq4' , 'mneg' , 'mpos' , 'mnegj' , 'mposj' )
3680
+ regions_with_inaccuracies_keep ('q1' , 'q2' , 'q3' , 'q4' , 'neg' , 'pos' , 'negj' , 'posj' , 'ninf' , 'pinf' , 'ninfj' , 'pinfj' ,
3681
+ 'mq1' , 'mq2' , 'mq3' , 'mq4' , 'mneg.imag' , 'mpos.imag' , 'mnegj' , 'mposj' )
3655
3682
3656
3683
elif name == 'arcsinh' :
3657
3684
if dtype == np .complex64 :
3658
- regions_with_inaccuracies_keep ('q1' , 'q2' , 'q3' , 'q4' , 'neg' , 'pos' , 'negj' , 'posj' , 'ninf' , 'pinf' , 'ninfj' , 'pinfj' , 'mq1' , 'mq2' , 'mq3' , 'mq4' , 'mneg' , 'mpos' , 'mnegj' )
3685
+ regions_with_inaccuracies_keep ('q1' , 'q2' , 'q3' , 'q4' , 'neg.real' , 'pos.real' , 'negj' , 'posj.real' , 'ninf' , 'pinf' , 'ninfj' , 'pinfj' ,
3686
+ 'mq1.real' , 'mq2' , 'mq3' , 'mq4.real' , 'mneg.real' , 'mpos.real' , 'mnegj' )
3659
3687
if dtype == np .complex128 :
3660
- regions_with_inaccuracies_keep ('q1' , 'q2' , 'q3' , 'q4' , 'neg' , 'pos' , 'negj' , 'posj' , 'ninf' , 'pinf' , 'ninfj' , 'pinfj' , 'mq2' , 'mq3' , 'mneg' , 'mnegj' )
3688
+ regions_with_inaccuracies_keep ('q1' , 'q2' , 'q3' , 'q4' , 'neg.real ' , 'pos.real ' , 'negj' , 'posj.real ' , 'ninf' , 'pinf' , 'ninfj' , 'pinfj' , 'mq2' , 'mq3' , 'mneg.real ' , 'mnegj' )
3661
3689
3662
3690
elif name == 'arctan' :
3663
3691
if dtype == np .complex64 :
3664
- regions_with_inaccuracies_keep ('q1' , 'q2' , 'q3' , 'q4' , 'neg' , 'pos' , 'negj' , 'posj' , 'ninf' , 'pinf' , 'ninfj' , 'pinfj' , 'mq1' , 'mq2' , 'mq3' , 'mq4' , 'mnegj' , 'mposj' )
3692
+ regions_with_inaccuracies_keep ('q1' , 'q2' , 'q3' , 'q4' , 'neg' , 'pos' , 'negj' , 'posj' , 'ninf' , 'pinf' , 'ninfj' , 'pinfj' ,
3693
+ 'mq1.imag' , 'mq2.imag' , 'mq3.imag' , 'mq4.imag' , 'mnegj.imag' , 'mposj.imag' )
3665
3694
if dtype == np .complex128 :
3666
3695
regions_with_inaccuracies_keep ('q1' , 'q2' , 'q3' , 'q4' , 'neg' , 'pos' , 'negj' , 'posj' , 'ninf' , 'pinf' , 'ninfj' , 'pinfj' )
3667
3696
3668
3697
elif name == 'arctanh' :
3669
- regions_with_inaccuracies_keep ('q1' , 'q2' , 'q3' , 'q4' , 'neg' , 'pos' , 'negj' , 'posj' , 'ninf' , 'pinf' , 'ninfj' , 'pinfj' , 'mpos' )
3698
+ regions_with_inaccuracies_keep ('q1' , 'q2' , 'q3' , 'q4' , 'neg' , 'pos' , 'negj' , 'posj' , 'ninf' , 'pinf' , 'ninfj' , 'pinfj' , 'mpos.imag ' )
3670
3699
# TODO(pearu): after landing openxla/xla#10503, switch to
3671
3700
# regions_with_inaccuracies_keep('pos', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mpos')
3701
+
3702
+ elif name in {'cos' , 'sin' }:
3703
+ regions_with_inaccuracies_keep ('ninf.imag' , 'pinf.imag' )
3704
+
3672
3705
elif name in {'positive' , 'negative' , 'conjugate' , 'sin' , 'cos' , 'sqrt' , 'expm1' }:
3673
3706
regions_with_inaccuracies .clear ()
3674
3707
else :
3675
3708
assert 0 # unreachable
3676
3709
3677
3710
# Finally, perform the closeness tests per region:
3678
3711
unexpected_success_regions = []
3679
- for region_name , region_slice in s_dict .items ():
3712
+ for region_name , region_slice in s_dict_parts .items ():
3680
3713
region = args [0 ][region_slice ]
3681
- inexact_indices = np .where (normalized_result [region_slice ] != normalized_expected [region_slice ])
3714
+ if region_name .endswith ('.real' ):
3715
+ result_slice , expected_slice = result [region_slice ].real , expected [region_slice ].real
3716
+ normalized_result_slice , normalized_expected_slice = normalized_result [region_slice ].real , normalized_expected [region_slice ].real
3717
+ elif region_name .endswith ('.imag' ):
3718
+ result_slice , expected_slice = result [region_slice ].imag , expected [region_slice ].imag
3719
+ normalized_result_slice , normalized_expected_slice = normalized_result [region_slice ].imag , normalized_expected [region_slice ].imag
3720
+ else :
3721
+ result_slice , expected_slice = result [region_slice ], expected [region_slice ]
3722
+ normalized_result_slice , normalized_expected_slice = normalized_result [region_slice ], normalized_expected [region_slice ]
3723
+
3724
+ inexact_indices = np .where (normalized_result_slice != normalized_expected_slice )
3682
3725
3683
3726
if inexact_indices [0 ].size == 0 :
3684
3727
inexact_samples = ''
@@ -3697,20 +3740,36 @@ def regions_with_inaccuracies_keep(*to_keep):
3697
3740
if kind == 'success' and region_name not in regions_with_inaccuracies :
3698
3741
with jtu .ignore_warning (category = RuntimeWarning , message = "overflow encountered in.*" ):
3699
3742
self .assertAllClose (
3700
- normalized_result [ region_slice ], normalized_expected [ region_slice ] , atol = atol ,
3743
+ normalized_result_slice , normalized_expected_slice , atol = atol ,
3701
3744
err_msg = f"{ name } in { region_name } , { is_cpu = } { is_cuda = } , { xla_extension_version = } \n { inexact_samples } " )
3702
3745
3703
3746
if kind == 'failure' and region_name in regions_with_inaccuracies :
3704
3747
try :
3705
3748
with self .assertRaises (AssertionError , msg = f"{ name } in { region_name } , { is_cpu = } { is_cuda = } , { xla_extension_version = } " ):
3706
3749
with jtu .ignore_warning (category = RuntimeWarning , message = "overflow encountered in.*" ):
3707
- self .assertAllClose (normalized_result [ region_slice ], normalized_expected [ region_slice ] )
3750
+ self .assertAllClose (normalized_result_slice , normalized_expected_slice )
3708
3751
except AssertionError as msg :
3709
3752
if str (msg ).startswith ('AssertionError not raised' ):
3710
3753
unexpected_success_regions .append (region_name )
3711
3754
else :
3712
3755
raise # something else is wrong..
3713
3756
3757
+ def eliminate_parts (seq ):
3758
+ # replace n.real and n.imag items in seq with n.
3759
+ result = []
3760
+ for part_name in seq :
3761
+ name = part_name .split ('.' )[0 ]
3762
+ if name in result :
3763
+ continue
3764
+ if name + '.real' in seq and name + '.imag' in seq :
3765
+ result .append (name )
3766
+ else :
3767
+ result .append (part_name )
3768
+ return result
3769
+
3770
+ regions_with_inaccuracies = eliminate_parts (regions_with_inaccuracies )
3771
+ unexpected_success_regions = eliminate_parts (unexpected_success_regions )
3772
+
3714
3773
if kind == 'success' and regions_with_inaccuracies :
3715
3774
reason = "xfail: problematic regions: " + ", " .join (regions_with_inaccuracies )
3716
3775
raise unittest .SkipTest (reason )
0 commit comments