From 6a7832e5453db0a72743467673cf847261b95a41 Mon Sep 17 00:00:00 2001 From: "Oriol (ProDesk)" Date: Wed, 18 Jun 2025 17:03:52 +0200 Subject: [PATCH 1/3] ensure coords are always kept in rvs --- src/xarray_einstats/stats.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/xarray_einstats/stats.py b/src/xarray_einstats/stats.py index 2c58ae8..4b075aa 100644 --- a/src/xarray_einstats/stats.py +++ b/src/xarray_einstats/stats.py @@ -258,15 +258,15 @@ def rvs( output_core_dims=[output_core_dims], **apply_kwargs, ) - if not isinstance(out, xr.DataArray): - for elem in (*dist_args, *args, *dist_kwargs.values(), *kwargs.values()): - if isinstance(elem, xr.DataArray): - for name, values in elem.coords.items(): - if name in coords: - continue - if set(values.dims) < set(output_core_dims): - coords[name] = values + for elem in (*dist_args, *args, *dist_kwargs.values(), *kwargs.values()): + if isinstance(elem, xr.DataArray): + for name, values in elem.coords.items(): + if name in coords: + continue + if set(values.dims) < set(output_core_dims): + coords[name] = values + if not isinstance(out, xr.DataArray): return xr.DataArray(out, dims=output_core_dims, coords=coords) return out.assign_coords(coords) From d188d8c7bd61e6ec8184150ee215a8c585b59f68 Mon Sep 17 00:00:00 2001 From: "Oriol (ProDesk)" Date: Wed, 18 Jun 2025 17:22:55 +0200 Subject: [PATCH 2/3] improve tests --- tests/test_stats.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/test_stats.py b/tests/test_stats.py index a327531..7aaa547 100644 --- a/tests/test_stats.py +++ b/tests/test_stats.py @@ -126,6 +126,19 @@ def test_rv_method(self, data, wrapper, dim_names, in_type): assert_dims_in_da(out, dim_names) assert len(out[dim_names[0]]) == 2 assert len(out[dim_names[1]]) == 7 + assert all(np.all(out.coords[name] == coords) for name, coords in data.coords.items()) + + @pytest.mark.parametrize("in_type", ("scalar", "dataarray")) + def test_rv_method_coords(self, data, wrapper, in_type): + dist = get_dist_and_clean_method(wrapper, data, in_type=in_type) + coords = {"new_dim": ["true", "false"], "extra_dim": list("abcdefg")} + out = dist.rvs(coords=coords) + assert_dims_in_da(out, list(coords)) + assert len(out["new_dim"]) == 2 + assert len(out["extra_dim"]) == 7 + assert np.all(out.coords["new_dim"] == np.array(coords["new_dim"])) + assert np.all(out.coords["extra_dim"] == np.array(coords["extra_dim"])) + assert all(np.all(out.coords[name] == coord) for name, coord in data.coords.items()) @pytest.mark.parametrize("size", (1, 10)) @pytest.mark.parametrize("dims", (None, "name", ["name"])) From d0954b51b091ef3867f1ee7c2ad6f7c91b25188e Mon Sep 17 00:00:00 2001 From: "Oriol (ProDesk)" Date: Wed, 18 Jun 2025 17:43:28 +0200 Subject: [PATCH 3/3] update version and changelog --- docs/source/changelog.md | 4 ++++ src/xarray_einstats/__init__.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/docs/source/changelog.md b/docs/source/changelog.md index 6f76cfc..c7fb1fc 100644 --- a/docs/source/changelog.md +++ b/docs/source/changelog.md @@ -1,5 +1,9 @@ # Change Log +## v0.9.1 (2025 June 18) +### Maintenance and fixes +* Ensure coords are preserved in `.rvs` method of RV wrappers {pull}`73` + ## v0.9.0 (2025 May 22) ### New features * Support PreliZ distributions as input to RV wrappers {pull}`70` diff --git a/src/xarray_einstats/__init__.py b/src/xarray_einstats/__init__.py index 50f618b..90d868b 100644 --- a/src/xarray_einstats/__init__.py +++ b/src/xarray_einstats/__init__.py @@ -19,7 +19,7 @@ "EinopsAccessor", ] -__version__ = "0.9.0" +__version__ = "0.9.1" def sort(da, dim, kind=None, stable=None, **kwargs):