@@ -1387,25 +1387,29 @@ def test_vectorize_exclude_dims_dask() -> None:
1387
1387
1388
1388
def test_corr_only_dataarray () -> None :
1389
1389
with pytest .raises (TypeError , match = "Only xr.DataArray is supported" ):
1390
- xr .corr (xr .Dataset (), xr .Dataset ())
1390
+ xr .corr (xr .Dataset (), xr .Dataset ()) # type: ignore[type-var]
1391
1391
1392
1392
1393
- def arrays_w_tuples ():
1393
+ @pytest .fixture (scope = "module" )
1394
+ def arrays ():
1394
1395
da = xr .DataArray (
1395
1396
np .random .random ((3 , 21 , 4 )),
1396
1397
coords = {"time" : pd .date_range ("2000-01-01" , freq = "1D" , periods = 21 )},
1397
1398
dims = ("a" , "time" , "x" ),
1398
1399
)
1399
1400
1400
- arrays = [
1401
+ return [
1401
1402
da .isel (time = range (0 , 18 )),
1402
1403
da .isel (time = range (2 , 20 )).rolling (time = 3 , center = True ).mean (),
1403
1404
xr .DataArray ([[1 , 2 ], [1 , np .nan ]], dims = ["x" , "time" ]),
1404
1405
xr .DataArray ([[1 , 2 ], [np .nan , np .nan ]], dims = ["x" , "time" ]),
1405
1406
xr .DataArray ([[1 , 2 ], [2 , 1 ]], dims = ["x" , "time" ]),
1406
1407
]
1407
1408
1408
- array_tuples = [
1409
+
1410
+ @pytest .fixture (scope = "module" )
1411
+ def array_tuples (arrays ):
1412
+ return [
1409
1413
(arrays [0 ], arrays [0 ]),
1410
1414
(arrays [0 ], arrays [1 ]),
1411
1415
(arrays [1 ], arrays [1 ]),
@@ -1417,27 +1421,19 @@ def arrays_w_tuples():
1417
1421
(arrays [4 ], arrays [4 ]),
1418
1422
]
1419
1423
1420
- return arrays , array_tuples
1421
-
1422
1424
1423
1425
@pytest .mark .parametrize ("ddof" , [0 , 1 ])
1424
- @pytest .mark .parametrize (
1425
- "da_a, da_b" ,
1426
- [
1427
- arrays_w_tuples ()[1 ][3 ],
1428
- arrays_w_tuples ()[1 ][4 ],
1429
- arrays_w_tuples ()[1 ][5 ],
1430
- arrays_w_tuples ()[1 ][6 ],
1431
- arrays_w_tuples ()[1 ][7 ],
1432
- arrays_w_tuples ()[1 ][8 ],
1433
- ],
1434
- )
1426
+ @pytest .mark .parametrize ("n" , [3 , 4 , 5 , 6 , 7 , 8 ])
1435
1427
@pytest .mark .parametrize ("dim" , [None , "x" , "time" ])
1436
1428
@requires_dask
1437
- def test_lazy_corrcov (da_a , da_b , dim , ddof ) -> None :
1429
+ def test_lazy_corrcov (
1430
+ n : int , dim : str | None , ddof : int , array_tuples : tuple [xr .DataArray , xr .DataArray ]
1431
+ ) -> None :
1438
1432
# GH 5284
1439
1433
from dask import is_dask_collection
1440
1434
1435
+ da_a , da_b = array_tuples [n ]
1436
+
1441
1437
with raise_if_dask_computes ():
1442
1438
cov = xr .cov (da_a .chunk (), da_b .chunk (), dim = dim , ddof = ddof )
1443
1439
assert is_dask_collection (cov )
@@ -1447,12 +1443,13 @@ def test_lazy_corrcov(da_a, da_b, dim, ddof) -> None:
1447
1443
1448
1444
1449
1445
@pytest .mark .parametrize ("ddof" , [0 , 1 ])
1450
- @pytest .mark .parametrize (
1451
- "da_a, da_b" ,
1452
- [arrays_w_tuples ()[1 ][0 ], arrays_w_tuples ()[1 ][1 ], arrays_w_tuples ()[1 ][2 ]],
1453
- )
1446
+ @pytest .mark .parametrize ("n" , [0 , 1 , 2 ])
1454
1447
@pytest .mark .parametrize ("dim" , [None , "time" ])
1455
- def test_cov (da_a , da_b , dim , ddof ) -> None :
1448
+ def test_cov (
1449
+ n : int , dim : str | None , ddof : int , array_tuples : tuple [xr .DataArray , xr .DataArray ]
1450
+ ) -> None :
1451
+ da_a , da_b = array_tuples [n ]
1452
+
1456
1453
if dim is not None :
1457
1454
1458
1455
def np_cov_ind (ts1 , ts2 , a , x ):
@@ -1499,12 +1496,13 @@ def np_cov(ts1, ts2):
1499
1496
assert_allclose (actual , expected )
1500
1497
1501
1498
1502
- @pytest .mark .parametrize (
1503
- "da_a, da_b" ,
1504
- [arrays_w_tuples ()[1 ][0 ], arrays_w_tuples ()[1 ][1 ], arrays_w_tuples ()[1 ][2 ]],
1505
- )
1499
+ @pytest .mark .parametrize ("n" , [0 , 1 , 2 ])
1506
1500
@pytest .mark .parametrize ("dim" , [None , "time" ])
1507
- def test_corr (da_a , da_b , dim ) -> None :
1501
+ def test_corr (
1502
+ n : int , dim : str | None , array_tuples : tuple [xr .DataArray , xr .DataArray ]
1503
+ ) -> None :
1504
+ da_a , da_b = array_tuples [n ]
1505
+
1508
1506
if dim is not None :
1509
1507
1510
1508
def np_corr_ind (ts1 , ts2 , a , x ):
@@ -1547,12 +1545,12 @@ def np_corr(ts1, ts2):
1547
1545
assert_allclose (actual , expected )
1548
1546
1549
1547
1550
- @pytest .mark .parametrize (
1551
- "da_a, da_b" ,
1552
- arrays_w_tuples ()[1 ],
1553
- )
1548
+ @pytest .mark .parametrize ("n" , range (9 ))
1554
1549
@pytest .mark .parametrize ("dim" , [None , "time" , "x" ])
1555
- def test_covcorr_consistency (da_a , da_b , dim ) -> None :
1550
+ def test_covcorr_consistency (
1551
+ n : int , dim : str | None , array_tuples : tuple [xr .DataArray , xr .DataArray ]
1552
+ ) -> None :
1553
+ da_a , da_b = array_tuples [n ]
1556
1554
# Testing that xr.corr and xr.cov are consistent with each other
1557
1555
# 1. Broadcast the two arrays
1558
1556
da_a , da_b = broadcast (da_a , da_b )
@@ -1569,10 +1567,13 @@ def test_covcorr_consistency(da_a, da_b, dim) -> None:
1569
1567
1570
1568
1571
1569
@requires_dask
1572
- @pytest .mark .parametrize ("da_a, da_b " , arrays_w_tuples ()[ 1 ] )
1570
+ @pytest .mark .parametrize ("n " , range ( 9 ) )
1573
1571
@pytest .mark .parametrize ("dim" , [None , "time" , "x" ])
1574
1572
@pytest .mark .filterwarnings ("ignore:invalid value encountered in .*divide" )
1575
- def test_corr_lazycorr_consistency (da_a , da_b , dim ) -> None :
1573
+ def test_corr_lazycorr_consistency (
1574
+ n : int , dim : str | None , array_tuples : tuple [xr .DataArray , xr .DataArray ]
1575
+ ) -> None :
1576
+ da_a , da_b = array_tuples [n ]
1576
1577
da_al = da_a .chunk ()
1577
1578
da_bl = da_b .chunk ()
1578
1579
c_abl = xr .corr (da_al , da_bl , dim = dim )
@@ -1591,22 +1592,27 @@ def test_corr_dtype_error():
1591
1592
xr .testing .assert_equal (xr .corr (da_a , da_b ), xr .corr (da_a , da_b .chunk ()))
1592
1593
1593
1594
1594
- @pytest .mark .parametrize (
1595
- "da_a" ,
1596
- arrays_w_tuples ()[0 ],
1597
- )
1595
+ @pytest .mark .parametrize ("n" , range (5 ))
1598
1596
@pytest .mark .parametrize ("dim" , [None , "time" , "x" , ["time" , "x" ]])
1599
- def test_autocov (da_a , dim ) -> None :
1597
+ def test_autocov (n : int , dim : str | None , arrays ) -> None :
1598
+ da = arrays [n ]
1599
+
1600
1600
# Testing that the autocovariance*(N-1) is ~=~ to the variance matrix
1601
1601
# 1. Ignore the nans
1602
- valid_values = da_a .notnull ()
1602
+ valid_values = da .notnull ()
1603
1603
# Because we're using ddof=1, this requires > 1 value in each sample
1604
- da_a = da_a .where (valid_values .sum (dim = dim ) > 1 )
1605
- expected = ((da_a - da_a .mean (dim = dim )) ** 2 ).sum (dim = dim , skipna = True , min_count = 1 )
1606
- actual = xr .cov (da_a , da_a , dim = dim ) * (valid_values .sum (dim ) - 1 )
1604
+ da = da .where (valid_values .sum (dim = dim ) > 1 )
1605
+ expected = ((da - da .mean (dim = dim )) ** 2 ).sum (dim = dim , skipna = True , min_count = 1 )
1606
+ actual = xr .cov (da , da , dim = dim ) * (valid_values .sum (dim ) - 1 )
1607
1607
assert_allclose (actual , expected )
1608
1608
1609
1609
1610
+ def test_complex_cov () -> None :
1611
+ da = xr .DataArray ([1j , - 1j ])
1612
+ actual = xr .cov (da , da )
1613
+ assert abs (actual .item ()) == 2
1614
+
1615
+
1610
1616
@requires_dask
1611
1617
def test_vectorize_dask_new_output_dims () -> None :
1612
1618
# regression test for GH3574
0 commit comments