Skip to content

Commit bb076e3

Browse files
gcariaGiacomo Caria
andauthored
Fix rolling mean on bool arrays (#10319)
* set count's dtype to int * add test for the MR, and fix failing test * remove comments --------- Co-authored-by: Giacomo Caria <giacomo@chloris.earth>
1 parent bb495f8 commit bb076e3

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

xarray/computation/rolling.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,10 +196,13 @@ def method(self, keep_attrs=None, **kwargs):
196196

197197
def _mean(self, keep_attrs, **kwargs):
198198
result = self.sum(keep_attrs=False, **kwargs) / duck_array_ops.astype(
199-
self.count(keep_attrs=False), dtype=self.obj.dtype, copy=False
199+
self.count(keep_attrs=False), dtype=int, copy=False
200200
)
201201
if keep_attrs:
202202
result.attrs = self.obj.attrs
203+
204+
if self.obj.dtype.kind not in "bi":
205+
result = result.astype(self.obj.dtype, copy=False)
203206
return result
204207

205208
_mean.__doc__ = _ROLLING_REDUCE_DOCSTRING_TEMPLATE.format(name="mean")

xarray/tests/test_rolling.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,20 @@ def test_rolling_dask_dtype(self, dtype) -> None:
433433
chunked_result = data.chunk({"x": 1}).rolling(x=3, min_periods=1).mean()
434434
assert chunked_result.dtype == unchunked_result.dtype
435435

436+
def test_rolling_mean_bool(self) -> None:
437+
bool_raster = DataArray(
438+
data=[0, 1, 1, 0, 1, 0],
439+
dims=("x"),
440+
).astype(bool)
441+
442+
expected = DataArray(
443+
data=[np.nan, 2 / 3, 2 / 3, 2 / 3, 1 / 3, np.nan],
444+
dims=("x"),
445+
)
446+
447+
result = bool_raster.rolling(x=3, center=True).mean()
448+
assert_allclose(result, expected)
449+
436450

437451
@requires_numbagg
438452
class TestDataArrayRollingExp:

0 commit comments

Comments
 (0)