Skip to content

Commit 41fa67c

Browse files
Jake VanderPlasjax authors
authored andcommitted
Finalize deprecation of zero-dimensional inputs to jnp.nonzero
PiperOrigin-RevId: 626299531
1 parent 837f0bb commit 41fa67c

File tree

3 files changed

+21
-38
lines changed

3 files changed

+21
-38
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ Remember to align the itemized text with the first line of an item within a list
4646
* The {func}`jax.numpy.hypot` function now issues a deprecation warning when
4747
passing complex-valued inputs to it. This will raise an error when the
4848
deprecation is completed.
49+
* Scalar arguments to {func}`jax.numpy.nonzero`, {func}`jax.numpy.where`, and
50+
related functions now raise an error, following a similar change in NumPy.
4951

5052
## jaxlib 0.4.27
5153

jax/_src/numpy/lax_numpy.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1454,10 +1454,8 @@ def nonzero(a: ArrayLike, *, size: int | None = None,
14541454
arr = asarray(a)
14551455
del a
14561456
if ndim(arr) == 0:
1457-
# Added 2023 Dec 6
1458-
warnings.warn("Calling nonzero on 0d arrays is deprecated. Use `atleast_1d(arr).nonzero()",
1459-
DeprecationWarning, stacklevel=2)
1460-
arr = atleast_1d(arr)
1457+
raise ValueError("Calling nonzero on 0d arrays is not allowed. "
1458+
"Use jnp.atleast_1d(scalar).nonzero() instead.")
14611459
mask = arr if arr.dtype == bool else (arr != 0)
14621460
calculated_size = mask.sum() if size is None else size
14631461
calculated_size = core.concrete_dim_or_error(calculated_size,

tests/lax_numpy_test.py

Lines changed: 17 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -323,17 +323,15 @@ def testCountNonzero(self, shape, dtype, axis):
323323
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
324324
self._CompileAndCheck(jnp_fun, args_maker)
325325

326-
@jtu.sample_product(shape=all_shapes, dtype=all_dtypes)
326+
@jtu.sample_product(shape=nonzerodim_shapes, dtype=all_dtypes)
327327
def testNonzero(self, shape, dtype):
328328
rng = jtu.rand_some_zero(self.rng())
329329
args_maker = lambda: [rng(shape, dtype)]
330-
with jtu.ignore_warning(category=DeprecationWarning,
331-
message="Calling nonzero on 0d arrays.*"):
332-
self._CheckAgainstNumpy(np.nonzero, jnp.nonzero, args_maker, check_dtypes=False)
330+
self._CheckAgainstNumpy(np.nonzero, jnp.nonzero, args_maker, check_dtypes=False)
333331

334332
@jtu.sample_product(
335333
[dict(shape=shape, fill_value=fill_value)
336-
for shape in nonempty_array_shapes
334+
for shape in nonempty_nonscalar_array_shapes
337335
for fill_value in [None, -1, shape or (1,)]
338336
],
339337
dtype=all_dtypes,
@@ -351,17 +349,13 @@ def np_fun(x):
351349
return tuple(np.concatenate([arg, np.full(size - len(arg), fval, arg.dtype)])
352350
for fval, arg in safe_zip(fillvals, result))
353351
jnp_fun = lambda x: jnp.nonzero(x, size=size, fill_value=fill_value)
354-
with jtu.ignore_warning(category=DeprecationWarning,
355-
message="Calling nonzero on 0d arrays.*"):
356-
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
357-
self._CompileAndCheck(jnp_fun, args_maker)
352+
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
353+
self._CompileAndCheck(jnp_fun, args_maker)
358354

359-
@jtu.sample_product(shape=all_shapes, dtype=all_dtypes)
355+
@jtu.sample_product(shape=nonzerodim_shapes, dtype=all_dtypes)
360356
def testFlatNonzero(self, shape, dtype):
361357
rng = jtu.rand_some_zero(self.rng())
362-
np_fun = jtu.ignore_warning(
363-
category=DeprecationWarning,
364-
message="Calling nonzero on 0d arrays.*")(np.flatnonzero)
358+
np_fun = np.flatnonzero
365359
jnp_fun = jnp.flatnonzero
366360
args_maker = lambda: [rng(shape, dtype)]
367361
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
@@ -371,15 +365,14 @@ def testFlatNonzero(self, shape, dtype):
371365
self._CompileAndCheck(jnp_fun, args_maker)
372366

373367
@jtu.sample_product(
374-
shape=nonempty_array_shapes,
368+
shape=nonempty_nonscalar_array_shapes,
375369
dtype=all_dtypes,
376370
fill_value=[None, -1, 10, (-1,), (10,)],
377371
size=[1, 5, 10],
378372
)
379373
def testFlatNonzeroSize(self, shape, dtype, size, fill_value):
380374
rng = jtu.rand_some_zero(self.rng())
381375
args_maker = lambda: [rng(shape, dtype)]
382-
@jtu.ignore_warning(category=DeprecationWarning, message="Calling nonzero on 0d arrays.*")
383376
def np_fun(x):
384377
result = np.flatnonzero(x)
385378
if size <= len(result):
@@ -391,24 +384,20 @@ def np_fun(x):
391384
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
392385
self._CompileAndCheck(jnp_fun, args_maker)
393386

394-
@jtu.sample_product(shape=all_shapes, dtype=all_dtypes)
387+
@jtu.sample_product(shape=nonzerodim_shapes, dtype=all_dtypes)
395388
def testArgWhere(self, shape, dtype):
396389
rng = jtu.rand_some_zero(self.rng())
397390
args_maker = lambda: [rng(shape, dtype)]
398-
with jtu.ignore_warning(category=DeprecationWarning,
399-
message="Calling nonzero on 0d arrays.*"):
400-
self._CheckAgainstNumpy(np.argwhere, jnp.argwhere, args_maker, check_dtypes=False)
391+
self._CheckAgainstNumpy(np.argwhere, jnp.argwhere, args_maker, check_dtypes=False)
401392

402393
# JIT compilation requires specifying a size statically. Full test of this
403394
# behavior is in testNonzeroSize().
404395
jnp_fun = lambda x: jnp.argwhere(x, size=np.size(x) // 2)
405-
with jtu.ignore_warning(category=DeprecationWarning,
406-
message="Calling nonzero on 0d arrays.*"):
407-
self._CompileAndCheck(jnp_fun, args_maker)
396+
self._CompileAndCheck(jnp_fun, args_maker)
408397

409398
@jtu.sample_product(
410399
[dict(shape=shape, fill_value=fill_value)
411-
for shape in nonempty_array_shapes
400+
for shape in nonempty_nonscalar_array_shapes
412401
for fill_value in [None, -1, shape or (1,)]
413402
],
414403
dtype=all_dtypes,
@@ -427,10 +416,8 @@ def np_fun(x):
427416
for fval, arg in safe_zip(fillvals, result.T)]).T
428417
jnp_fun = lambda x: jnp.argwhere(x, size=size, fill_value=fill_value)
429418

430-
with jtu.ignore_warning(category=DeprecationWarning,
431-
message="Calling nonzero on 0d arrays.*"):
432-
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
433-
self._CompileAndCheck(jnp_fun, args_maker)
419+
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
420+
self._CompileAndCheck(jnp_fun, args_maker)
434421

435422
@jtu.sample_product(
436423
[dict(np_op=getattr(np, rec.name), jnp_op=getattr(jnp, rec.name),
@@ -4490,24 +4477,20 @@ def args_maker(): return []
44904477
self._CompileAndCheck(jnp_fun, args_maker)
44914478

44924479
@jtu.sample_product(
4493-
shape=all_shapes,
4480+
shape=nonzerodim_shapes,
44944481
dtype=all_dtypes,
44954482
)
44964483
def testWhereOneArgument(self, shape, dtype):
44974484
rng = jtu.rand_some_zero(self.rng())
44984485
args_maker = lambda: [rng(shape, dtype)]
44994486

4500-
with jtu.ignore_warning(category=DeprecationWarning,
4501-
message="Calling nonzero on 0d arrays.*"):
4502-
self._CheckAgainstNumpy(np.where, jnp.where, args_maker, check_dtypes=False)
4487+
self._CheckAgainstNumpy(np.where, jnp.where, args_maker, check_dtypes=False)
45034488

45044489
# JIT compilation requires specifying a size statically. Full test of
45054490
# this behavior is in testNonzeroSize().
45064491
jnp_fun = lambda x: jnp.where(x, size=np.size(x) // 2)
45074492

4508-
with jtu.ignore_warning(category=DeprecationWarning,
4509-
message="Calling nonzero on 0d arrays.*"):
4510-
self._CompileAndCheck(jnp_fun, args_maker)
4493+
self._CompileAndCheck(jnp_fun, args_maker)
45114494

45124495
@jtu.sample_product(
45134496
shapes=filter(_shapes_are_broadcast_compatible,

0 commit comments

Comments
 (0)