Skip to content

Commit 6b582f5

Browse files
author
jax authors
committed
Merge pull request #20552 from jakevdp:geomspace-complex
PiperOrigin-RevId: 621348396
2 parents 88dd29a + fd7c85b commit 6b582f5

File tree

3 files changed

+14
-13
lines changed

3 files changed

+14
-13
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ Remember to align the itemized text with the first line of an item within a list
1212
* Added {func}`jax.numpy.trapezoid`, following the addition of this function in
1313
NumPy 2.0.
1414

15+
* Changes
16+
* Complex-valued {func}`jax.numpy.geomspace` now chooses the logarithmic spiral
17+
branch consistent with that of NumPy 2.0.
18+
1519
* Deprecations & Removals
1620
* {func}`jax.tree_map` is deprecated; use `jax.tree.map` instead, or for backward
1721
compatibility with older JAX versions, use {func}`jax.tree_util.tree_map`.

jax/_src/numpy/lax_numpy.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2685,13 +2685,11 @@ def _geomspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool
26852685
util.check_arraylike("geomspace", start, stop)
26862686
start = asarray(start, dtype=computation_dtype)
26872687
stop = asarray(stop, dtype=computation_dtype)
2688-
# follow the numpy geomspace convention for negative and complex endpoints
2689-
signflip = 1 - (1 - ufuncs.sign(ufuncs.real(start))) * (1 - ufuncs.sign(ufuncs.real(stop))) // 2
2690-
signflip = signflip.astype(computation_dtype)
2691-
res = signflip * logspace(ufuncs.log10(signflip * start),
2692-
ufuncs.log10(signflip * stop), num,
2693-
endpoint=endpoint, base=10.0,
2694-
dtype=computation_dtype, axis=0)
2688+
2689+
sign = ufuncs.sign(start)
2690+
res = sign * logspace(ufuncs.log10(start / sign), ufuncs.log10(stop / sign),
2691+
num, endpoint=endpoint, base=10.0,
2692+
dtype=computation_dtype, axis=0)
26952693
if axis != 0:
26962694
res = moveaxis(res, 0, axis)
26972695
return lax.convert_element_type(res, dtype)

tests/lax_numpy_test.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5068,10 +5068,6 @@ def np_op(start, stop):
50685068
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
50695069
def testGeomspace(self, start_shape, stop_shape, num,
50705070
endpoint, dtype, axis):
5071-
# TODO(jakevdp): remove once https://github.com/numpy/numpy/issues/26195 is fixed.
5072-
if jtu.numpy_version() >= (2, 0, 0) and dtypes.issubdtype(dtype, jnp.complexfloating):
5073-
self.skipTest("NumPy 2.0.0rc1 geomspace is broken for complex dtypes.")
5074-
50755071
rng = jtu.rand_default(self.rng())
50765072
# relax default tolerances slightly
50775073
tol = {dtypes.bfloat16: 2e-2, np.float16: 4e-3, np.float32: 2e-3,
@@ -5101,8 +5097,11 @@ def np_op(start, stop):
51015097
start, stop, num, endpoint=endpoint,
51025098
dtype=dtype if dtype != jnp.bfloat16 else np.float32,
51035099
axis=axis).astype(dtype)
5104-
self._CheckAgainstNumpy(np_op, jnp_op, args_maker,
5105-
check_dtypes=False, tol=tol)
5100+
5101+
# JAX follows NumPy 2.0 semantics for complex geomspace.
5102+
if not (jtu.numpy_version() < (2, 0, 0) and dtypes.issubdtype(dtype, jnp.complexfloating)):
5103+
self._CheckAgainstNumpy(np_op, jnp_op, args_maker,
5104+
check_dtypes=False, tol=tol)
51065105
if dtype in (inexact_dtypes + [None,]):
51075106
self._CompileAndCheck(jnp_op, args_maker,
51085107
check_dtypes=False, atol=tol, rtol=tol)

0 commit comments

Comments
 (0)