Skip to content

Commit 564da79

Browse files
authored
TST: array types: wrap namespaces centrally (scipy#22320)
1 parent 4cc4d14 commit 564da79

26 files changed

+208
-316
lines changed

scipy/_lib/_array_api.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -252,9 +252,6 @@ def _strict_check(actual, desired, xp, *,
252252
xp = _default_xp_ctxvar.get()
253253
except LookupError:
254254
xp = array_namespace(desired)
255-
else:
256-
# Wrap namespace if needed
257-
xp = array_namespace(xp.asarray(0))
258255

259256
if check_namespace:
260257
_assert_matching_namespace(actual, desired, xp)

scipy/_lib/tests/test__util.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -615,8 +615,8 @@ def f2(*args):
615615

616616
# Ensure arrays are at least 1d to follow sane type promotion rules.
617617
# This can be removed when minimum supported NumPy is 2.0
618-
if xp == np:
619-
cond, fillvalue, *arrays = np.atleast_1d(cond, fillvalue, *arrays)
618+
if is_numpy(xp):
619+
cond, fillvalue, *arrays = xp.atleast_1d(cond, fillvalue, *arrays)
620620

621621
ref1 = xp.where(cond, f(*arrays), fillvalue)
622622
ref2 = xp.where(cond, f(*arrays), f2(*arrays))
@@ -625,7 +625,7 @@ def f2(*args):
625625
# Python scalar. When it does, test can be run with array_api_strict, too.
626626
ref3 = xp.where(cond, f(*arrays), float_fillvalue)
627627

628-
if xp == np: # because we ensured arrays are at least 1d
628+
if is_numpy(xp): # because we ensured arrays are at least 1d
629629
ref1 = ref1.reshape(result_shape)
630630
ref2 = ref2.reshape(result_shape)
631631
ref3 = ref3.reshape(result_shape)

scipy/_lib/tests/test_array_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def test_strict_checks(self, xp, dtype, shape):
8484

8585
kwarg_names = ["check_namespace", "check_dtype", "check_shape", "check_0d"]
8686
options = dict(zip(kwarg_names, [True, False, False, False]))
87-
if xp == np:
87+
if is_numpy(xp):
8888
xp_assert_equal(x, y, **options)
8989
else:
9090
with pytest.raises(

scipy/cluster/tests/test_vq.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from scipy._lib import array_api_extra as xpx
1919
from scipy._lib._array_api import (
20-
SCIPY_ARRAY_API, array_namespace, xp_copy, xp_assert_close, xp_assert_equal
20+
SCIPY_ARRAY_API, xp_copy, xp_assert_close, xp_assert_equal
2121
)
2222

2323
skip_xp_backends = pytest.mark.skip_xp_backends
@@ -358,13 +358,12 @@ def test_krandinit(self, xp, krand_lock):
358358
datas = [xp.reshape(data, (200, 2)),
359359
xp.reshape(data, (20, 20))[:10, :]]
360360
k = int(1e6)
361-
xp_test = array_namespace(data)
362361
with krand_lock:
363362
for data in datas:
364363
rng = np.random.default_rng(1234)
365-
init = _krandinit(data, k, rng, xp_test)
366-
orig_cov = xpx.cov(data.T, xp=xp_test)
367-
init_cov = xpx.cov(init.T, xp=xp_test)
364+
init = _krandinit(data, k, rng, xp)
365+
orig_cov = xpx.cov(data.T, xp=xp)
366+
init_cov = xpx.cov(init.T, xp=xp)
368367
xp_assert_close(orig_cov, init_cov, atol=1.1e-2)
369368

370369
def test_kmeans2_empty(self, xp):

scipy/conftest.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from scipy._lib._fpumode import get_fpu_mode
1515
from scipy._lib._testutils import FPUModeChangeWarning
16-
from scipy._lib._array_api import SCIPY_ARRAY_API, SCIPY_DEVICE
16+
from scipy._lib._array_api import SCIPY_ARRAY_API, SCIPY_DEVICE, array_namespace
1717
from scipy._lib import _pep440
1818

1919
try:
@@ -227,17 +227,21 @@ def xp(request):
227227
# if any, and raise pytest.xfail() if the current xp is in the list.
228228
skip_or_xfail_xp_backends(request, "xfail")
229229

230+
xp = request.param
231+
# Potentially wrap namespace with array_api_compat
232+
xp = array_namespace(xp.empty(0))
233+
230234
if SCIPY_ARRAY_API:
231235
from scipy._lib._array_api import default_xp
232236

233237
# Throughout all calls to assert_almost_equal, assert_array_almost_equal, and
234238
# xp_assert_* functions, test that the array namespace is xp in both the
235239
# expected and actual arrays. This is to detect the case where both arrays are
236240
# erroneously just plain numpy while xp is something else.
237-
with default_xp(request.param):
238-
yield request.param
241+
with default_xp(xp):
242+
yield xp
239243
else:
240-
yield request.param
244+
yield xp
241245

242246

243247
skip_xp_invalid_arg = pytest.mark.skipif(SCIPY_ARRAY_API,

scipy/differentiate/tests/test_differentiate.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import scipy._lib._elementwise_iterative_method as eim
77
from scipy._lib._array_api_no_0d import xp_assert_close, xp_assert_equal, xp_assert_less
8-
from scipy._lib._array_api import is_numpy, is_torch, array_namespace
8+
from scipy._lib._array_api import is_numpy, is_torch, xp_size
99

1010
from scipy import stats, optimize, special
1111
from scipy.differentiate import derivative, jacobian, hessian
@@ -376,8 +376,7 @@ def test_special_cases(self, xp):
376376
# Test that integers are not passed to `f`
377377
# (otherwise this would overflow)
378378
def f(x):
379-
xp_test = array_namespace(x) # needs `isdtype`
380-
assert xp_test.isdtype(x.dtype, 'real floating')
379+
assert xp.isdtype(x.dtype, 'real floating')
381380
return x ** 99 - 1
382381

383382
if not is_torch(xp): # torch defaults to float32
@@ -582,9 +581,10 @@ def df1_1xy(x, y):
582581
return xp.sin(2*x) * y**2
583582

584583
res = jacobian(df1, z, initial_step=10)
585-
if is_numpy(xp):
586-
assert len(np.unique(res.nit)) == 4
587-
assert len(np.unique(res.nfev)) == 4
584+
# FIXME https://github.com/scipy/scipy/pull/22320#discussion_r1914898175
585+
if not is_torch(xp):
586+
assert xp_size(xp.unique_values(res.nit)) == 4
587+
assert xp_size(xp.unique_values(res.nfev)) == 4
588588

589589
res00 = jacobian(lambda x: df1_0xy(x, z[1]), z[0:1], initial_step=10)
590590
res01 = jacobian(lambda y: df1_0xy(z[0], y), z[1:2], initial_step=10)
@@ -594,7 +594,10 @@ def df1_1xy(x, y):
594594
for attr in ['success', 'status', 'df', 'nit', 'nfev']:
595595
ref_attr = xp.asarray([[getattr(res00, attr), getattr(res01, attr)],
596596
[getattr(res10, attr), getattr(res11, attr)]])
597-
ref[attr] = xp.squeeze(ref_attr)
597+
ref[attr] = xp.squeeze(
598+
ref_attr,
599+
axis=tuple(ax for ax, size in enumerate(ref_attr.shape) if size == 1)
600+
)
598601
rtol = 1.5e-5 if res[attr].dtype == xp.float32 else 1.5e-14
599602
xp_assert_close(res[attr], ref[attr], rtol=rtol)
600603

@@ -659,10 +662,9 @@ def test_float32(self, xp):
659662

660663
def test_nfev(self, xp):
661664
z = xp.asarray([0.5, 0.25])
662-
xp_test = array_namespace(z)
663665

664666
def f1(z):
665-
x, y = xp_test.broadcast_arrays(*z)
667+
x, y = xp.broadcast_arrays(*z)
666668
f1.nfev = f1.nfev + (math.prod(x.shape[2:]) if x.ndim > 2 else 1)
667669
return xp.sin(x) * y ** 3
668670
f1.nfev = 0

scipy/fft/tests/test_basic.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from pytest import raises as assert_raises
99
import scipy.fft as fft
1010
from scipy._lib._array_api import (
11-
array_namespace, xp_size, xp_assert_close, xp_assert_equal
11+
is_numpy, xp_size, xp_assert_close, xp_assert_equal
1212
)
1313

1414
skip_xp_backends = pytest.mark.skip_xp_backends
@@ -228,10 +228,10 @@ def _check_axes(self, op, xp):
228228
dtype = get_expected_input_dtype(op, xp)
229229
x = xp.asarray(random((30, 20, 10)), dtype=dtype)
230230
axes = [(0, 1, 2), (0, 2, 1), (1, 0, 2), (1, 2, 0), (2, 0, 1), (2, 1, 0)]
231-
xp_test = array_namespace(x)
231+
232232
for a in axes:
233-
op_tr = op(xp_test.permute_dims(x, axes=a))
234-
tr_op = xp_test.permute_dims(op(x, axes=a), axes=a)
233+
op_tr = op(xp.permute_dims(x, axes=a))
234+
tr_op = xp.permute_dims(op(x, axes=a), axes=a)
235235
xp_assert_close(op_tr, tr_op)
236236

237237
@pytest.mark.parametrize("op", [fft.fftn, fft.ifftn, fft.rfftn, fft.irfftn])
@@ -248,15 +248,15 @@ def test_axes_subset_with_shape_standard(self, op, xp):
248248
dtype = get_expected_input_dtype(op, xp)
249249
x = xp.asarray(random((16, 8, 4)), dtype=dtype)
250250
axes = [(0, 1, 2), (0, 2, 1), (1, 2, 0)]
251-
xp_test = array_namespace(x)
251+
252252
for a in axes:
253253
# different shape on the first two axes
254254
shape = tuple([2*x.shape[ax] if ax in a[:2] else x.shape[ax]
255255
for ax in range(x.ndim)])
256256
# transform only the first two axes
257-
op_tr = op(xp_test.permute_dims(x, axes=a),
257+
op_tr = op(xp.permute_dims(x, axes=a),
258258
s=shape[:2], axes=(0, 1))
259-
tr_op = xp_test.permute_dims(op(x, s=shape[:2], axes=a[:2]),
259+
tr_op = xp.permute_dims(op(x, s=shape[:2], axes=a[:2]),
260260
axes=a)
261261
xp_assert_close(op_tr, tr_op)
262262

@@ -268,21 +268,21 @@ def test_axes_subset_with_shape_non_standard(self, op, xp):
268268
dtype = get_expected_input_dtype(op, xp)
269269
x = xp.asarray(random((16, 8, 4)), dtype=dtype)
270270
axes = [(0, 1, 2), (0, 2, 1), (1, 2, 0)]
271-
xp_test = array_namespace(x)
271+
272272
for a in axes:
273273
# different shape on the first two axes
274274
shape = tuple([2*x.shape[ax] if ax in a[:2] else x.shape[ax]
275275
for ax in range(x.ndim)])
276276
# transform only the first two axes
277-
op_tr = op(xp_test.permute_dims(x, axes=a), s=shape[:2], axes=(0, 1))
278-
tr_op = xp_test.permute_dims(op(x, s=shape[:2], axes=a[:2]), axes=a)
277+
op_tr = op(xp.permute_dims(x, axes=a), s=shape[:2], axes=(0, 1))
278+
tr_op = xp.permute_dims(op(x, s=shape[:2], axes=a[:2]), axes=a)
279279
xp_assert_close(op_tr, tr_op)
280280

281281
def test_all_1d_norm_preserving(self, xp):
282282
# verify that round-trip transforms are norm-preserving
283283
x = xp.asarray(random(30), dtype=xp.float64)
284-
xp_test = array_namespace(x)
285-
x_norm = xp_test.linalg.vector_norm(x)
284+
285+
x_norm = xp.linalg.vector_norm(x)
286286
n = xp_size(x) * 2
287287
func_pairs = [(fft.rfft, fft.irfft),
288288
# hfft: order so the first function takes x.size samples
@@ -294,12 +294,12 @@ def test_all_1d_norm_preserving(self, xp):
294294
for forw, back in func_pairs:
295295
if forw == fft.fft:
296296
x = xp.asarray(x, dtype=xp.complex128)
297-
x_norm = xp_test.linalg.vector_norm(x)
297+
x_norm = xp.linalg.vector_norm(x)
298298
for n in [xp_size(x), 2*xp_size(x)]:
299299
for norm in ['backward', 'ortho', 'forward']:
300300
tmp = forw(x, n=n, norm=norm)
301301
tmp = back(tmp, n=n, norm=norm)
302-
xp_assert_close(xp_test.linalg.vector_norm(tmp), x_norm)
302+
xp_assert_close(xp.linalg.vector_norm(tmp), x_norm)
303303

304304
@skip_xp_backends(np_only=True)
305305
@pytest.mark.parametrize("dtype", [np.float16, np.longdouble])
@@ -481,13 +481,17 @@ def test_non_standard_params(func, xp):
481481
else:
482482
dtype = xp.complex128
483483

484-
if xp.__name__ != 'numpy':
485-
x = xp.asarray([1, 2, 3], dtype=dtype)
486-
# func(x) should not raise an exception
487-
func(x)
484+
x = xp.asarray([1, 2, 3], dtype=dtype)
485+
# func(x) should not raise an exception
486+
func(x)
487+
488+
if is_numpy(xp):
489+
func(x, workers=2)
490+
else:
488491
assert_raises(ValueError, func, x, workers=2)
489-
# `plan` param is not tested since SciPy does not use it currently
490-
# but should be tested if it comes into use
492+
493+
# `plan` param is not tested since SciPy does not use it currently
494+
# but should be tested if it comes into use
491495

492496

493497
@pytest.mark.parametrize("dtype", ['float32', 'float64'])

scipy/fft/tests/test_fftlog.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from scipy.fft._fftlog import fht, ifht, fhtoffset
88
from scipy.special import poch
99

10-
from scipy._lib._array_api import xp_assert_close, xp_assert_less, array_namespace
10+
from scipy._lib._array_api import xp_assert_close, xp_assert_less
1111

1212
skip_xp_backends = pytest.mark.skip_xp_backends
1313

@@ -186,13 +186,12 @@ def test_array_like(xp, op):
186186
@pytest.mark.parametrize('n', [128, 129])
187187
def test_gh_21661(xp, n):
188188
one = xp.asarray(1.0)
189-
xp_test = array_namespace(one)
190189
mu = 0.0
191190
r = np.logspace(-7, 1, n)
192191
dln = math.log(r[1] / r[0])
193192
offset = fhtoffset(dln, initial=-6 * np.log(10), mu=mu)
194193
r = xp.asarray(r, dtype=one.dtype)
195-
k = math.exp(offset) / xp_test.flip(r, axis=-1)
194+
k = math.exp(offset) / xp.flip(r, axis=-1)
196195

197196
def f(x, mu):
198197
return x**(mu + 1)*xp.exp(-x**2/2)

scipy/fft/tests/test_helper.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import numpy as np
1212
import sys
1313
from scipy._lib._array_api import (
14-
xp_assert_close, get_xp_devices, xp_device, array_namespace
14+
xp_assert_close, get_xp_devices, xp_device
1515
)
1616
from scipy import fft
1717

@@ -532,11 +532,10 @@ def test_definition(self, xp):
532532
xp_assert_close(y, x2, check_dtype=False)
533533

534534
def test_device(self, xp):
535-
xp_test = array_namespace(xp.empty(0))
536535
devices = get_xp_devices(xp)
537536
for d in devices:
538537
y = fft.fftfreq(9, xp=xp, device=d)
539-
x = xp_test.empty(0, device=d)
538+
x = xp.empty(0, device=d)
540539
assert xp_device(y) == xp_device(x)
541540

542541

@@ -565,9 +564,8 @@ def test_definition(self, xp):
565564
xp_assert_close(y, x2, check_dtype=False)
566565

567566
def test_device(self, xp):
568-
xp_test = array_namespace(xp.empty(0))
569567
devices = get_xp_devices(xp)
570568
for d in devices:
571569
y = fft.rfftfreq(9, xp=xp, device=d)
572-
x = xp_test.empty(0, device=d)
570+
x = xp.empty(0, device=d)
573571
assert xp_device(y) == xp_device(x)

scipy/integrate/tests/test_cubature.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -115,13 +115,10 @@ def genz_malik_1980_f_2_exact(a, b, alphas, betas, xp):
115115
a = xp.reshape(a, (*([1]*(len(alphas.shape) - 1)), ndim))
116116
b = xp.reshape(b, (*([1]*(len(alphas.shape) - 1)), ndim))
117117

118-
# `xp` is the unwrapped namespace, so `.atan` won't work for `xp = np` and np<2.
119-
xp_test = array_namespace(a)
120-
121118
return (
122119
(-1)**ndim * 1/xp.prod(alphas, axis=-1)
123120
* xp.prod(
124-
xp_test.atan((a - betas)/alphas) - xp_test.atan((b - betas)/alphas),
121+
xp.atan((a - betas)/alphas) - xp.atan((b - betas)/alphas),
125122
axis=-1,
126123
)
127124
)
@@ -1344,19 +1341,18 @@ def test_infinite_limits_maintains_points(self, a, b, points, xp):
13441341
transformation.
13451342
"""
13461343

1347-
xp_compat = array_namespace(xp.empty(0))
13481344
points = [xp.asarray(p, dtype=xp.float64) for p in points]
13491345

13501346
f_transformed = _InfiniteLimitsTransform(
13511347
# Bind `points` and `xp` argument in f
1352-
lambda x: f_with_problematic_points(x, points, xp_compat),
1353-
xp.asarray(a, dtype=xp_compat.float64),
1354-
xp.asarray(b, dtype=xp_compat.float64),
1355-
xp=xp_compat,
1348+
lambda x: f_with_problematic_points(x, points, xp),
1349+
xp.asarray(a, dtype=xp.float64),
1350+
xp.asarray(b, dtype=xp.float64),
1351+
xp=xp,
13561352
)
13571353

13581354
for point in points:
1359-
transformed_point = f_transformed.inv(xp_compat.reshape(point, (1, -1)))
1355+
transformed_point = f_transformed.inv(xp.reshape(point, (1, -1)))
13601356

13611357
with pytest.raises(Exception, match="called with a problematic point"):
13621358
f_transformed(transformed_point)

0 commit comments

Comments
 (0)