Skip to content

ENH: allow masking nodata in zonal_stats #123

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions xvec/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -985,6 +985,7 @@ def zonal_stats(
method: str = "rasterize",
all_touched: bool = False,
n_jobs: int = -1,
nodata: Any = None,
**kwargs: dict[str, Any],
) -> xr.DataArray | xr.Dataset:
"""Extract the values from a dataset indexed by a set of geometries
Expand Down Expand Up @@ -1062,6 +1063,9 @@ def zonal_stats(
Number of parallel threads to use. It is recommended to set this to the
number of physical cores of the CPU. ``-1`` uses all available cores.
Applies only if ``method="iterate"``.
nodata : Any
Value representing missing data. If not specified, the value is included in
the aggregation.
**kwargs : optional
Keyword arguments to be passed to the aggregation function
(e.g., ``Dataset.quantile(**kwargs)``).
Expand Down Expand Up @@ -1152,6 +1156,7 @@ def zonal_stats(
y_coords=y_coords,
stats=stats,
all_touched=all_touched,
nodata=nodata,
)

if method == "rasterize":
Expand All @@ -1163,6 +1168,7 @@ def zonal_stats(
stats=stats,
name=name,
all_touched=all_touched,
nodata=nodata,
**kwargs,
)
elif method == "iterate":
Expand All @@ -1175,6 +1181,7 @@ def zonal_stats(
name=name,
all_touched=all_touched,
n_jobs=n_jobs,
nodata=nodata,
**kwargs,
)
elif method == "exactextract":
Expand All @@ -1185,6 +1192,7 @@ def zonal_stats(
y_coords=y_coords,
stats=stats,
name=name,
nodata=nodata,
**kwargs,
)
else:
Expand Down
16 changes: 16 additions & 0 deletions xvec/tests/test_zonal_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,3 +456,19 @@ def test_exactextract_strategy():
method="exactextract",
strategy="invalid_strategy",
)


@pytest.mark.parametrize("method", ["rasterize", "iterate", "exactextract"])
def test_nodata(method):
ds = xr.tutorial.open_dataset("eraint_uvz")
world = gpd.read_file(geodatasets.get_path("naturalearth land"))

arr = ds.z.where(ds.z > ds.z.mean(), -9999)
unmasked = arr.xvec.zonal_stats(
world.geometry, "longitude", "latitude", method=method
)
masked = arr.xvec.zonal_stats(
world.geometry, "longitude", "latitude", method=method, nodata=-9999
)

assert unmasked.mean() < masked.mean()
49 changes: 47 additions & 2 deletions xvec/zonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def _zonal_stats_rasterize(
stats: str | Callable | Sequence[str | Callable | tuple] = "mean",
name: str = "geometry",
all_touched: bool = False,
nodata: Any = None,
**kwargs,
) -> xr.DataArray | xr.Dataset:
try:
Expand Down Expand Up @@ -70,6 +71,11 @@ def _zonal_stats_rasterize(
unique.remove(length)

obj = acc._obj.copy()

# mask out nodata - note that this casts whole array to float
if nodata is not None:
obj = obj.where(obj != nodata)

if isinstance(obj, xr.Dataset):
obj = obj.assign_coords(
__labels__=xr.DataArray(labels, dims=(y_coords, x_coords))
Expand Down Expand Up @@ -124,6 +130,7 @@ def _zonal_stats_iterative(
name: str = "geometry",
all_touched: bool = False,
n_jobs: int = -1,
nodata: Any = None,
**kwargs: dict[str, Any],
) -> xr.DataArray | xr.Dataset:
"""Extract the values from a dataset indexed by a set of geometries
Expand Down Expand Up @@ -158,6 +165,9 @@ def _zonal_stats_iterative(
n_jobs : int, optional
Number of parallel threads to use.
It is recommended to set this to the number of physical cores in the CPU.
nodata : Any
Value representing missing data. If not specified, the value is included in
the aggregation.
**kwargs : optional
Keyword arguments to be passed to the aggregation function
(as ``Dataset.mean(**kwargs)``).
Expand Down Expand Up @@ -198,6 +208,7 @@ def _zonal_stats_iterative(
y_coords,
stats=stats,
all_touched=all_touched,
nodata=nodata,
**kwargs,
)
for geom in geometry
Expand All @@ -224,6 +235,7 @@ def _agg_geom(
y_coords: str | None = None,
stats: str | Callable | Iterable[str | Callable | tuple] = "mean",
all_touched: bool = False,
nodata: Any = None,
**kwargs,
):
"""Aggregate the values from a dataset over a polygon geometry.
Expand All @@ -250,6 +262,9 @@ def _agg_geom(
If True, all pixels touched by geometries will be considered. If False, only
pixels whose center is within the polygon or that are selected by
Bresenham’s line algorithm will be considered.
nodata : Any
Value representing missing data. If not specified, the value is included in
the aggregation.

Returns
-------
Expand All @@ -270,6 +285,8 @@ def _agg_geom(
all_touched=all_touched,
)
masked = acc._obj.where(xr.DataArray(mask, dims=(y_coords, x_coords)))
if nodata is not None:
masked = masked.where(masked != nodata)
if pd.api.types.is_list_like(stats):
agg = {}
for stat in stats: # type: ignore
Expand Down Expand Up @@ -309,6 +326,7 @@ def _zonal_stats_exactextract(
y_coords: Hashable,
stats: str | Callable | Sequence[str | Callable | tuple] = "mean",
name: str = "geometry",
nodata: Any = None,
**kwargs,
) -> xr.DataArray | xr.Dataset:
"""Extract the values from a dataset indexed by a set of geometries
Expand All @@ -334,6 +352,9 @@ def _zonal_stats_exactextract(
``"quantile(q=0.20)"``)
name : str, optional
Name of the dimension that will hold the ``geometry``, by default "geometry"
nodata : Any
Value representing missing data. If not specified, the value is included in
the aggregation.

Returns
-------
Expand Down Expand Up @@ -372,6 +393,7 @@ def _zonal_stats_exactextract(
stats,
name,
original_is_ds,
nodata=nodata,
**kwargs,
)
i = 0
Expand Down Expand Up @@ -410,6 +432,7 @@ def _zonal_stats_exactextract(
stats,
name,
original_is_ds,
nodata=nodata,
**kwargs,
)
# Unstack the result
Expand Down Expand Up @@ -447,6 +470,7 @@ def _agg_exactextract(
name: str = "geometry",
original_is_ds: bool = False,
strategy: str = "feature-sequential",
nodata: Any = None,
):
"""Extract the values from a dataset indexed by a set of geometries

Expand Down Expand Up @@ -476,6 +500,9 @@ def _agg_exactextract(
If True, all pixels touched by geometries will be considered. If False, only
pixels whose center is within the polygon or that are selected by
Bresenham’s line algorithm will be considered.
nodata : Any
Value representing missing data. If not specified, the value is included in
the aggregation.
strategy : str, optional
The strategy to use for the extraction, by default "feature-sequential"
Use either "feature-sequential" and "raster-sequential".
Expand Down Expand Up @@ -511,6 +538,10 @@ def _agg_exactextract(
# Check the order of dimensions
data = data.transpose("location", y_coords, x_coords)

# mask nodata
if nodata is not None:
data = data.where(data != nodata)

# Aggregation result
gdf = gpd.GeoDataFrame(geometry=geometry, crs=crs)
results = exactextract.exact_extract(
Expand All @@ -537,7 +568,16 @@ def _agg_exactextract(


def _get_mean(
geom_arr, obj, x_coords, y_coords, transform, all_touched, stats, dims, **kwargs
geom_arr,
obj,
x_coords,
y_coords,
transform,
all_touched,
stats,
dims,
nodata,
**kwargs,
):
from rasterio import features

Expand All @@ -552,6 +592,10 @@ def _get_mean(
all_touched=all_touched,
)
masked = obj.where(xr.DataArray(mask, dims=(y_coords, x_coords)))

if nodata is not None:
masked = masked.where(masked != nodata)

if pd.api.types.is_list_like(stats):
agg = {}
for stat in stats: # type: ignore
Expand Down Expand Up @@ -589,6 +633,7 @@ def _variable_zonal(
y_coords: Hashable,
stats="mean",
all_touched: bool = False,
nodata: Any = None,
):
transform = acc._obj.rio.transform()
dims = variable_geometry.dims
Expand All @@ -597,7 +642,7 @@ def _variable_zonal(

for x in stacked:
m = _get_mean(
x, acc._obj, x_coords, y_coords, transform, all_touched, stats, dims
x, acc._obj, x_coords, y_coords, transform, all_touched, stats, dims, nodata
)
m.name = "statistics"
r.append(m)
Expand Down