Skip to content

Commit 7d217eb

Browse files
lithomas1lucascolley
authored andcommitted
Fixes for fft/special modules (#14)
1 parent 2a26252 commit 7d217eb

File tree

4 files changed

+22
-13
lines changed

4 files changed

+22
-13
lines changed

scipy/_lib/meson.build

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ py3.install_sources(
199199
'array_api_compat/array_api_compat/numpy/_info.py',
200200
'array_api_compat/array_api_compat/numpy/fft.py',
201201
'array_api_compat/array_api_compat/numpy/linalg.py',
202+
'array_api_compat/array_api_compat/numpy/_info.py',
202203
],
203204
subdir: 'scipy/_lib/array_api_compat/numpy',
204205
)

scipy/fft/tests/test_helper.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -520,23 +520,24 @@ def test_definition(self, xp):
520520

521521
# default dtype varies across backends
522522

523-
y = 9 * fft.fftfreq(9, xp=xp)
523+
wrapped_xp = array_namespace(x)
524+
y = 9 * fft.fftfreq(9, xp=wrapped_xp)
524525
xp_assert_close(y, x, check_dtype=False, check_namespace=True)
525526

526-
y = 9 * xp.pi * fft.fftfreq(9, xp.pi, xp=xp)
527+
y = 9 * xp.pi * fft.fftfreq(9, xp.pi, xp=wrapped_xp)
527528
xp_assert_close(y, x, check_dtype=False)
528529

529-
y = 10 * fft.fftfreq(10, xp=xp)
530+
y = 10 * fft.fftfreq(10, xp=wrapped_xp)
530531
xp_assert_close(y, x2, check_dtype=False)
531532

532-
y = 10 * xp.pi * fft.fftfreq(10, xp.pi, xp=xp)
533+
y = 10 * xp.pi * fft.fftfreq(10, xp.pi, xp=wrapped_xp)
533534
xp_assert_close(y, x2, check_dtype=False)
534535

535536
def test_device(self, xp):
536537
xp_test = array_namespace(xp.empty(0))
537538
devices = get_xp_devices(xp)
538539
for d in devices:
539-
y = fft.fftfreq(9, xp=xp, device=d)
540+
y = fft.fftfreq(9, xp=xp_test, device=d)
540541
x = xp_test.empty(0, device=d)
541542
assert xp_device(y) == xp_device(x)
542543

@@ -552,23 +553,23 @@ def test_definition(self, xp):
552553
x2 = xp.asarray([0, 1, 2, 3, 4, 5], dtype=xp.float64)
553554

554555
# default dtype varies across backends
555-
556-
y = 9 * fft.rfftfreq(9, xp=xp)
556+
wrapped_xp = array_namespace(x)
557+
y = 9 * fft.rfftfreq(9, xp=wrapped_xp)
557558
xp_assert_close(y, x, check_dtype=False, check_namespace=True)
558559

559-
y = 9 * xp.pi * fft.rfftfreq(9, xp.pi, xp=xp)
560+
y = 9 * xp.pi * fft.rfftfreq(9, xp.pi, xp=wrapped_xp)
560561
xp_assert_close(y, x, check_dtype=False)
561562

562-
y = 10 * fft.rfftfreq(10, xp=xp)
563+
y = 10 * fft.rfftfreq(10, xp=wrapped_xp)
563564
xp_assert_close(y, x2, check_dtype=False)
564565

565-
y = 10 * xp.pi * fft.rfftfreq(10, xp.pi, xp=xp)
566+
y = 10 * xp.pi * fft.rfftfreq(10, xp.pi, xp=wrapped_xp)
566567
xp_assert_close(y, x2, check_dtype=False)
567568

568569
def test_device(self, xp):
569570
xp_test = array_namespace(xp.empty(0))
570571
devices = get_xp_devices(xp)
571572
for d in devices:
572-
y = fft.rfftfreq(9, xp=xp, device=d)
573+
y = fft.rfftfreq(9, xp=xp_test, device=d)
573574
x = xp_test.empty(0, device=d)
574575
assert xp_device(y) == xp_device(x)

scipy/special/_logsumexp.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,10 @@ def _logsumexp(a, b, axis, return_sign, xp):
202202
a_max, i_max = _elements_and_indices_with_max_real(a, axis=axis, xp=xp)
203203

204204
# for precision, these terms are separated out of the main sum.
205-
a[i_max] = -xp.inf
205+
# TODO: we shouldn't be mutating in-place here unless we make a copy
206+
# dask arrays do not copy before this somehow
207+
#a[i_max] = -xp.inf
208+
a = xp.where(i_max, -xp.asarray(xp.inf, dtype=a.dtype), a)
206209
i_max_dt = xp.astype(i_max, a.dtype)
207210
# This is an inefficient way of getting `m` because it is the sum of a sparse
208211
# array; however, this is the simplest way I can think of to get the right shape.

scipy/special/tests/test_logsumexp.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,11 @@ def test_logsumexp(self, xp):
6969
nan = xp.asarray([xp.nan])
7070
xp_assert_equal(logsumexp(inf), inf[0])
7171
xp_assert_equal(logsumexp(-inf), -inf[0])
72-
xp_assert_equal(logsumexp(nan), nan[0])
72+
# catch warnings here for dasks state there's no way to suppress
73+
# warnings just for dask
74+
# https://github.com/dask/dask/issues/3245
75+
with np.errstate(divide='ignore', invalid='ignore'):
76+
xp_assert_equal(logsumexp(nan), nan[0])
7377
xp_assert_equal(logsumexp(xp.asarray([-xp.inf, -xp.inf])), -inf[0])
7478

7579
# Handling an array with different magnitudes on the axes

0 commit comments

Comments
 (0)