@@ -323,17 +323,15 @@ def testCountNonzero(self, shape, dtype, axis):
323
323
self ._CheckAgainstNumpy (np_fun , jnp_fun , args_maker , check_dtypes = False )
324
324
self ._CompileAndCheck (jnp_fun , args_maker )
325
325
326
- @jtu .sample_product (shape = all_shapes , dtype = all_dtypes )
326
+ @jtu .sample_product (shape = nonzerodim_shapes , dtype = all_dtypes )
327
327
def testNonzero (self , shape , dtype ):
328
328
rng = jtu .rand_some_zero (self .rng ())
329
329
args_maker = lambda : [rng (shape , dtype )]
330
- with jtu .ignore_warning (category = DeprecationWarning ,
331
- message = "Calling nonzero on 0d arrays.*" ):
332
- self ._CheckAgainstNumpy (np .nonzero , jnp .nonzero , args_maker , check_dtypes = False )
330
+ self ._CheckAgainstNumpy (np .nonzero , jnp .nonzero , args_maker , check_dtypes = False )
333
331
334
332
@jtu .sample_product (
335
333
[dict (shape = shape , fill_value = fill_value )
336
- for shape in nonempty_array_shapes
334
+ for shape in nonempty_nonscalar_array_shapes
337
335
for fill_value in [None , - 1 , shape or (1 ,)]
338
336
],
339
337
dtype = all_dtypes ,
@@ -351,17 +349,13 @@ def np_fun(x):
351
349
return tuple (np .concatenate ([arg , np .full (size - len (arg ), fval , arg .dtype )])
352
350
for fval , arg in safe_zip (fillvals , result ))
353
351
jnp_fun = lambda x : jnp .nonzero (x , size = size , fill_value = fill_value )
354
- with jtu .ignore_warning (category = DeprecationWarning ,
355
- message = "Calling nonzero on 0d arrays.*" ):
356
- self ._CheckAgainstNumpy (np_fun , jnp_fun , args_maker , check_dtypes = False )
357
- self ._CompileAndCheck (jnp_fun , args_maker )
352
+ self ._CheckAgainstNumpy (np_fun , jnp_fun , args_maker , check_dtypes = False )
353
+ self ._CompileAndCheck (jnp_fun , args_maker )
358
354
359
- @jtu .sample_product (shape = all_shapes , dtype = all_dtypes )
355
+ @jtu .sample_product (shape = nonzerodim_shapes , dtype = all_dtypes )
360
356
def testFlatNonzero (self , shape , dtype ):
361
357
rng = jtu .rand_some_zero (self .rng ())
362
- np_fun = jtu .ignore_warning (
363
- category = DeprecationWarning ,
364
- message = "Calling nonzero on 0d arrays.*" )(np .flatnonzero )
358
+ np_fun = np .flatnonzero
365
359
jnp_fun = jnp .flatnonzero
366
360
args_maker = lambda : [rng (shape , dtype )]
367
361
self ._CheckAgainstNumpy (np_fun , jnp_fun , args_maker , check_dtypes = False )
@@ -371,15 +365,14 @@ def testFlatNonzero(self, shape, dtype):
371
365
self ._CompileAndCheck (jnp_fun , args_maker )
372
366
373
367
@jtu .sample_product (
374
- shape = nonempty_array_shapes ,
368
+ shape = nonempty_nonscalar_array_shapes ,
375
369
dtype = all_dtypes ,
376
370
fill_value = [None , - 1 , 10 , (- 1 ,), (10 ,)],
377
371
size = [1 , 5 , 10 ],
378
372
)
379
373
def testFlatNonzeroSize (self , shape , dtype , size , fill_value ):
380
374
rng = jtu .rand_some_zero (self .rng ())
381
375
args_maker = lambda : [rng (shape , dtype )]
382
- @jtu .ignore_warning (category = DeprecationWarning , message = "Calling nonzero on 0d arrays.*" )
383
376
def np_fun (x ):
384
377
result = np .flatnonzero (x )
385
378
if size <= len (result ):
@@ -391,24 +384,20 @@ def np_fun(x):
391
384
self ._CheckAgainstNumpy (np_fun , jnp_fun , args_maker , check_dtypes = False )
392
385
self ._CompileAndCheck (jnp_fun , args_maker )
393
386
394
- @jtu .sample_product (shape = all_shapes , dtype = all_dtypes )
387
+ @jtu .sample_product (shape = nonzerodim_shapes , dtype = all_dtypes )
395
388
def testArgWhere (self , shape , dtype ):
396
389
rng = jtu .rand_some_zero (self .rng ())
397
390
args_maker = lambda : [rng (shape , dtype )]
398
- with jtu .ignore_warning (category = DeprecationWarning ,
399
- message = "Calling nonzero on 0d arrays.*" ):
400
- self ._CheckAgainstNumpy (np .argwhere , jnp .argwhere , args_maker , check_dtypes = False )
391
+ self ._CheckAgainstNumpy (np .argwhere , jnp .argwhere , args_maker , check_dtypes = False )
401
392
402
393
# JIT compilation requires specifying a size statically. Full test of this
403
394
# behavior is in testNonzeroSize().
404
395
jnp_fun = lambda x : jnp .argwhere (x , size = np .size (x ) // 2 )
405
- with jtu .ignore_warning (category = DeprecationWarning ,
406
- message = "Calling nonzero on 0d arrays.*" ):
407
- self ._CompileAndCheck (jnp_fun , args_maker )
396
+ self ._CompileAndCheck (jnp_fun , args_maker )
408
397
409
398
@jtu .sample_product (
410
399
[dict (shape = shape , fill_value = fill_value )
411
- for shape in nonempty_array_shapes
400
+ for shape in nonempty_nonscalar_array_shapes
412
401
for fill_value in [None , - 1 , shape or (1 ,)]
413
402
],
414
403
dtype = all_dtypes ,
@@ -427,10 +416,8 @@ def np_fun(x):
427
416
for fval , arg in safe_zip (fillvals , result .T )]).T
428
417
jnp_fun = lambda x : jnp .argwhere (x , size = size , fill_value = fill_value )
429
418
430
- with jtu .ignore_warning (category = DeprecationWarning ,
431
- message = "Calling nonzero on 0d arrays.*" ):
432
- self ._CheckAgainstNumpy (np_fun , jnp_fun , args_maker , check_dtypes = False )
433
- self ._CompileAndCheck (jnp_fun , args_maker )
419
+ self ._CheckAgainstNumpy (np_fun , jnp_fun , args_maker , check_dtypes = False )
420
+ self ._CompileAndCheck (jnp_fun , args_maker )
434
421
435
422
@jtu .sample_product (
436
423
[dict (np_op = getattr (np , rec .name ), jnp_op = getattr (jnp , rec .name ),
@@ -4490,24 +4477,20 @@ def args_maker(): return []
4490
4477
self ._CompileAndCheck (jnp_fun , args_maker )
4491
4478
4492
4479
@jtu .sample_product (
4493
- shape = all_shapes ,
4480
+ shape = nonzerodim_shapes ,
4494
4481
dtype = all_dtypes ,
4495
4482
)
4496
4483
def testWhereOneArgument (self , shape , dtype ):
4497
4484
rng = jtu .rand_some_zero (self .rng ())
4498
4485
args_maker = lambda : [rng (shape , dtype )]
4499
4486
4500
- with jtu .ignore_warning (category = DeprecationWarning ,
4501
- message = "Calling nonzero on 0d arrays.*" ):
4502
- self ._CheckAgainstNumpy (np .where , jnp .where , args_maker , check_dtypes = False )
4487
+ self ._CheckAgainstNumpy (np .where , jnp .where , args_maker , check_dtypes = False )
4503
4488
4504
4489
# JIT compilation requires specifying a size statically. Full test of
4505
4490
# this behavior is in testNonzeroSize().
4506
4491
jnp_fun = lambda x : jnp .where (x , size = np .size (x ) // 2 )
4507
4492
4508
- with jtu .ignore_warning (category = DeprecationWarning ,
4509
- message = "Calling nonzero on 0d arrays.*" ):
4510
- self ._CompileAndCheck (jnp_fun , args_maker )
4493
+ self ._CompileAndCheck (jnp_fun , args_maker )
4511
4494
4512
4495
@jtu .sample_product (
4513
4496
shapes = filter (_shapes_are_broadcast_compatible ,
0 commit comments