Skip to content

Commit 41e2c3c

Browse files
authored
exactextract improvements (#27)
1 parent fe7dae0 commit 41e2c3c

File tree

6 files changed

+312
-87
lines changed

6 files changed

+312
-87
lines changed

.github/workflows/test.yml

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,34 @@ jobs:
3737
run: |
3838
python -m pip install --upgrade pip
3939
pip install hatch
40+
41+
# https://github.com/actions/cache/blob/main/tips-and-workarounds.md#update-a-cache
42+
- name: Restore cached hypothesis directory
43+
id: restore-hypothesis-cache
44+
uses: actions/cache/restore@v4
45+
with:
46+
path: .hypothesis/
47+
key: cache-hypothesis-${{ runner.os }}-${{ github.run_id }}
48+
restore-keys: |
49+
cache-hypothesis-
50+
4051
- name: Set Up Hatch Env
4152
run: |
4253
hatch env create test.py${{ matrix.python-version }}
4354
hatch env run -e test.py${{ matrix.python-version }} list-env
4455
- name: Run Tests
4556
run: |
4657
hatch env run --env test.py${{ matrix.python-version }} run-coverage
58+
59+
# explicitly save the cache so it gets updated, also do this even if it fails.
60+
- name: Save cached hypothesis directory
61+
id: save-hypothesis-cache
62+
if: always()
63+
uses: actions/cache/save@v4
64+
with:
65+
path: .hypothesis/
66+
key: cache-hypothesis-${{ runner.os }}-${{ github.run_id }}
67+
4768
# - name: Upload coverage
4869
# uses: codecov/codecov-action@v5
4970
# with:

pyproject.toml

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,18 @@ rasterize = [
4040
"rasterio",
4141
"rioxarray",
4242
]
43-
exactextract = ["exactextract"]
43+
exactextract = ["exactextract", "sparse"]
4444
test = [
4545
"geodatasets",
46+
"pooch",
4647
"dask-geopandas",
4748
"odc-geo",
4849
"rasterio",
4950
"rioxarray",
5051
"exactextract",
51-
"netCDF4"
52+
"sparse",
53+
"netCDF4",
54+
"hypothesis",
5255
]
5356

5457
[tool.hatch]
@@ -71,15 +74,16 @@ dependencies = [
7174
"coverage",
7275
"pytest",
7376
"pytest-cov",
77+
"pytest-xdist"
7478
]
7579
features = ["test"]
7680

7781
[[tool.hatch.envs.test.matrix]]
7882
python = ["3.10", "3.13"]
7983

8084
[tool.hatch.envs.test.scripts]
81-
run-coverage = "pytest --cov-config=pyproject.toml --cov=pkg --cov-report xml --cov=src --junitxml=junit.xml -o junit_family=legacy"
82-
run-coverage-html = "pytest --cov-config=pyproject.toml --cov=pkg --cov-report html --cov=src"
85+
run-coverage = "pytest -nauto --cov-config=pyproject.toml --cov=pkg --cov-report xml --cov=src --junitxml=junit.xml -o junit_family=legacy"
86+
run-coverage-html = "pytest -nauto --cov-config=pyproject.toml --cov=pkg --cov-report html --cov=src"
8387
run-pytest = "run-coverage --no-cov"
8488
run-verbose = "run-coverage --verbose"
8589
run-mypy = "mypy src"

src/rasterix/rasterize/exact.py

Lines changed: 136 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import geopandas as gpd
77
import numpy as np
8+
import sparse
89
import xarray as xr
910
from exactextract import exact_extract
1011
from exactextract.raster import NumPyRasterSource
@@ -16,39 +17,23 @@
1617
import dask_geopandas
1718

1819
MIN_CHUNK_SIZE = 2 # exactextract cannot handle arrays of size 1.
20+
GEOM_AXIS = 0
21+
Y_AXIS = 1
22+
X_AXIS = 2
23+
24+
DEFAULT_STRATEGY = "feature-sequential"
25+
Strategy = Literal["feature-sequential", "raster-sequential", "raster-parallel"]
26+
CoverageWeights = Literal["area_spherical_m2", "area_cartesian", "area_spherical_km2", "fraction", "none"]
1927

2028
__all__ = [
2129
"coverage",
2230
]
2331

2432

25-
def get_dtype(coverage_weight, geometries):
26-
if coverage_weight.lower() == "fraction":
27-
dtype = "float64"
28-
elif coverage_weight.lower() == "none":
29-
dtype = np.min_scalar_type(len(geometries))
30-
else:
31-
raise NotImplementedError
32-
return dtype
33-
34-
35-
def np_coverage(
36-
x: np.ndarray,
37-
y: np.ndarray,
38-
*,
39-
geometries: gpd.GeoDataFrame,
40-
coverage_weight: Literal["fraction", "none"] = "fraction",
41-
) -> np.ndarray[Any, Any]:
42-
"""
43-
Parameters
44-
----------
45-
46-
"""
33+
def xy_to_raster_source(x: np.ndarray, y: np.ndarray, *, srs_wkt: str | None) -> NumPyRasterSource:
4734
assert x.ndim == 1
4835
assert y.ndim == 1
4936

50-
dtype = get_dtype(coverage_weight, geometries)
51-
5237
xsize = x.size
5338
ysize = y.size
5439

@@ -67,32 +52,90 @@ def np_coverage(
6752
xmax=x.max() + dx1,
6853
ymin=y.min() - dy0,
6954
ymax=y.max() + dy1,
70-
srs_wkt=geometries.crs.to_wkt(),
55+
srs_wkt=srs_wkt,
7156
)
57+
58+
return raster
59+
60+
61+
def get_dtype(coverage_weight: CoverageWeights, geometries):
62+
if coverage_weight.lower() == "none":
63+
dtype = np.uint8
64+
else:
65+
dtype = np.float64
66+
return dtype
67+
68+
69+
def np_coverage(
70+
x: np.ndarray,
71+
y: np.ndarray,
72+
*,
73+
geometries: gpd.GeoDataFrame,
74+
strategy: Strategy = DEFAULT_STRATEGY,
75+
coverage_weight: CoverageWeights = "fraction",
76+
) -> np.ndarray[Any, Any]:
77+
"""
78+
Parameters
79+
----------
80+
81+
"""
82+
dtype = get_dtype(coverage_weight, geometries)
83+
84+
if len(geometries.columns) > 1:
85+
raise ValueError("Require a single geometries column or a GeoSeries.")
86+
87+
shape = (y.size, x.size)
88+
raster = xy_to_raster_source(x, y, srs_wkt=geometries.crs.to_wkt())
7289
result = exact_extract(
7390
rast=raster,
7491
vec=geometries,
7592
ops=["cell_id", f"coverage(coverage_weight={coverage_weight})"],
7693
output="pandas",
7794
# max_cells_in_memory=2*x.size * y.size
7895
)
79-
out = np.zeros((len(geometries), *shape), dtype=dtype)
80-
# TODO: vectorized assignment?
96+
97+
lens = np.vectorize(len)(result.cell_id.values)
98+
nnz = np.sum(lens)
99+
100+
# Notes on GCXS vs COO, For N data points in 263 geoms by 4000 x by 4000 y
101+
# 1. GCXS cannot compress _all_ axes. This is relevant here.
102+
# 2. GCXS: indptr is 4000*4000 + 1, N per indices & N per data
103+
# 3. COO: 4*N
104+
# It is not obvious that there is much improvement to GCXS at least currently
105+
geom_idxs = np.empty((nnz,), dtype=np.int64)
106+
xy_idxs = np.empty((nnz,), dtype=np.int64)
107+
data = np.empty((nnz,), dtype=dtype)
108+
109+
off = 0
81110
for i in range(len(geometries)):
82-
res = result.loc[i]
83-
# indices = np.unravel_index(res.cell_id, shape=shape)
84-
# out[(i, *indices)] = offset + i + 1 # 0 is the fill value
85-
out[i, ...].flat[res.cell_id] = res.coverage
86-
return out
111+
cell_id = result.cell_id.values[i]
112+
if cell_id.size == 0:
113+
continue
114+
geom_idxs[off : off + cell_id.size] = i
115+
xy_idxs[off : off + cell_id.size] = cell_id
116+
data[off : off + cell_id.size] = result.coverage.values[i]
117+
off += cell_id.size
118+
return sparse.COO(
119+
(geom_idxs, *np.unravel_index(xy_idxs, shape=shape)),
120+
data=data,
121+
sorted=True,
122+
fill_value=0,
123+
shape=(len(geometries), *shape),
124+
)
87125

88126

89127
def coverage_np_dask_wrapper(
90-
x: np.ndarray, y: np.ndarray, geom_array: np.ndarray, coverage_weight, crs
128+
geom_array: np.ndarray,
129+
x: np.ndarray,
130+
y: np.ndarray,
131+
coverage_weight: CoverageWeights,
132+
strategy: Strategy,
133+
crs,
91134
) -> np.ndarray:
92135
return np_coverage(
93-
x=x,
94-
y=y,
95-
geometries=gpd.GeoDataFrame(geometry=geom_array, crs=crs),
136+
x=x.squeeze(axis=(GEOM_AXIS, Y_AXIS)),
137+
y=y.squeeze(axis=(GEOM_AXIS, X_AXIS)),
138+
geometries=gpd.GeoDataFrame(geometry=geom_array.squeeze(axis=(X_AXIS, Y_AXIS)), crs=crs),
96139
coverage_weight=coverage_weight,
97140
)
98141

@@ -102,27 +145,29 @@ def dask_coverage(
102145
y: dask.array.Array,
103146
*,
104147
geom_array: dask.array.Array,
105-
coverage_weight: Literal["fraction", "none"] = "fraction",
148+
coverage_weight: CoverageWeights = "fraction",
149+
strategy: Strategy = DEFAULT_STRATEGY,
106150
crs: Any,
107151
) -> dask.array.Array:
108152
import dask.array
109153

110-
if any(c == 1 for c in x.chunks) or any(c == 1 for c in y.chunks):
154+
if any(c == 1 for c in x.chunks[0]) or any(c == 1 for c in y.chunks[0]):
111155
raise ValueError("exactextract does not support a chunksize of 1. Please rechunk to avoid this")
112156

113-
return dask.array.blockwise(
157+
out = dask.array.map_blocks(
114158
coverage_np_dask_wrapper,
115-
"gji",
116-
x,
117-
"i",
118-
y,
119-
"j",
120-
geom_array,
121-
"g",
159+
geom_array[:, np.newaxis, np.newaxis],
160+
x[np.newaxis, np.newaxis, :],
161+
y[np.newaxis, :, np.newaxis],
122162
crs=crs,
123163
coverage_weight=coverage_weight,
124-
dtype=get_dtype(coverage_weight, geom_array),
164+
strategy=strategy,
165+
chunks=(*geom_array.chunks, *y.chunks, *x.chunks),
166+
meta=sparse.COO(
167+
[], data=np.array([], dtype=get_dtype(coverage_weight, geom_array)), shape=(0, 0, 0), fill_value=0
168+
),
125169
)
170+
return out
126171

127172

128173
def coverage(
@@ -131,7 +176,8 @@ def coverage(
131176
*,
132177
xdim="x",
133178
ydim="y",
134-
coverage_weight="fraction",
179+
strategy: Strategy = "feature-sequential",
180+
coverage_weight: CoverageWeights = "fraction",
135181
) -> xr.DataArray:
136182
"""
137183
Returns "coverage" fractions for each pixel for each geometry calculated using exactextract.
@@ -163,35 +209,62 @@ def coverage(
163209
y=obj[ydim].data,
164210
geometries=geometries,
165211
coverage_weight=coverage_weight,
212+
strategy=strategy,
166213
)
167214
geom_array = geometries.to_numpy().squeeze(axis=1)
168215
else:
169-
from dask.array import from_array
216+
from dask.array import Array, from_array
170217

171218
geom_dask_array = geometries_as_dask_array(geometries)
219+
if not isinstance(obj[xdim].data, Array):
220+
dask_x = from_array(obj[xdim].data, chunks=obj.chunksizes.get(xdim, -1))
221+
else:
222+
dask_x = obj[xdim].data
223+
224+
if not isinstance(obj[ydim].data, Array):
225+
dask_y = from_array(obj[ydim].data, chunks=obj.chunksizes.get(ydim, -1))
226+
else:
227+
dask_y = obj[ydim].data
228+
172229
out = dask_coverage(
173-
x=from_array(obj[xdim].data, chunks=obj.chunksizes.get(xdim, -1)),
174-
y=from_array(obj[ydim].data, chunks=obj.chunksizes.get(ydim, -1)),
230+
x=dask_x,
231+
y=dask_y,
175232
geom_array=geom_dask_array,
176233
crs=geometries.crs,
177234
coverage_weight=coverage_weight,
235+
strategy=strategy,
178236
)
179237
if isinstance(geometries, gpd.GeoDataFrame):
180238
geom_array = geometries.to_numpy().squeeze(axis=1)
181239
else:
182240
geom_array = geom_dask_array
183241

184-
coverage = xr.DataArray(
185-
dims=("geometry", ydim, xdim),
186-
data=out,
187-
coords=xr.Coordinates(
188-
coords={
189-
xdim: obj.coords[xdim],
190-
ydim: obj.coords[ydim],
191-
"spatial_ref": obj.spatial_ref,
192-
"geometry": geom_array,
193-
},
194-
indexes={xdim: obj.xindexes[xdim], ydim: obj.xindexes[ydim]},
195-
),
242+
name = "coverage"
243+
attrs = {}
244+
if "area" in coverage_weight:
245+
name = "area"
246+
if "_m2" in coverage_weight or coverage_weight == "area_cartesian":
247+
attrs["long_name"] = coverage_weight.removesuffix("_m2")
248+
attrs["units"] = "m2"
249+
elif "_km2" in coverage_weight:
250+
attrs["long_name"] = coverage_weight.removesuffix("_km2")
251+
attrs["units"] = "km2"
252+
253+
xy_coords = [
254+
xr.Coordinates.from_xindex(obj.xindexes.get(dim))
255+
for dim in (xdim, ydim)
256+
if obj.xindexes.get(dim) is not None
257+
]
258+
coords = xr.Coordinates(
259+
coords={
260+
"spatial_ref": obj.spatial_ref,
261+
"geometry": geom_array,
262+
},
263+
indexes={},
196264
)
265+
if xy_coords:
266+
for c in xy_coords:
267+
coords = coords.merge(c)
268+
coords = coords.coords
269+
coverage = xr.DataArray(dims=("geometry", ydim, xdim), data=out, coords=coords, attrs=attrs, name=name)
197270
return coverage

src/rasterix/rasterize/utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import uuid
34
from typing import TYPE_CHECKING
45

56
import geopandas as gpd
@@ -33,7 +34,13 @@ def geometries_as_dask_array(
3334
from dask.array import from_array
3435

3536
if isinstance(geometries, gpd.GeoDataFrame):
36-
return from_array(geometries.geometry.to_numpy(), chunks=-1)
37+
return from_array(
38+
geometries.geometry.to_numpy(),
39+
chunks=-1,
40+
# This is what dask-geopandas does
41+
# It avoids pickling geometries, which can be expensive (calls to_wkb)
42+
name=uuid.uuid4().hex,
43+
)
3744
else:
3845
divisions = geometries.divisions
3946
if any(d is None for d in divisions):

0 commit comments

Comments
 (0)