@@ -3616,11 +3616,9 @@ def regions_with_inaccuracies_keep(*to_keep):
3616
3616
elif name == 'log10' :
3617
3617
regions_with_inaccuracies_keep ('q1' , 'q2' , 'q3' , 'q4' , 'ninf.imag' , 'pinf.imag' , 'ninfj.imag' , 'pinfj.imag' , 'zero.imag' )
3618
3618
3619
- elif name == 'log1p' :
3620
- regions_with_inaccuracies_keep ('q1.real' , 'q2.real' , 'q3.real' , 'q4.real' , 'neg.real' , 'pos.real' ,
3621
- 'negj.real' , 'posj.real' , 'ninf.real' , 'ninfj.real' , 'pinfj.real' )
3622
- # TODO(pearu): after landing openxla/xla#10503, switch to
3623
- # regions_with_inaccuracies_keep('ninf', 'pinf', 'ninfj', 'pinfj')
3619
+ elif name == 'log1p' and xla_extension_version < 254 :
3620
+ regions_with_inaccuracies_keep ('q1.real' , 'q2.real' , 'q3.real' , 'q4.real' , 'neg.real' , 'pos.real' ,
3621
+ 'negj.real' , 'posj.real' , 'ninf.real' , 'ninfj.real' , 'pinfj.real' )
3624
3622
3625
3623
elif name == 'exp' :
3626
3624
regions_with_inaccuracies_keep ('pos.imag' , 'pinf.imag' , 'mpos.imag' )
@@ -3640,9 +3638,10 @@ def regions_with_inaccuracies_keep(*to_keep):
3640
3638
'ninf.imag' , 'pinf.imag' , 'ninfj.real' , 'pinfj.real' )
3641
3639
3642
3640
elif name == 'tan' :
3641
+ # TODO(pearu): eliminate this if-block when openxla/xla#10525 lands
3643
3642
regions_with_inaccuracies_keep ('q1.imag' , 'q2.imag' , 'q3.imag' , 'q4.imag' , 'negj.imag' , 'posj.imag' ,
3644
3643
'ninfj.imag' , 'pinfj.imag' , 'mq1.imag' , 'mq2.imag' , 'mq3.imag' , 'mq4.imag' , 'mnegj.imag' , 'mposj.imag' ,
3645
- 'ninf.imag' , 'pinf.imag' )
3644
+ 'ninf.imag' , 'pinf.imag' , 'ninf.real' , 'pinf.real' , 'ninfj.real' , 'pinfj.real' )
3646
3645
3647
3646
elif name == 'sinh' :
3648
3647
if is_cuda :
@@ -3695,14 +3694,15 @@ def regions_with_inaccuracies_keep(*to_keep):
3695
3694
regions_with_inaccuracies_keep ('q1' , 'q2' , 'q3' , 'q4' , 'neg' , 'pos' , 'negj' , 'posj' , 'ninf' , 'pinf' , 'ninfj' , 'pinfj' )
3696
3695
3697
3696
elif name == 'arctanh' :
3698
- regions_with_inaccuracies_keep ('q1' , 'q2' , 'q3' , 'q4' , 'neg' , 'pos' , 'negj' , 'posj' , 'ninf' , 'pinf' , 'ninfj' , 'pinfj' , 'mpos.imag' )
3699
- # TODO(pearu): after landing openxla/xla#10503, switch to
3700
- # regions_with_inaccuracies_keep('pos', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mpos')
3697
+ if xla_extension_version < 254 :
3698
+ regions_with_inaccuracies_keep ('q1' , 'q2' , 'q3' , 'q4' , 'neg' , 'pos' , 'negj' , 'posj' , 'ninf' , 'pinf' , 'ninfj' , 'pinfj' , 'mpos.imag' )
3699
+ else :
3700
+ regions_with_inaccuracies_keep ('pos.imag' , 'ninf' , 'pinf' , 'ninfj' , 'pinfj' , 'mpos.imag' )
3701
3701
3702
3702
elif name in {'cos' , 'sin' }:
3703
3703
regions_with_inaccuracies_keep ('ninf.imag' , 'pinf.imag' )
3704
3704
3705
- elif name in {'positive' , 'negative' , 'conjugate' , 'sin' , 'cos' , 'sqrt' , 'expm1' }:
3705
+ elif name in {'positive' , 'negative' , 'conjugate' , 'sin' , 'cos' , 'sqrt' , 'expm1' , 'log1p' }:
3706
3706
regions_with_inaccuracies .clear ()
3707
3707
else :
3708
3708
assert 0 # unreachable
0 commit comments