Skip to content

Commit 21d8645

Browse files
headtr1ckdcherian
andauthored
Support complex arrays in xr.corr (#7392)
* complex cov * fix mypy * update whatsa-new * Update xarray/core/computation.py * slight improvements to tests * bugfix in corr_cov for multiple dims * fix whats-new * allow refreshing of backends * Revert "allow refreshing of backends" This reverts commit 576692b. --------- Co-authored-by: Deepak Cherian <dcherian@users.noreply.github.com>
1 parent cd90184 commit 21d8645

File tree

3 files changed

+82
-60
lines changed

3 files changed

+82
-60
lines changed

doc/whats-new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ v2023.03.0 (unreleased)
2323
New Features
2424
~~~~~~~~~~~~
2525

26+
- Fix :py:meth:`xr.cov` and :py:meth:`xr.corr` now support complex valued arrays (:issue:`7340`, :pull:`7392`).
27+
By `Michael Niklas <https://github.com/headtr1ck>`_.
2628

2729
Breaking changes
2830
~~~~~~~~~~~~~~~~

xarray/core/computation.py

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,16 @@
99
import warnings
1010
from collections import Counter
1111
from collections.abc import Hashable, Iterable, Mapping, Sequence
12-
from typing import TYPE_CHECKING, AbstractSet, Any, Callable, TypeVar, Union, overload
12+
from typing import (
13+
TYPE_CHECKING,
14+
AbstractSet,
15+
Any,
16+
Callable,
17+
Literal,
18+
TypeVar,
19+
Union,
20+
overload,
21+
)
1322

1423
import numpy as np
1524

@@ -21,7 +30,7 @@
2130
from xarray.core.merge import merge_attrs, merge_coordinates_without_align
2231
from xarray.core.options import OPTIONS, _get_keep_attrs
2332
from xarray.core.pycompat import is_duck_dask_array
24-
from xarray.core.types import T_DataArray
33+
from xarray.core.types import Dims, T_DataArray
2534
from xarray.core.utils import is_dict_like, is_scalar
2635
from xarray.core.variable import Variable
2736

@@ -1209,7 +1218,9 @@ def apply_ufunc(
12091218
return apply_array_ufunc(func, *args, dask=dask)
12101219

12111220

1212-
def cov(da_a, da_b, dim=None, ddof=1):
1221+
def cov(
1222+
da_a: T_DataArray, da_b: T_DataArray, dim: Dims = None, ddof: int = 1
1223+
) -> T_DataArray:
12131224
"""
12141225
Compute covariance between two DataArray objects along a shared dimension.
12151226
@@ -1219,9 +1230,9 @@ def cov(da_a, da_b, dim=None, ddof=1):
12191230
Array to compute.
12201231
da_b : DataArray
12211232
Array to compute.
1222-
dim : str, optional
1233+
dim : str, iterable of hashable, "..." or None, optional
12231234
The dimension along which the covariance will be computed
1224-
ddof : int, optional
1235+
ddof : int, default: 1
12251236
If ddof=1, covariance is normalized by N-1, giving an unbiased estimate,
12261237
else normalization is by N.
12271238
@@ -1289,7 +1300,7 @@ def cov(da_a, da_b, dim=None, ddof=1):
12891300
return _cov_corr(da_a, da_b, dim=dim, ddof=ddof, method="cov")
12901301

12911302

1292-
def corr(da_a, da_b, dim=None):
1303+
def corr(da_a: T_DataArray, da_b: T_DataArray, dim: Dims = None) -> T_DataArray:
12931304
"""
12941305
Compute the Pearson correlation coefficient between
12951306
two DataArray objects along a shared dimension.
@@ -1300,7 +1311,7 @@ def corr(da_a, da_b, dim=None):
13001311
Array to compute.
13011312
da_b : DataArray
13021313
Array to compute.
1303-
dim : str, optional
1314+
dim : str, iterable of hashable, "..." or None, optional
13041315
The dimension along which the correlation will be computed
13051316
13061317
Returns
@@ -1368,7 +1379,11 @@ def corr(da_a, da_b, dim=None):
13681379

13691380

13701381
def _cov_corr(
1371-
da_a: T_DataArray, da_b: T_DataArray, dim=None, ddof=0, method=None
1382+
da_a: T_DataArray,
1383+
da_b: T_DataArray,
1384+
dim: Dims = None,
1385+
ddof: int = 0,
1386+
method: Literal["cov", "corr", None] = None,
13721387
) -> T_DataArray:
13731388
"""
13741389
Internal method for xr.cov() and xr.corr() so only have to
@@ -1388,22 +1403,21 @@ def _cov_corr(
13881403
demeaned_da_b = da_b - da_b.mean(dim=dim)
13891404

13901405
# 4. Compute covariance along the given dim
1391-
#
13921406
# N.B. `skipna=True` is required or auto-covariance is computed incorrectly. E.g.
13931407
# Try xr.cov(da,da) for da = xr.DataArray([[1, 2], [1, np.nan]], dims=["x", "time"])
1394-
cov = (demeaned_da_a * demeaned_da_b).sum(dim=dim, skipna=True, min_count=1) / (
1395-
valid_count
1396-
)
1408+
cov = (demeaned_da_a.conj() * demeaned_da_b).sum(
1409+
dim=dim, skipna=True, min_count=1
1410+
) / (valid_count)
13971411

13981412
if method == "cov":
1399-
return cov
1413+
return cov # type: ignore[return-value]
14001414

14011415
else:
14021416
# compute std + corr
14031417
da_a_std = da_a.std(dim=dim)
14041418
da_b_std = da_b.std(dim=dim)
14051419
corr = cov / (da_a_std * da_b_std)
1406-
return corr
1420+
return corr # type: ignore[return-value]
14071421

14081422

14091423
def cross(
@@ -1616,7 +1630,7 @@ def cross(
16161630

16171631
def dot(
16181632
*arrays,
1619-
dims: str | Iterable[Hashable] | ellipsis | None = None,
1633+
dims: Dims = None,
16201634
**kwargs: Any,
16211635
):
16221636
"""Generalized dot product for xarray objects. Like np.einsum, but
@@ -1626,7 +1640,7 @@ def dot(
16261640
----------
16271641
*arrays : DataArray or Variable
16281642
Arrays to compute.
1629-
dims : ..., str or tuple of str, optional
1643+
dims : str, iterable of hashable, "..." or None, optional
16301644
Which dimensions to sum over. Ellipsis ('...') sums over all dimensions.
16311645
If not specified, then all the common dimensions are summed over.
16321646
**kwargs : dict

xarray/tests/test_computation.py

Lines changed: 50 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1387,25 +1387,29 @@ def test_vectorize_exclude_dims_dask() -> None:
13871387

13881388
def test_corr_only_dataarray() -> None:
13891389
with pytest.raises(TypeError, match="Only xr.DataArray is supported"):
1390-
xr.corr(xr.Dataset(), xr.Dataset())
1390+
xr.corr(xr.Dataset(), xr.Dataset()) # type: ignore[type-var]
13911391

13921392

1393-
def arrays_w_tuples():
1393+
@pytest.fixture(scope="module")
1394+
def arrays():
13941395
da = xr.DataArray(
13951396
np.random.random((3, 21, 4)),
13961397
coords={"time": pd.date_range("2000-01-01", freq="1D", periods=21)},
13971398
dims=("a", "time", "x"),
13981399
)
13991400

1400-
arrays = [
1401+
return [
14011402
da.isel(time=range(0, 18)),
14021403
da.isel(time=range(2, 20)).rolling(time=3, center=True).mean(),
14031404
xr.DataArray([[1, 2], [1, np.nan]], dims=["x", "time"]),
14041405
xr.DataArray([[1, 2], [np.nan, np.nan]], dims=["x", "time"]),
14051406
xr.DataArray([[1, 2], [2, 1]], dims=["x", "time"]),
14061407
]
14071408

1408-
array_tuples = [
1409+
1410+
@pytest.fixture(scope="module")
1411+
def array_tuples(arrays):
1412+
return [
14091413
(arrays[0], arrays[0]),
14101414
(arrays[0], arrays[1]),
14111415
(arrays[1], arrays[1]),
@@ -1417,27 +1421,19 @@ def arrays_w_tuples():
14171421
(arrays[4], arrays[4]),
14181422
]
14191423

1420-
return arrays, array_tuples
1421-
14221424

14231425
@pytest.mark.parametrize("ddof", [0, 1])
1424-
@pytest.mark.parametrize(
1425-
"da_a, da_b",
1426-
[
1427-
arrays_w_tuples()[1][3],
1428-
arrays_w_tuples()[1][4],
1429-
arrays_w_tuples()[1][5],
1430-
arrays_w_tuples()[1][6],
1431-
arrays_w_tuples()[1][7],
1432-
arrays_w_tuples()[1][8],
1433-
],
1434-
)
1426+
@pytest.mark.parametrize("n", [3, 4, 5, 6, 7, 8])
14351427
@pytest.mark.parametrize("dim", [None, "x", "time"])
14361428
@requires_dask
1437-
def test_lazy_corrcov(da_a, da_b, dim, ddof) -> None:
1429+
def test_lazy_corrcov(
1430+
n: int, dim: str | None, ddof: int, array_tuples: tuple[xr.DataArray, xr.DataArray]
1431+
) -> None:
14381432
# GH 5284
14391433
from dask import is_dask_collection
14401434

1435+
da_a, da_b = array_tuples[n]
1436+
14411437
with raise_if_dask_computes():
14421438
cov = xr.cov(da_a.chunk(), da_b.chunk(), dim=dim, ddof=ddof)
14431439
assert is_dask_collection(cov)
@@ -1447,12 +1443,13 @@ def test_lazy_corrcov(da_a, da_b, dim, ddof) -> None:
14471443

14481444

14491445
@pytest.mark.parametrize("ddof", [0, 1])
1450-
@pytest.mark.parametrize(
1451-
"da_a, da_b",
1452-
[arrays_w_tuples()[1][0], arrays_w_tuples()[1][1], arrays_w_tuples()[1][2]],
1453-
)
1446+
@pytest.mark.parametrize("n", [0, 1, 2])
14541447
@pytest.mark.parametrize("dim", [None, "time"])
1455-
def test_cov(da_a, da_b, dim, ddof) -> None:
1448+
def test_cov(
1449+
n: int, dim: str | None, ddof: int, array_tuples: tuple[xr.DataArray, xr.DataArray]
1450+
) -> None:
1451+
da_a, da_b = array_tuples[n]
1452+
14561453
if dim is not None:
14571454

14581455
def np_cov_ind(ts1, ts2, a, x):
@@ -1499,12 +1496,13 @@ def np_cov(ts1, ts2):
14991496
assert_allclose(actual, expected)
15001497

15011498

1502-
@pytest.mark.parametrize(
1503-
"da_a, da_b",
1504-
[arrays_w_tuples()[1][0], arrays_w_tuples()[1][1], arrays_w_tuples()[1][2]],
1505-
)
1499+
@pytest.mark.parametrize("n", [0, 1, 2])
15061500
@pytest.mark.parametrize("dim", [None, "time"])
1507-
def test_corr(da_a, da_b, dim) -> None:
1501+
def test_corr(
1502+
n: int, dim: str | None, array_tuples: tuple[xr.DataArray, xr.DataArray]
1503+
) -> None:
1504+
da_a, da_b = array_tuples[n]
1505+
15081506
if dim is not None:
15091507

15101508
def np_corr_ind(ts1, ts2, a, x):
@@ -1547,12 +1545,12 @@ def np_corr(ts1, ts2):
15471545
assert_allclose(actual, expected)
15481546

15491547

1550-
@pytest.mark.parametrize(
1551-
"da_a, da_b",
1552-
arrays_w_tuples()[1],
1553-
)
1548+
@pytest.mark.parametrize("n", range(9))
15541549
@pytest.mark.parametrize("dim", [None, "time", "x"])
1555-
def test_covcorr_consistency(da_a, da_b, dim) -> None:
1550+
def test_covcorr_consistency(
1551+
n: int, dim: str | None, array_tuples: tuple[xr.DataArray, xr.DataArray]
1552+
) -> None:
1553+
da_a, da_b = array_tuples[n]
15561554
# Testing that xr.corr and xr.cov are consistent with each other
15571555
# 1. Broadcast the two arrays
15581556
da_a, da_b = broadcast(da_a, da_b)
@@ -1569,10 +1567,13 @@ def test_covcorr_consistency(da_a, da_b, dim) -> None:
15691567

15701568

15711569
@requires_dask
1572-
@pytest.mark.parametrize("da_a, da_b", arrays_w_tuples()[1])
1570+
@pytest.mark.parametrize("n", range(9))
15731571
@pytest.mark.parametrize("dim", [None, "time", "x"])
15741572
@pytest.mark.filterwarnings("ignore:invalid value encountered in .*divide")
1575-
def test_corr_lazycorr_consistency(da_a, da_b, dim) -> None:
1573+
def test_corr_lazycorr_consistency(
1574+
n: int, dim: str | None, array_tuples: tuple[xr.DataArray, xr.DataArray]
1575+
) -> None:
1576+
da_a, da_b = array_tuples[n]
15761577
da_al = da_a.chunk()
15771578
da_bl = da_b.chunk()
15781579
c_abl = xr.corr(da_al, da_bl, dim=dim)
@@ -1591,22 +1592,27 @@ def test_corr_dtype_error():
15911592
xr.testing.assert_equal(xr.corr(da_a, da_b), xr.corr(da_a, da_b.chunk()))
15921593

15931594

1594-
@pytest.mark.parametrize(
1595-
"da_a",
1596-
arrays_w_tuples()[0],
1597-
)
1595+
@pytest.mark.parametrize("n", range(5))
15981596
@pytest.mark.parametrize("dim", [None, "time", "x", ["time", "x"]])
1599-
def test_autocov(da_a, dim) -> None:
1597+
def test_autocov(n: int, dim: str | None, arrays) -> None:
1598+
da = arrays[n]
1599+
16001600
# Testing that the autocovariance*(N-1) is ~=~ to the variance matrix
16011601
# 1. Ignore the nans
1602-
valid_values = da_a.notnull()
1602+
valid_values = da.notnull()
16031603
# Because we're using ddof=1, this requires > 1 value in each sample
1604-
da_a = da_a.where(valid_values.sum(dim=dim) > 1)
1605-
expected = ((da_a - da_a.mean(dim=dim)) ** 2).sum(dim=dim, skipna=True, min_count=1)
1606-
actual = xr.cov(da_a, da_a, dim=dim) * (valid_values.sum(dim) - 1)
1604+
da = da.where(valid_values.sum(dim=dim) > 1)
1605+
expected = ((da - da.mean(dim=dim)) ** 2).sum(dim=dim, skipna=True, min_count=1)
1606+
actual = xr.cov(da, da, dim=dim) * (valid_values.sum(dim) - 1)
16071607
assert_allclose(actual, expected)
16081608

16091609

1610+
def test_complex_cov() -> None:
1611+
da = xr.DataArray([1j, -1j])
1612+
actual = xr.cov(da, da)
1613+
assert abs(actual.item()) == 2
1614+
1615+
16101616
@requires_dask
16111617
def test_vectorize_dask_new_output_dims() -> None:
16121618
# regression test for GH3574

0 commit comments

Comments
 (0)