Skip to content

Commit 0496cb4

Browse files
Avoid use of random numbers in test_weighted.test_weighted_operations_nonequal_coords (#6961)
* avoid use of random numbers in test weighted * corrected PR number in whats-new.rst * explicit variable names for types * removed unnecessary comment * changes from review * typo in docstring * Update xarray/tests/test_weighted.py Co-authored-by: Mathias Hauser <mathause@users.noreply.github.com> Co-authored-by: Mathias Hauser <mathause@users.noreply.github.com>
1 parent 647ac4b commit 0496cb4

File tree

2 files changed

+44
-14
lines changed

2 files changed

+44
-14
lines changed

doc/whats-new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ Bug fixes
5555
By `Fabian Hofmann <https://github.com/FabianHofmann>`_.
5656
- Fix step plots with ``hue`` arg. (:pull:`6944`)
5757
By `András Gunyhó <https://github.com/mgunyho>`_.
58+
- Avoid use of random numbers in `test_weighted.test_weighted_operations_nonequal_coords` (:issue:`6504`, :pull:`6961`).
59+
By `Luke Conibear <https://github.com/lukeconibear>`_.
5860

5961
Documentation
6062
~~~~~~~~~~~~~

xarray/tests/test_weighted.py

Lines changed: 42 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
from typing import Any, Iterable
4+
35
import numpy as np
46
import pytest
57

@@ -654,23 +656,49 @@ def test_weighted_quantile_3D(dim, q, add_nans, skipna):
654656
assert_allclose(expected, result2.data)
655657

656658

657-
def test_weighted_operations_nonequal_coords():
658-
# There are no weights for a == 4, so that data point is ignored.
659-
weights = DataArray(np.random.randn(4), dims=("a",), coords=dict(a=[0, 1, 2, 3]))
660-
data = DataArray(np.random.randn(4), dims=("a",), coords=dict(a=[1, 2, 3, 4]))
661-
check_weighted_operations(data, weights, dim="a", skipna=None)
659+
@pytest.mark.parametrize(
660+
"coords_weights, coords_data, expected_value_at_weighted_quantile",
661+
[
662+
([0, 1, 2, 3], [1, 2, 3, 4], 2.5), # no weights for coord a == 4
663+
([0, 1, 2, 3], [2, 3, 4, 5], 1.8), # no weights for coord a == 4 or 5
664+
([2, 3, 4, 5], [0, 1, 2, 3], 3.8), # no weights for coord a == 0 or 1
665+
],
666+
)
667+
def test_weighted_operations_nonequal_coords(
668+
coords_weights: Iterable[Any],
669+
coords_data: Iterable[Any],
670+
expected_value_at_weighted_quantile: float,
671+
) -> None:
672+
"""Check that weighted operations work with unequal coords.
673+
674+
675+
Parameters
676+
----------
677+
coords_weights : Iterable[Any]
678+
The coords for the weights.
679+
coords_data : Iterable[Any]
680+
The coords for the data.
681+
expected_value_at_weighted_quantile : float
682+
The expected value for the quantile of the weighted data.
683+
"""
684+
da_weights = DataArray(
685+
[0.5, 1.0, 1.0, 2.0], dims=("a",), coords=dict(a=coords_weights)
686+
)
687+
da_data = DataArray([1, 2, 3, 4], dims=("a",), coords=dict(a=coords_data))
688+
check_weighted_operations(da_data, da_weights, dim="a", skipna=None)
662689

663-
q = 0.5
664-
result = data.weighted(weights).quantile(q, dim="a")
665-
# Expected value computed using code from https://aakinshin.net/posts/weighted-quantiles/ with values at a=1,2,3
666-
expected = DataArray([0.9308707], coords={"quantile": [q]}).squeeze()
667-
assert_allclose(result, expected)
690+
quantile = 0.5
691+
da_actual = da_data.weighted(da_weights).quantile(quantile, dim="a")
692+
da_expected = DataArray(
693+
[expected_value_at_weighted_quantile], coords={"quantile": [quantile]}
694+
).squeeze()
695+
assert_allclose(da_actual, da_expected)
668696

669-
data = data.to_dataset(name="data")
670-
check_weighted_operations(data, weights, dim="a", skipna=None)
697+
ds_data = da_data.to_dataset(name="data")
698+
check_weighted_operations(ds_data, da_weights, dim="a", skipna=None)
671699

672-
result = data.weighted(weights).quantile(q, dim="a")
673-
assert_allclose(result, expected.to_dataset(name="data"))
700+
ds_actual = ds_data.weighted(da_weights).quantile(quantile, dim="a")
701+
assert_allclose(ds_actual, da_expected.to_dataset(name="data"))
674702

675703

676704
@pytest.mark.parametrize("shape_data", ((4,), (4, 4), (4, 4, 4)))

0 commit comments

Comments
 (0)