Skip to content

Commit cf31403

Browse files
committed
Merge branch 'main' into dask-new
[skip cirrus] [skip circle]
2 parents b43387c + 564da79 commit cf31403

28 files changed

+232
-335
lines changed

scipy/_lib/_array_api.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -253,9 +253,6 @@ def _strict_check(actual, desired, xp, *,
253253
xp = _default_xp_ctxvar.get()
254254
except LookupError:
255255
xp = array_namespace(desired)
256-
else:
257-
# Wrap namespace if needed
258-
xp = array_namespace(xp.asarray(0))
259256

260257
if check_namespace:
261258
_assert_matching_namespace(actual, desired, xp)

scipy/_lib/tests/test__util.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -615,8 +615,8 @@ def f2(*args):
615615

616616
# Ensure arrays are at least 1d to follow sane type promotion rules.
617617
# This can be removed when minimum supported NumPy is 2.0
618-
if xp == np:
619-
cond, fillvalue, *arrays = np.atleast_1d(cond, fillvalue, *arrays)
618+
if is_numpy(xp):
619+
cond, fillvalue, *arrays = xp.atleast_1d(cond, fillvalue, *arrays)
620620

621621
ref1 = xp.where(cond, f(*arrays), fillvalue)
622622
ref2 = xp.where(cond, f(*arrays), f2(*arrays))
@@ -625,7 +625,7 @@ def f2(*args):
625625
# Python scalar. When it does, test can be run with array_api_strict, too.
626626
ref3 = xp.where(cond, f(*arrays), float_fillvalue)
627627

628-
if xp == np: # because we ensured arrays are at least 1d
628+
if is_numpy(xp): # because we ensured arrays are at least 1d
629629
ref1 = ref1.reshape(result_shape)
630630
ref2 = ref2.reshape(result_shape)
631631
ref3 = ref3.reshape(result_shape)

scipy/_lib/tests/test_array_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def test_strict_checks(self, xp, dtype, shape):
8484

8585
kwarg_names = ["check_namespace", "check_dtype", "check_shape", "check_0d"]
8686
options = dict(zip(kwarg_names, [True, False, False, False]))
87-
if xp == np:
87+
if is_numpy(xp):
8888
xp_assert_equal(x, y, **options)
8989
else:
9090
with pytest.raises(

scipy/cluster/tests/test_vq.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from scipy._lib import array_api_extra as xpx
1919
from scipy._lib._array_api import (
20-
SCIPY_ARRAY_API, array_namespace, xp_copy, xp_assert_close, xp_assert_equal
20+
SCIPY_ARRAY_API, xp_copy, xp_assert_close, xp_assert_equal
2121
)
2222

2323
skip_xp_backends = pytest.mark.skip_xp_backends
@@ -359,13 +359,12 @@ def test_krandinit(self, xp, krand_lock):
359359
datas = [xp.reshape(data, (200, 2)),
360360
xp.reshape(data, (20, 20))[:10, :]]
361361
k = int(1e6)
362-
xp_test = array_namespace(data)
363362
with krand_lock:
364363
for data in datas:
365364
rng = np.random.default_rng(1234)
366-
init = _krandinit(data, k, rng, xp_test)
367-
orig_cov = xpx.cov(data.T, xp=xp_test)
368-
init_cov = xpx.cov(init.T, xp=xp_test)
365+
init = _krandinit(data, k, rng, xp)
366+
orig_cov = xpx.cov(data.T, xp=xp)
367+
init_cov = xpx.cov(init.T, xp=xp)
369368
xp_assert_close(orig_cov, init_cov, atol=1.1e-2)
370369

371370
def test_kmeans2_empty(self, xp):

scipy/conftest.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313

1414
from scipy._lib._fpumode import get_fpu_mode
1515
from scipy._lib._testutils import FPUModeChangeWarning
16-
from scipy._lib._array_api import SCIPY_ARRAY_API, SCIPY_DEVICE, xp_device
16+
from scipy._lib._array_api import (
17+
SCIPY_ARRAY_API, SCIPY_DEVICE, array_namespace, xp_device
18+
)
1719
from scipy._lib import _pep440
1820

1921
try:
@@ -233,17 +235,21 @@ def xp(request):
233235
# if any, and raise pytest.xfail() if the current xp is in the list.
234236
skip_or_xfail_xp_backends(request, "xfail")
235237

238+
xp = request.param
239+
# Potentially wrap namespace with array_api_compat
240+
xp = array_namespace(xp.empty(0))
241+
236242
if SCIPY_ARRAY_API:
237243
from scipy._lib._array_api import default_xp
238244

239245
# Throughout all calls to assert_almost_equal, assert_array_almost_equal, and
240246
# xp_assert_* functions, test that the array namespace is xp in both the
241247
# expected and actual arrays. This is to detect the case where both arrays are
242248
# erroneously just plain numpy while xp is something else.
243-
with default_xp(request.param):
244-
yield request.param
249+
with default_xp(xp):
250+
yield xp
245251
else:
246-
yield request.param
252+
yield xp
247253

248254

249255
skip_xp_invalid_arg = pytest.mark.skipif(SCIPY_ARRAY_API,

scipy/differentiate/tests/test_differentiate.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import scipy._lib._elementwise_iterative_method as eim
77
from scipy._lib._array_api_no_0d import xp_assert_close, xp_assert_equal, xp_assert_less
8-
from scipy._lib._array_api import is_numpy, is_torch, array_namespace
8+
from scipy._lib._array_api import is_numpy, is_torch, xp_size
99

1010
from scipy import stats, optimize, special
1111
from scipy.differentiate import derivative, jacobian, hessian
@@ -377,8 +377,7 @@ def test_special_cases(self, xp):
377377
# Test that integers are not passed to `f`
378378
# (otherwise this would overflow)
379379
def f(x):
380-
xp_test = array_namespace(x) # needs `isdtype`
381-
assert xp_test.isdtype(x.dtype, 'real floating')
380+
assert xp.isdtype(x.dtype, 'real floating')
382381
return x ** 99 - 1
383382

384383
if not is_torch(xp): # torch defaults to float32
@@ -584,9 +583,10 @@ def df1_1xy(x, y):
584583
return xp.sin(2*x) * y**2
585584

586585
res = jacobian(df1, z, initial_step=10)
587-
if is_numpy(xp):
588-
assert len(np.unique(res.nit)) == 4
589-
assert len(np.unique(res.nfev)) == 4
586+
# FIXME https://github.com/scipy/scipy/pull/22320#discussion_r1914898175
587+
if not is_torch(xp):
588+
assert xp_size(xp.unique_values(res.nit)) == 4
589+
assert xp_size(xp.unique_values(res.nfev)) == 4
590590

591591
res00 = jacobian(lambda x: df1_0xy(x, z[1]), z[0:1], initial_step=10)
592592
res01 = jacobian(lambda y: df1_0xy(z[0], y), z[1:2], initial_step=10)
@@ -596,7 +596,10 @@ def df1_1xy(x, y):
596596
for attr in ['success', 'status', 'df', 'nit', 'nfev']:
597597
ref_attr = xp.asarray([[getattr(res00, attr), getattr(res01, attr)],
598598
[getattr(res10, attr), getattr(res11, attr)]])
599-
ref[attr] = xp.squeeze(ref_attr)
599+
ref[attr] = xp.squeeze(
600+
ref_attr,
601+
axis=tuple(ax for ax, size in enumerate(ref_attr.shape) if size == 1)
602+
)
600603
rtol = 1.5e-5 if res[attr].dtype == xp.float32 else 1.5e-14
601604
xp_assert_close(res[attr], ref[attr], rtol=rtol)
602605

@@ -662,10 +665,9 @@ def test_float32(self, xp):
662665

663666
def test_nfev(self, xp):
664667
z = xp.asarray([0.5, 0.25])
665-
xp_test = array_namespace(z)
666668

667669
def f1(z):
668-
x, y = xp_test.broadcast_arrays(*z)
670+
x, y = xp.broadcast_arrays(*z)
669671
f1.nfev = f1.nfev + (math.prod(x.shape[2:]) if x.ndim > 2 else 1)
670672
return xp.sin(x) * y ** 3
671673
f1.nfev = 0

scipy/fft/tests/test_basic.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from pytest import raises as assert_raises
99
import scipy.fft as fft
1010
from scipy._lib._array_api import (
11-
array_namespace, xp_size, xp_assert_close, xp_assert_equal
11+
is_numpy, xp_size, xp_assert_close, xp_assert_equal
1212
)
1313

1414
skip_xp_backends = pytest.mark.skip_xp_backends
@@ -228,10 +228,10 @@ def _check_axes(self, op, xp):
228228
dtype = get_expected_input_dtype(op, xp)
229229
x = xp.asarray(random((30, 20, 10)), dtype=dtype)
230230
axes = [(0, 1, 2), (0, 2, 1), (1, 0, 2), (1, 2, 0), (2, 0, 1), (2, 1, 0)]
231-
xp_test = array_namespace(x)
231+
232232
for a in axes:
233-
op_tr = op(xp_test.permute_dims(x, axes=a))
234-
tr_op = xp_test.permute_dims(op(x, axes=a), axes=a)
233+
op_tr = op(xp.permute_dims(x, axes=a))
234+
tr_op = xp.permute_dims(op(x, axes=a), axes=a)
235235
xp_assert_close(op_tr, tr_op)
236236

237237
@pytest.mark.parametrize("op", [fft.fftn, fft.ifftn, fft.rfftn, fft.irfftn])
@@ -248,15 +248,15 @@ def test_axes_subset_with_shape_standard(self, op, xp):
248248
dtype = get_expected_input_dtype(op, xp)
249249
x = xp.asarray(random((16, 8, 4)), dtype=dtype)
250250
axes = [(0, 1, 2), (0, 2, 1), (1, 2, 0)]
251-
xp_test = array_namespace(x)
251+
252252
for a in axes:
253253
# different shape on the first two axes
254254
shape = tuple([2*x.shape[ax] if ax in a[:2] else x.shape[ax]
255255
for ax in range(x.ndim)])
256256
# transform only the first two axes
257-
op_tr = op(xp_test.permute_dims(x, axes=a),
257+
op_tr = op(xp.permute_dims(x, axes=a),
258258
s=shape[:2], axes=(0, 1))
259-
tr_op = xp_test.permute_dims(op(x, s=shape[:2], axes=a[:2]),
259+
tr_op = xp.permute_dims(op(x, s=shape[:2], axes=a[:2]),
260260
axes=a)
261261
xp_assert_close(op_tr, tr_op)
262262

@@ -268,21 +268,21 @@ def test_axes_subset_with_shape_non_standard(self, op, xp):
268268
dtype = get_expected_input_dtype(op, xp)
269269
x = xp.asarray(random((16, 8, 4)), dtype=dtype)
270270
axes = [(0, 1, 2), (0, 2, 1), (1, 2, 0)]
271-
xp_test = array_namespace(x)
271+
272272
for a in axes:
273273
# different shape on the first two axes
274274
shape = tuple([2*x.shape[ax] if ax in a[:2] else x.shape[ax]
275275
for ax in range(x.ndim)])
276276
# transform only the first two axes
277-
op_tr = op(xp_test.permute_dims(x, axes=a), s=shape[:2], axes=(0, 1))
278-
tr_op = xp_test.permute_dims(op(x, s=shape[:2], axes=a[:2]), axes=a)
277+
op_tr = op(xp.permute_dims(x, axes=a), s=shape[:2], axes=(0, 1))
278+
tr_op = xp.permute_dims(op(x, s=shape[:2], axes=a[:2]), axes=a)
279279
xp_assert_close(op_tr, tr_op)
280280

281281
def test_all_1d_norm_preserving(self, xp):
282282
# verify that round-trip transforms are norm-preserving
283283
x = xp.asarray(random(30), dtype=xp.float64)
284-
xp_test = array_namespace(x)
285-
x_norm = xp_test.linalg.vector_norm(x)
284+
285+
x_norm = xp.linalg.vector_norm(x)
286286
n = xp_size(x) * 2
287287
func_pairs = [(fft.rfft, fft.irfft),
288288
# hfft: order so the first function takes x.size samples
@@ -294,12 +294,12 @@ def test_all_1d_norm_preserving(self, xp):
294294
for forw, back in func_pairs:
295295
if forw == fft.fft:
296296
x = xp.asarray(x, dtype=xp.complex128)
297-
x_norm = xp_test.linalg.vector_norm(x)
297+
x_norm = xp.linalg.vector_norm(x)
298298
for n in [xp_size(x), 2*xp_size(x)]:
299299
for norm in ['backward', 'ortho', 'forward']:
300300
tmp = forw(x, n=n, norm=norm)
301301
tmp = back(tmp, n=n, norm=norm)
302-
xp_assert_close(xp_test.linalg.vector_norm(tmp), x_norm)
302+
xp_assert_close(xp.linalg.vector_norm(tmp), x_norm)
303303

304304
@skip_xp_backends(np_only=True)
305305
@pytest.mark.parametrize("dtype", [np.float16, np.longdouble])
@@ -482,13 +482,17 @@ def test_non_standard_params(func, xp):
482482
else:
483483
dtype = xp.complex128
484484

485-
if xp.__name__ != 'numpy':
486-
x = xp.asarray([1, 2, 3], dtype=dtype)
487-
# func(x) should not raise an exception
488-
func(x)
485+
x = xp.asarray([1, 2, 3], dtype=dtype)
486+
# func(x) should not raise an exception
487+
func(x)
488+
489+
if is_numpy(xp):
490+
func(x, workers=2)
491+
else:
489492
assert_raises(ValueError, func, x, workers=2)
490-
# `plan` param is not tested since SciPy does not use it currently
491-
# but should be tested if it comes into use
493+
494+
# `plan` param is not tested since SciPy does not use it currently
495+
# but should be tested if it comes into use
492496

493497

494498
@pytest.mark.parametrize("dtype", ['float32', 'float64'])

scipy/fft/tests/test_fftlog.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from scipy.fft._fftlog import fht, ifht, fhtoffset
88
from scipy.special import poch
99

10-
from scipy._lib._array_api import xp_assert_close, xp_assert_less, array_namespace
10+
from scipy._lib._array_api import xp_assert_close, xp_assert_less
1111

1212
skip_xp_backends = pytest.mark.skip_xp_backends
1313

@@ -186,13 +186,12 @@ def test_array_like(xp, op):
186186
@pytest.mark.parametrize('n', [128, 129])
187187
def test_gh_21661(xp, n):
188188
one = xp.asarray(1.0)
189-
xp_test = array_namespace(one)
190189
mu = 0.0
191190
r = np.logspace(-7, 1, n)
192191
dln = math.log(r[1] / r[0])
193192
offset = fhtoffset(dln, initial=-6 * np.log(10), mu=mu)
194193
r = xp.asarray(r, dtype=one.dtype)
195-
k = math.exp(offset) / xp_test.flip(r, axis=-1)
194+
k = math.exp(offset) / xp.flip(r, axis=-1)
196195

197196
def f(x, mu):
198197
return x**(mu + 1)*xp.exp(-x**2/2)

scipy/fft/tests/test_helper.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import numpy as np
1212
import sys
1313
from scipy._lib._array_api import (
14-
xp_assert_close, get_xp_devices, xp_device, array_namespace
14+
xp_assert_close, get_xp_devices, xp_device
1515
)
1616
from scipy import fft
1717

@@ -518,26 +518,23 @@ def test_definition(self, xp):
518518
x2 = xp.asarray([0, 1, 2, 3, 4, -5, -4, -3, -2, -1], dtype=xp.float64)
519519

520520
# default dtype varies across backends
521-
522-
wrapped_xp = array_namespace(x)
523-
y = 9 * fft.fftfreq(9, xp=wrapped_xp)
521+
y = 9 * fft.fftfreq(9, xp=xp)
524522
xp_assert_close(y, x, check_dtype=False, check_namespace=True)
525523

526-
y = 9 * xp.pi * fft.fftfreq(9, xp.pi, xp=wrapped_xp)
524+
y = 9 * xp.pi * fft.fftfreq(9, xp.pi, xp=xp)
527525
xp_assert_close(y, x, check_dtype=False)
528526

529-
y = 10 * fft.fftfreq(10, xp=wrapped_xp)
527+
y = 10 * fft.fftfreq(10, xp=xp)
530528
xp_assert_close(y, x2, check_dtype=False)
531529

532-
y = 10 * xp.pi * fft.fftfreq(10, xp.pi, xp=wrapped_xp)
530+
y = 10 * xp.pi * fft.fftfreq(10, xp.pi, xp=xp)
533531
xp_assert_close(y, x2, check_dtype=False)
534532

535533
def test_device(self, xp):
536-
xp_test = array_namespace(xp.empty(0))
537534
devices = get_xp_devices(xp)
538535
for d in devices:
539-
y = fft.fftfreq(9, xp=xp_test, device=d)
540-
x = xp_test.empty(0, device=d)
536+
y = fft.fftfreq(9, xp=xp, device=d)
537+
x = xp.empty(0, device=d)
541538
assert xp_device(y) == xp_device(x)
542539

543540

@@ -552,23 +549,21 @@ def test_definition(self, xp):
552549
x2 = xp.asarray([0, 1, 2, 3, 4, 5], dtype=xp.float64)
553550

554551
# default dtype varies across backends
555-
wrapped_xp = array_namespace(x)
556-
y = 9 * fft.rfftfreq(9, xp=wrapped_xp)
552+
y = 9 * fft.rfftfreq(9, xp=xp)
557553
xp_assert_close(y, x, check_dtype=False, check_namespace=True)
558554

559-
y = 9 * xp.pi * fft.rfftfreq(9, xp.pi, xp=wrapped_xp)
555+
y = 9 * xp.pi * fft.rfftfreq(9, xp.pi, xp=xp)
560556
xp_assert_close(y, x, check_dtype=False)
561557

562-
y = 10 * fft.rfftfreq(10, xp=wrapped_xp)
558+
y = 10 * fft.rfftfreq(10, xp=xp)
563559
xp_assert_close(y, x2, check_dtype=False)
564560

565-
y = 10 * xp.pi * fft.rfftfreq(10, xp.pi, xp=wrapped_xp)
561+
y = 10 * xp.pi * fft.rfftfreq(10, xp.pi, xp=xp)
566562
xp_assert_close(y, x2, check_dtype=False)
567563

568564
def test_device(self, xp):
569-
xp_test = array_namespace(xp.empty(0))
570565
devices = get_xp_devices(xp)
571566
for d in devices:
572-
y = fft.rfftfreq(9, xp=xp_test, device=d)
573-
x = xp_test.empty(0, device=d)
567+
y = fft.rfftfreq(9, xp=xp, device=d)
568+
x = xp.empty(0, device=d)
574569
assert xp_device(y) == xp_device(x)

0 commit comments

Comments
 (0)