Skip to content

Commit c7517b8

Browse files
author
jax authors
committed
Merge pull request #20825 from jakevdp:gammasgn
PiperOrigin-RevId: 626347182
2 parents 41fa67c + 568db10 commit c7517b8

File tree

4 files changed

+23
-3
lines changed

4 files changed

+23
-3
lines changed

docs/jax.scipy.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ jax.scipy.special
146146
gammainc
147147
gammaincc
148148
gammaln
149+
gammasgn
149150
hyp1f1
150151
i0
151152
i0e

jax/_src/scipy/special.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,22 @@ def gammaln(x: ArrayLike) -> Array:
4444
return lax.lgamma(x)
4545

4646

47-
def _gamma_sign(x: Array) -> Array:
47+
def gammasgn(x: ArrayLike) -> Array:
48+
"""Sign of the gamma function.
49+
50+
JAX implementation of :func:`scipy.special.gammasgn`.
51+
52+
Args:
53+
x: arraylike, real valued.
54+
55+
Returns:
56+
array containing 1.0 where gamma(x) is positive, and -1.0 where
57+
gamma(x) is negative.
58+
59+
See Also:
60+
:func:`jax.scipy.special.gamma`
61+
"""
62+
x, = promote_args_inexact("gammasgn", x)
4863
floor_x = lax.floor(x)
4964
return jnp.where((x > 0) | (x == floor_x) | (floor_x % 2 == 0), 1.0, -1.0)
5065

@@ -53,7 +68,7 @@ def _gamma_sign(x: Array) -> Array:
5368
The JAX version only accepts real-valued inputs.""")
5469
def gamma(x: ArrayLike) -> Array:
5570
x, = promote_args_inexact("gamma", x)
56-
return _gamma_sign(x) * lax.exp(lax.lgamma(x))
71+
return gammasgn(x) * lax.exp(lax.lgamma(x))
5772

5873
betaln = implements(
5974
osp_special.betaln,
@@ -73,7 +88,7 @@ def factorial(n: ArrayLike, exact: bool = False) -> Array:
7388
@implements(osp_special.beta, module='scipy.special')
7489
def beta(x: ArrayLike, y: ArrayLike) -> Array:
7590
x, y = promote_args_inexact("beta", x, y)
76-
sign = _gamma_sign(x) * _gamma_sign(y) * _gamma_sign(x + y)
91+
sign = gammasgn(x) * gammasgn(y) * gammasgn(x + y)
7792
return sign * lax.exp(betaln(x, y))
7893

7994

jax/scipy/special.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
gammainc as gammainc,
3535
gammaincc as gammaincc,
3636
gammaln as gammaln,
37+
gammasgn as gammasgn,
3738
gamma as gamma,
3839
i0 as i0,
3940
i0e as i0e,

tests/lax_scipy_special_functions_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,9 @@ def op_record(name, nargs, dtypes, rng_factory, test_grad, nondiff_argnums=(), t
7272
op_record(
7373
"gammaincc", 2, float_dtypes, jtu.rand_positive, True
7474
),
75+
op_record(
76+
"gammasgn", 1, float_dtypes, jtu.rand_default, True
77+
),
7578
op_record(
7679
"erf", 1, float_dtypes, jtu.rand_small_positive, True
7780
),

0 commit comments

Comments
 (0)