Skip to content

Commit 36bedee

Browse files
author
jax authors
committed
Merge pull request #20688 from pearu:pearu/tan
PiperOrigin-RevId: 623851102
2 parents 2be7205 + fc04ba9 commit 36bedee

File tree

2 files changed

+38
-10
lines changed

2 files changed

+38
-10
lines changed

jax/_src/test_util.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1754,6 +1754,34 @@ def log1p(self, x):
17541754
return ctx.make_mpc(((-x.real)._mpf_, (3 * pi / 4)._mpf_))
17551755
return ctx.log1p(x)
17561756

1757+
def tan(self, x):
1758+
ctx = x.context
1759+
if isinstance(x, ctx.mpc):
1760+
# Workaround mpmath 1.3 bug in tan(+-inf+-infj) evaluation (see mpmath/mpmath#781).
1761+
# TODO(pearu): remove this function when mpmath 1.4 or newer
1762+
# will be the required test dependency.
1763+
if ctx.isinf(x.imag) and (ctx.isinf(x.real) or ctx.isfinite(x.real)):
1764+
if x.imag > 0:
1765+
return ctx.make_mpc((ctx.zero._mpf_, ctx.one._mpf_))
1766+
return ctx.make_mpc((ctx.zero._mpf_, (-ctx.one)._mpf_))
1767+
if ctx.isinf(x.real) and ctx.isfinite(x.imag):
1768+
return ctx.make_mpc((ctx.nan._mpf_, ctx.nan._mpf_))
1769+
return ctx.tan(x)
1770+
1771+
def tanh(self, x):
1772+
ctx = x.context
1773+
if isinstance(x, ctx.mpc):
1774+
# Workaround mpmath 1.3 bug in tanh(+-inf+-infj) evaluation (see mpmath/mpmath#781).
1775+
# TODO(pearu): remove this function when mpmath 1.4 or newer
1776+
# will be the required test dependency.
1777+
if ctx.isinf(x.imag) and (ctx.isinf(x.real) or ctx.isfinite(x.real)):
1778+
if x.imag > 0:
1779+
return ctx.make_mpc((ctx.zero._mpf_, ctx.one._mpf_))
1780+
return ctx.make_mpc((ctx.zero._mpf_, (-ctx.one)._mpf_))
1781+
if ctx.isinf(x.real) and ctx.isfinite(x.imag):
1782+
return ctx.make_mpc((ctx.nan._mpf_, ctx.nan._mpf_))
1783+
return ctx.tanh(x)
1784+
17571785
def log2(self, x):
17581786
return x.context.ln(x) / x.context.ln2
17591787

tests/lax_test.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3616,11 +3616,9 @@ def regions_with_inaccuracies_keep(*to_keep):
36163616
elif name == 'log10':
36173617
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'ninf.imag', 'pinf.imag', 'ninfj.imag', 'pinfj.imag', 'zero.imag')
36183618

3619-
elif name == 'log1p':
3620-
regions_with_inaccuracies_keep('q1.real', 'q2.real', 'q3.real', 'q4.real', 'neg.real', 'pos.real',
3621-
'negj.real', 'posj.real', 'ninf.real', 'ninfj.real', 'pinfj.real')
3622-
# TODO(pearu): after landing openxla/xla#10503, switch to
3623-
# regions_with_inaccuracies_keep('ninf', 'pinf', 'ninfj', 'pinfj')
3619+
elif name == 'log1p' and xla_extension_version < 254:
3620+
regions_with_inaccuracies_keep('q1.real', 'q2.real', 'q3.real', 'q4.real', 'neg.real', 'pos.real',
3621+
'negj.real', 'posj.real', 'ninf.real', 'ninfj.real', 'pinfj.real')
36243622

36253623
elif name == 'exp':
36263624
regions_with_inaccuracies_keep('pos.imag', 'pinf.imag', 'mpos.imag')
@@ -3640,9 +3638,10 @@ def regions_with_inaccuracies_keep(*to_keep):
36403638
'ninf.imag', 'pinf.imag', 'ninfj.real', 'pinfj.real')
36413639

36423640
elif name == 'tan':
3641+
# TODO(pearu): eliminate this if-block when openxla/xla#10525 lands
36433642
regions_with_inaccuracies_keep('q1.imag', 'q2.imag', 'q3.imag', 'q4.imag', 'negj.imag', 'posj.imag',
36443643
'ninfj.imag', 'pinfj.imag', 'mq1.imag', 'mq2.imag', 'mq3.imag', 'mq4.imag', 'mnegj.imag', 'mposj.imag',
3645-
'ninf.imag', 'pinf.imag')
3644+
'ninf.imag', 'pinf.imag', 'ninf.real', 'pinf.real', 'ninfj.real', 'pinfj.real')
36463645

36473646
elif name == 'sinh':
36483647
if is_cuda:
@@ -3695,14 +3694,15 @@ def regions_with_inaccuracies_keep(*to_keep):
36953694
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj')
36963695

36973696
elif name == 'arctanh':
3698-
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mpos.imag')
3699-
# TODO(pearu): after landing openxla/xla#10503, switch to
3700-
# regions_with_inaccuracies_keep('pos', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mpos')
3697+
if xla_extension_version < 254:
3698+
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mpos.imag')
3699+
else:
3700+
regions_with_inaccuracies_keep('pos.imag', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mpos.imag')
37013701

37023702
elif name in {'cos', 'sin'}:
37033703
regions_with_inaccuracies_keep('ninf.imag', 'pinf.imag')
37043704

3705-
elif name in {'positive', 'negative', 'conjugate', 'sin', 'cos', 'sqrt', 'expm1'}:
3705+
elif name in {'positive', 'negative', 'conjugate', 'sin', 'cos', 'sqrt', 'expm1', 'log1p'}:
37063706
regions_with_inaccuracies.clear()
37073707
else:
37083708
assert 0 # unreachable

0 commit comments

Comments
 (0)