diff --git a/docs/source/changelog.md b/docs/source/changelog.md index 6f76cfc..c7fb1fc 100644 --- a/docs/source/changelog.md +++ b/docs/source/changelog.md @@ -1,5 +1,9 @@ # Change Log +## v0.9.1 (2025 June 18) +### Maintenance and fixes +* Ensure coords are preserved in `.rvs` method of RV wrappers {pull}`73` + ## v0.9.0 (2025 May 22) ### New features * Support PreliZ distributions as input to RV wrappers {pull}`70` diff --git a/src/xarray_einstats/__init__.py b/src/xarray_einstats/__init__.py index 50f618b..90d868b 100644 --- a/src/xarray_einstats/__init__.py +++ b/src/xarray_einstats/__init__.py @@ -19,7 +19,7 @@ "EinopsAccessor", ] -__version__ = "0.9.0" +__version__ = "0.9.1" def sort(da, dim, kind=None, stable=None, **kwargs): diff --git a/src/xarray_einstats/stats.py b/src/xarray_einstats/stats.py index 2c58ae8..4b075aa 100644 --- a/src/xarray_einstats/stats.py +++ b/src/xarray_einstats/stats.py @@ -258,15 +258,15 @@ def rvs( output_core_dims=[output_core_dims], **apply_kwargs, ) - if not isinstance(out, xr.DataArray): - for elem in (*dist_args, *args, *dist_kwargs.values(), *kwargs.values()): - if isinstance(elem, xr.DataArray): - for name, values in elem.coords.items(): - if name in coords: - continue - if set(values.dims) < set(output_core_dims): - coords[name] = values + for elem in (*dist_args, *args, *dist_kwargs.values(), *kwargs.values()): + if isinstance(elem, xr.DataArray): + for name, values in elem.coords.items(): + if name in coords: + continue + if set(values.dims) < set(output_core_dims): + coords[name] = values + if not isinstance(out, xr.DataArray): return xr.DataArray(out, dims=output_core_dims, coords=coords) return out.assign_coords(coords) diff --git a/tests/test_stats.py b/tests/test_stats.py index a327531..7aaa547 100644 --- a/tests/test_stats.py +++ b/tests/test_stats.py @@ -126,6 +126,19 @@ def test_rv_method(self, data, wrapper, dim_names, in_type): assert_dims_in_da(out, dim_names) assert len(out[dim_names[0]]) == 2 assert len(out[dim_names[1]]) == 7 + assert all(np.all(out.coords[name] == coords) for name, coords in data.coords.items()) + + @pytest.mark.parametrize("in_type", ("scalar", "dataarray")) + def test_rv_method_coords(self, data, wrapper, in_type): + dist = get_dist_and_clean_method(wrapper, data, in_type=in_type) + coords = {"new_dim": ["true", "false"], "extra_dim": list("abcdefg")} + out = dist.rvs(coords=coords) + assert_dims_in_da(out, list(coords)) + assert len(out["new_dim"]) == 2 + assert len(out["extra_dim"]) == 7 + assert np.all(out.coords["new_dim"] == np.array(coords["new_dim"])) + assert np.all(out.coords["extra_dim"] == np.array(coords["extra_dim"])) + assert all(np.all(out.coords[name] == coord) for name, coord in data.coords.items()) @pytest.mark.parametrize("size", (1, 10)) @pytest.mark.parametrize("dims", (None, "name", ["name"]))