Skip to content

Commit 24581ed

Browse files
mdhabertylerjereddy
authored andcommitted
MAINT: stats.zmap: restore support for complex data (scipy#22405)
* MAINT: stats.zmap: restore support for complex data * MAINT: stats.zmap: fix masked array support * Apply suggestions from code review
1 parent a29e8de commit 24581ed

File tree

4 files changed

+43
-26
lines changed

4 files changed

+43
-26
lines changed

scipy/_lib/_array_api.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -529,7 +529,7 @@ def xp_take_along_axis(arr: Array,
529529
def xp_broadcast_promote(*args, ensure_writeable=False, force_floating=False, xp=None):
530530
xp = array_namespace(*args) if xp is None else xp
531531

532-
args = [(xp.asarray(arg) if arg is not None else arg) for arg in args]
532+
args = [(_asarray(arg, subok=True) if arg is not None else arg) for arg in args]
533533
args_not_none = [arg for arg in args if arg is not None]
534534

535535
# determine minimum dtype
@@ -553,7 +553,12 @@ def xp_broadcast_promote(*args, ensure_writeable=False, force_floating=False, xp
553553

554554
# determine result shape
555555
shapes = {arg.shape for arg in args_not_none}
556-
shape = np.broadcast_shapes(*shapes) if len(shapes) != 1 else args_not_none[0].shape
556+
try:
557+
shape = (np.broadcast_shapes(*shapes) if len(shapes) != 1
558+
else args_not_none[0].shape)
559+
except ValueError as e:
560+
message = "Array shapes are incompatible for broadcasting."
561+
raise ValueError(message) from e
557562

558563
out = []
559564
for arg in args:
@@ -565,7 +570,8 @@ def xp_broadcast_promote(*args, ensure_writeable=False, force_floating=False, xp
565570
# Even if two arguments need broadcasting, this is faster than
566571
# `broadcast_arrays`, especially since we've already determined `shape`
567572
if arg.shape != shape:
568-
arg = xp.broadcast_to(arg, shape)
573+
kwargs = {'subok': True} if is_numpy(xp) else {}
574+
arg = xp.broadcast_to(arg, shape, **kwargs)
569575

570576
# convert dtype/copy only if needed
571577
if (arg.dtype != dtype) or ensure_writeable:

scipy/_lib/_util.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -939,13 +939,14 @@ def _contains_nan(a, nan_policy='propagate', policies=None, *,
939939
if nan_policy not in policies:
940940
raise ValueError(f"nan_policy must be one of {set(policies)}.")
941941

942-
inexact = (xp.isdtype(a.dtype, "real floating")
943-
or xp.isdtype(a.dtype, "complex floating"))
944-
if xp_size(a) == 0:
945-
contains_nan = False
946-
elif inexact:
947-
# Faster and less memory-intensive than xp.any(xp.isnan(a))
942+
if xp.isdtype(a.dtype, "real floating"):
943+
# Faster and less memory-intensive than xp.any(xp.isnan(a)), and unlike other
944+
# reductions, `max`/`min` won't return NaN unless there is a NaN in the data.
948945
contains_nan = xp.isnan(xp.max(a))
946+
elif xp.isdtype(a.dtype, "complex floating"):
947+
# Typically `real` and `imag` produce views; otherwise, `xp.any(xp.isnan(a))`
948+
# would be more efficient.
949+
contains_nan = xp.isnan(xp.max(xp.real(a))) | xp.isnan(xp.max(xp.imag(a)))
949950
elif is_numpy(xp) and np.issubdtype(a.dtype, object):
950951
contains_nan = False
951952
for el in a.ravel():

scipy/stats/_stats_py.py

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
from ._resampling import (MonteCarloMethod, PermutationMethod, BootstrapMethod,
6262
monte_carlo_test, permutation_test, bootstrap,
6363
_batch_generator)
64-
from ._axis_nan_policy import (_axis_nan_policy_factory, _broadcast_arrays,
64+
from ._axis_nan_policy import (_axis_nan_policy_factory,
6565
_broadcast_concatenate, _broadcast_shapes,
6666
_broadcast_array_shapes_remove_axis, SmallSampleWarning,
6767
too_small_1d_not_omit, too_small_1d_omit,
@@ -79,6 +79,7 @@
7979
xp_moveaxis_to_end,
8080
xp_sign,
8181
xp_vector_norm,
82+
xp_broadcast_promote,
8283
)
8384
from scipy._lib import array_api_extra as xpx
8485
from scipy._lib.deprecation import _deprecated
@@ -10843,21 +10844,7 @@ def _xp_mean(x, /, *, axis=None, weights=None, keepdims=False, nan_policy='propa
1084310844
or (weights is not None and xp_size(weights) == 0)):
1084410845
return gmean(x, weights=weights, axis=axis, keepdims=keepdims)
1084510846

10846-
# handle non-broadcastable inputs
10847-
if weights is not None and x.shape != weights.shape:
10848-
try:
10849-
x, weights = _broadcast_arrays((x, weights), xp=xp)
10850-
except (ValueError, RuntimeError) as e:
10851-
message = "Array shapes are incompatible for broadcasting."
10852-
raise ValueError(message) from e
10853-
10854-
# convert integers to the default float of the array library
10855-
if not xp.isdtype(x.dtype, 'real floating'):
10856-
dtype = xp.asarray(1.).dtype
10857-
x = xp.asarray(x, dtype=dtype)
10858-
if weights is not None and not xp.isdtype(weights.dtype, 'real floating'):
10859-
dtype = xp.asarray(1.).dtype
10860-
weights = xp.asarray(weights, dtype=dtype)
10847+
x, weights = xp_broadcast_promote(x, weights, force_floating=True)
1086110848

1086210849
# handle the special case of zero-sized arrays
1086310850
message = (too_small_1d_not_omit if (x.ndim == 1 or axis is None)
@@ -10932,7 +10919,9 @@ def _xp_var(x, /, *, axis=None, correction=0, keepdims=False, nan_policy='propag
1093210919
mean = _xp_mean(x, keepdims=True, **kwargs)
1093310920
x = _asarray(x, dtype=mean.dtype, subok=True)
1093410921
x_mean = _demean(x, mean, axis, xp=xp)
10935-
var = _xp_mean(x_mean**2, keepdims=keepdims, **kwargs)
10922+
x_mean_conj = (xp.conj(x_mean) if xp.isdtype(x_mean.dtype, 'complex floating')
10923+
else x_mean) # crossref data-apis/array-api#824
10924+
var = _xp_mean(x_mean * x_mean_conj, keepdims=keepdims, **kwargs)
1093610925

1093710926
if correction != 0:
1093810927
if axis is None:

scipy/stats/tests/test_stats.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3167,6 +3167,12 @@ def test_degenerate_input(self, xp):
31673167
res = stats.zmap(scores, compare)
31683168
xp_assert_equal(res, ref)
31693169

3170+
@pytest.mark.skip_xp_backends('array_api_strict', reason='needs array-api#850')
3171+
def test_complex_gh22404(self, xp):
3172+
res = stats.zmap(xp.asarray([1, 2, 3, 4]), xp.asarray([1, 1j, -1, -1j]))
3173+
ref = xp.asarray([1.+0.j, 2.+0.j, 3.+0.j, 4.+0.j])
3174+
xp_assert_close(res, ref)
3175+
31703176

31713177
class TestMedianAbsDeviation:
31723178
def setup_class(self):
@@ -9681,6 +9687,13 @@ def test_integer(self, xp):
96819687
xp_assert_equal(_xp_mean(x), _xp_mean(y))
96829688
xp_assert_equal(_xp_mean(y, weights=x), _xp_mean(y, weights=y))
96839689

9690+
def test_complex_gh22404(self, xp):
9691+
rng = np.random.default_rng(90359458245906)
9692+
x, y, wx, wy = rng.random((4, 20))
9693+
res = _xp_mean(xp.asarray(x + y*1j), weights=xp.asarray(wx + wy*1j))
9694+
ref = np.average(x + y*1j, weights=wx + wy*1j)
9695+
xp_assert_close(res, xp.asarray(ref))
9696+
96849697

96859698
@array_api_compatible
96869699
@pytest.mark.usefixtures("skip_xp_backends")
@@ -9786,6 +9799,14 @@ def test_integer(self, xp):
97869799
y = xp.arange(10.)
97879800
xp_assert_equal(_xp_var(x), _xp_var(y))
97889801

9802+
@pytest.mark.skip_xp_backends('array_api_strict', reason='needs array-api#850')
9803+
def test_complex_gh22404(self, xp):
9804+
rng = np.random.default_rng(90359458245906)
9805+
x, y = rng.random((2, 20))
9806+
res = _xp_var(xp.asarray(x + y*1j))
9807+
ref = np.var(x + y*1j)
9808+
xp_assert_close(res, xp.asarray(ref), check_dtype=False)
9809+
97899810

97909811
@array_api_compatible
97919812
def test_chk_asarray(xp):

0 commit comments

Comments
 (0)