Skip to content

Commit a545cf6

Browse files
authored
ensure coords are always kept in rvs (#73)
* ensure coords are always kept in rvs * improve tests * update version and changelog
1 parent cf51293 commit a545cf6

File tree

4 files changed

+26
-9
lines changed

4 files changed

+26
-9
lines changed

docs/source/changelog.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# Change Log
22

3+
## v0.9.1 (2025 June 18)
4+
### Maintenance and fixes
5+
* Ensure coords are preserved in `.rvs` method of RV wrappers {pull}`73`
6+
37
## v0.9.0 (2025 May 22)
48
### New features
59
* Support PreliZ distributions as input to RV wrappers {pull}`70`

src/xarray_einstats/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
"EinopsAccessor",
2020
]
2121

22-
__version__ = "0.9.0"
22+
__version__ = "0.9.1"
2323

2424

2525
def sort(da, dim, kind=None, stable=None, **kwargs):

src/xarray_einstats/stats.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -258,15 +258,15 @@ def rvs(
258258
output_core_dims=[output_core_dims],
259259
**apply_kwargs,
260260
)
261-
if not isinstance(out, xr.DataArray):
262-
for elem in (*dist_args, *args, *dist_kwargs.values(), *kwargs.values()):
263-
if isinstance(elem, xr.DataArray):
264-
for name, values in elem.coords.items():
265-
if name in coords:
266-
continue
267-
if set(values.dims) < set(output_core_dims):
268-
coords[name] = values
261+
for elem in (*dist_args, *args, *dist_kwargs.values(), *kwargs.values()):
262+
if isinstance(elem, xr.DataArray):
263+
for name, values in elem.coords.items():
264+
if name in coords:
265+
continue
266+
if set(values.dims) < set(output_core_dims):
267+
coords[name] = values
269268

269+
if not isinstance(out, xr.DataArray):
270270
return xr.DataArray(out, dims=output_core_dims, coords=coords)
271271
return out.assign_coords(coords)
272272

tests/test_stats.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,19 @@ def test_rv_method(self, data, wrapper, dim_names, in_type):
126126
assert_dims_in_da(out, dim_names)
127127
assert len(out[dim_names[0]]) == 2
128128
assert len(out[dim_names[1]]) == 7
129+
assert all(np.all(out.coords[name] == coords) for name, coords in data.coords.items())
130+
131+
@pytest.mark.parametrize("in_type", ("scalar", "dataarray"))
132+
def test_rv_method_coords(self, data, wrapper, in_type):
133+
dist = get_dist_and_clean_method(wrapper, data, in_type=in_type)
134+
coords = {"new_dim": ["true", "false"], "extra_dim": list("abcdefg")}
135+
out = dist.rvs(coords=coords)
136+
assert_dims_in_da(out, list(coords))
137+
assert len(out["new_dim"]) == 2
138+
assert len(out["extra_dim"]) == 7
139+
assert np.all(out.coords["new_dim"] == np.array(coords["new_dim"]))
140+
assert np.all(out.coords["extra_dim"] == np.array(coords["extra_dim"]))
141+
assert all(np.all(out.coords[name] == coord) for name, coord in data.coords.items())
129142

130143
@pytest.mark.parametrize("size", (1, 10))
131144
@pytest.mark.parametrize("dims", (None, "name", ["name"]))

0 commit comments

Comments
 (0)