|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
| 3 | +from typing import Any, Iterable |
| 4 | + |
3 | 5 | import numpy as np
|
4 | 6 | import pytest
|
5 | 7 |
|
@@ -654,23 +656,49 @@ def test_weighted_quantile_3D(dim, q, add_nans, skipna):
|
654 | 656 | assert_allclose(expected, result2.data)
|
655 | 657 |
|
656 | 658 |
|
657 |
| -def test_weighted_operations_nonequal_coords(): |
658 |
| - # There are no weights for a == 4, so that data point is ignored. |
659 |
| - weights = DataArray(np.random.randn(4), dims=("a",), coords=dict(a=[0, 1, 2, 3])) |
660 |
| - data = DataArray(np.random.randn(4), dims=("a",), coords=dict(a=[1, 2, 3, 4])) |
661 |
| - check_weighted_operations(data, weights, dim="a", skipna=None) |
| 659 | +@pytest.mark.parametrize( |
| 660 | + "coords_weights, coords_data, expected_value_at_weighted_quantile", |
| 661 | + [ |
| 662 | + ([0, 1, 2, 3], [1, 2, 3, 4], 2.5), # no weights for coord a == 4 |
| 663 | + ([0, 1, 2, 3], [2, 3, 4, 5], 1.8), # no weights for coord a == 4 or 5 |
| 664 | + ([2, 3, 4, 5], [0, 1, 2, 3], 3.8), # no weights for coord a == 0 or 1 |
| 665 | + ], |
| 666 | +) |
| 667 | +def test_weighted_operations_nonequal_coords( |
| 668 | + coords_weights: Iterable[Any], |
| 669 | + coords_data: Iterable[Any], |
| 670 | + expected_value_at_weighted_quantile: float, |
| 671 | +) -> None: |
| 672 | + """Check that weighted operations work with unequal coords. |
| 673 | +
|
| 674 | +
|
| 675 | + Parameters |
| 676 | + ---------- |
| 677 | + coords_weights : Iterable[Any] |
| 678 | + The coords for the weights. |
| 679 | + coords_data : Iterable[Any] |
| 680 | + The coords for the data. |
| 681 | + expected_value_at_weighted_quantile : float |
| 682 | + The expected value for the quantile of the weighted data. |
| 683 | + """ |
| 684 | + da_weights = DataArray( |
| 685 | + [0.5, 1.0, 1.0, 2.0], dims=("a",), coords=dict(a=coords_weights) |
| 686 | + ) |
| 687 | + da_data = DataArray([1, 2, 3, 4], dims=("a",), coords=dict(a=coords_data)) |
| 688 | + check_weighted_operations(da_data, da_weights, dim="a", skipna=None) |
662 | 689 |
|
663 |
| - q = 0.5 |
664 |
| - result = data.weighted(weights).quantile(q, dim="a") |
665 |
| - # Expected value computed using code from https://aakinshin.net/posts/weighted-quantiles/ with values at a=1,2,3 |
666 |
| - expected = DataArray([0.9308707], coords={"quantile": [q]}).squeeze() |
667 |
| - assert_allclose(result, expected) |
| 690 | + quantile = 0.5 |
| 691 | + da_actual = da_data.weighted(da_weights).quantile(quantile, dim="a") |
| 692 | + da_expected = DataArray( |
| 693 | + [expected_value_at_weighted_quantile], coords={"quantile": [quantile]} |
| 694 | + ).squeeze() |
| 695 | + assert_allclose(da_actual, da_expected) |
668 | 696 |
|
669 |
| - data = data.to_dataset(name="data") |
670 |
| - check_weighted_operations(data, weights, dim="a", skipna=None) |
| 697 | + ds_data = da_data.to_dataset(name="data") |
| 698 | + check_weighted_operations(ds_data, da_weights, dim="a", skipna=None) |
671 | 699 |
|
672 |
| - result = data.weighted(weights).quantile(q, dim="a") |
673 |
| - assert_allclose(result, expected.to_dataset(name="data")) |
| 700 | + ds_actual = ds_data.weighted(da_weights).quantile(quantile, dim="a") |
| 701 | + assert_allclose(ds_actual, da_expected.to_dataset(name="data")) |
674 | 702 |
|
675 | 703 |
|
676 | 704 | @pytest.mark.parametrize("shape_data", ((4,), (4, 4), (4, 4, 4)))
|
|
0 commit comments