Skip to content

Commit 131e6cb

Browse files
authored
use smallest int mask in rasterize (#96)
1 parent ab3f22d commit 131e6cb

File tree

1 file changed

+23
-8
lines changed

1 file changed

+23
-8
lines changed

xvec/zonal.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import pandas as pd
99
import shapely
1010
import xarray as xr
11+
from xarray.groupers import UniqueGrouper
1112

1213

1314
def _agg_rasterize(groups, stats, **kwargs):
@@ -50,19 +51,32 @@ def _zonal_stats_rasterize(
5051
crs = None
5152

5253
transform = acc._obj.rio.transform()
54+
length = len(geometry)
55+
dtype = np.int16 if length < np.iinfo(np.int16).max else np.int32
5356

5457
labels = features.rasterize(
55-
zip(geometry, range(len(geometry)), strict=False),
58+
zip(geometry, range(length), strict=False),
5659
out_shape=(
5760
acc._obj[y_coords].shape[0],
5861
acc._obj[x_coords].shape[0],
5962
),
6063
transform=transform,
61-
fill=np.nan, # type: ignore
64+
fill=length, # type: ignore
6265
all_touched=all_touched,
63-
dtype=np.float32,
66+
dtype=dtype,
6467
)
65-
groups = acc._obj.groupby(xr.DataArray(labels, dims=(y_coords, x_coords)))
68+
69+
unique = np.unique(labels).tolist()
70+
unique.remove(length)
71+
72+
obj = acc._obj.copy()
73+
if isinstance(obj, xr.Dataset):
74+
obj = obj.assign_coords(
75+
__labels__=xr.DataArray(labels, dims=(y_coords, x_coords))
76+
)
77+
else:
78+
obj["__labels__"] = xr.DataArray(labels, dims=(y_coords, x_coords))
79+
groups = obj.groupby({"__labels__": UniqueGrouper(labels=unique)})
6680

6781
if pd.api.types.is_list_like(stats):
6882
agg = {}
@@ -89,10 +103,11 @@ def _zonal_stats_rasterize(
89103
raise ValueError(f"{stats} is not a valid aggregation.")
90104

91105
vec_cube = (
92-
agg_array.reindex(group=range(len(geometry)))
93-
.assign_coords(group=geometry)
94-
.rename(group=name)
95-
).xvec.set_geom_indexes(name, crs=crs)
106+
agg_array.reindex(__labels__=range(length))
107+
.assign_coords(__labels__=geometry)
108+
.rename(__labels__=name)
109+
.xvec.set_geom_indexes(name, crs=crs)
110+
)
96111

97112
del groups
98113
gc.collect()

0 commit comments

Comments
 (0)