Skip to content

Commit 20ed76a

Browse files
authored
Rasterize: Remove odc-geo dependency (#37)
1 parent 8d02884 commit 20ed76a

File tree

3 files changed

+127
-103
lines changed

3 files changed

+127
-103
lines changed

pyproject.toml

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,17 +35,12 @@ dynamic=["version"]
3535

3636
[project.optional-dependencies]
3737
dask = ["dask-geopandas"]
38-
rasterize = [
39-
"odc-geo",
40-
"rasterio",
41-
"rioxarray",
42-
]
38+
rasterize = ["rasterio"]
4339
exactextract = ["exactextract", "sparse"]
4440
test = [
4541
"geodatasets",
4642
"pooch",
4743
"dask-geopandas",
48-
"odc-geo",
4944
"rasterio",
5045
"rioxarray",
5146
"exactextract",

src/rasterix/rasterize/rasterio.py

Lines changed: 96 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,76 @@
11
# rasterio wrappers
22
from __future__ import annotations
33

4-
from collections.abc import Sequence
4+
import functools
5+
from collections.abc import Callable, Sequence
56
from functools import partial
6-
from typing import TYPE_CHECKING, Any
7+
from typing import TYPE_CHECKING, Any, TypeVar
78

89
import geopandas as gpd
910
import numpy as np
10-
import odc.geo.xr # noqa
1111
import rasterio as rio
1212
import xarray as xr
13+
from affine import Affine
1314
from rasterio.features import MergeAlg, geometry_mask
1415
from rasterio.features import rasterize as rasterize_rio
1516

16-
from .utils import is_in_memory, prepare_for_dask
17+
from .utils import XAXIS, YAXIS, get_affine, is_in_memory, prepare_for_dask
18+
19+
F = TypeVar("F", bound=Callable[..., Any])
1720

1821
if TYPE_CHECKING:
1922
import dask_geopandas
2023

2124

25+
def with_rio_env(func: F) -> F:
26+
"""
27+
Decorator that handles the 'env' and 'clear_cache' kwargs.
28+
"""
29+
30+
@functools.wraps(func)
31+
def wrapper(*args, **kwargs):
32+
env = kwargs.pop("env", None)
33+
clear_cache = kwargs.pop("clear_cache", False)
34+
35+
if env is None:
36+
env = rio.Env()
37+
38+
with env:
39+
# Remove env and clear_cache from kwargs before calling the wrapped function
40+
# since the function shouldn't handle the context management
41+
result = func(*args, **kwargs)
42+
43+
if clear_cache:
44+
with rio.Env(GDAL_CACHEMAX=0):
45+
# attempt to force-clear the GDAL cache
46+
pass
47+
48+
return result
49+
50+
return wrapper
51+
52+
2253
def dask_rasterize_wrapper(
2354
geom_array: np.ndarray,
24-
tile_array: np.ndarray,
55+
x_offsets: np.ndarray,
56+
y_offsets: np.ndarray,
57+
x_sizes: np.ndarray,
58+
y_sizes: np.ndarray,
2559
offset_array: np.ndarray,
2660
*,
2761
fill: Any,
62+
affine: Affine,
2863
all_touched: bool,
2964
merge_alg: MergeAlg,
3065
dtype_: np.dtype,
3166
env: rio.Env | None = None,
3267
) -> np.ndarray:
33-
tile = tile_array.item()
3468
offset = offset_array.item()
3569

3670
return rasterize_geometries(
3771
geom_array[:, 0, 0].tolist(),
38-
tile=tile,
72+
affine=affine * affine.translation(x_offsets.item(), y_offsets.item()),
73+
shape=(y_sizes.item(), x_sizes.item()),
3974
offset=offset,
4075
all_touched=all_touched,
4176
merge_alg=merge_alg,
@@ -45,44 +80,25 @@ def dask_rasterize_wrapper(
4580
)[np.newaxis, :, :]
4681

4782

83+
@with_rio_env
4884
def rasterize_geometries(
4985
geometries: Sequence[Any],
5086
*,
5187
dtype: np.dtype,
52-
tile,
53-
offset,
88+
shape: tuple[int, int],
89+
affine: Affine,
90+
offset: int,
5491
env: rio.Env | None = None,
5592
clear_cache: bool = False,
5693
**kwargs,
5794
):
58-
# From https://rasterio.readthedocs.io/en/latest/api/rasterio.features.html#rasterio.features.rasterize
59-
# The out array will be copied and additional temporary raster memory equal to 2x the smaller of out data
60-
# or GDAL’s max cache size (controlled by GDAL_CACHEMAX, default is 5% of the computer’s physical memory) is required.
61-
# If GDAL max cache size is smaller than the output data, the array of shapes will be iterated multiple times.
62-
# Performance is thus a linear function of buffer size. For maximum speed, ensure that GDAL_CACHEMAX
63-
# is larger than the size of out or out_shape.
64-
if env is None:
65-
# out_size = dtype.itemsize * math.prod(tile.shape)
66-
# env = rio.Env(GDAL_CACHEMAX=1.2 * out_size)
67-
# FIXME: figure out a good default
68-
env = rio.Env()
69-
with env:
70-
res = rasterize_rio(
71-
zip(geometries, range(offset, offset + len(geometries)), strict=True),
72-
out_shape=tile.shape,
73-
transform=tile.affine,
74-
**kwargs,
75-
)
76-
if clear_cache:
77-
with rio.Env(GDAL_CACHEMAX=0):
78-
try:
79-
from osgeo import gdal
80-
81-
# attempt to force-clear the GDAL cache
82-
assert gdal.GetCacheMax() == 0
83-
except ImportError:
84-
pass
85-
assert res.shape == tile.shape
95+
res = rasterize_rio(
96+
zip(geometries, range(offset, offset + len(geometries)), strict=True),
97+
out_shape=shape,
98+
transform=affine,
99+
**kwargs,
100+
)
101+
assert res.shape == shape
86102
return res
87103

88104

@@ -129,25 +145,30 @@ def rasterize(
129145
"""
130146
if xdim not in obj.dims or ydim not in obj.dims:
131147
raise ValueError(f"Received {xdim=!r}, {ydim=!r} but obj.dims={tuple(obj.dims)}")
132-
box = obj.odc.geobox
133-
rasterize_kwargs = dict(all_touched=all_touched, merge_alg=merge_alg)
148+
149+
rasterize_kwargs = dict(
150+
all_touched=all_touched, merge_alg=merge_alg, affine=get_affine(obj, xdim=xdim, ydim=ydim), env=env
151+
)
134152
# FIXME: box.crs == geometries.crs
135153
if is_in_memory(obj=obj, geometries=geometries):
136154
geom_array = geometries.to_numpy().squeeze(axis=1)
137155
rasterized = rasterize_geometries(
138156
geom_array.tolist(),
139-
tile=box,
157+
shape=(obj.sizes[ydim], obj.sizes[xdim]),
140158
offset=0,
141159
dtype=np.min_scalar_type(len(geometries)),
142160
fill=len(geometries),
143-
env=env,
144161
**rasterize_kwargs,
145162
)
146163
else:
147164
from dask.array import from_array, map_blocks
148165

149-
chunks, tiles_array, geom_array = prepare_for_dask(
150-
obj, geometries, xdim=xdim, ydim=ydim, geoms_rechunk_size=geoms_rechunk_size
166+
map_blocks_args, chunks, geom_array = prepare_for_dask(
167+
obj,
168+
geometries,
169+
xdim=xdim,
170+
ydim=ydim,
171+
geoms_rechunk_size=geoms_rechunk_size,
151172
)
152173
# DaskGeoDataFrame.len() computes!
153174
num_geoms = geom_array.size
@@ -159,10 +180,9 @@ def rasterize(
159180

160181
rasterized = map_blocks(
161182
dask_rasterize_wrapper,
162-
geom_array[:, np.newaxis, np.newaxis],
163-
tiles_array[np.newaxis, :, :],
183+
*map_blocks_args,
164184
offsets[:, np.newaxis, np.newaxis],
165-
chunks=((1,) * geom_array.numblocks[0], chunks[0], chunks[1]),
185+
chunks=((1,) * geom_array.numblocks[0], chunks[YAXIS], chunks[XAXIS]),
166186
meta=np.array([], dtype=dtype),
167187
fill=0, # good identity value for both sum & replace.
168188
**rasterize_kwargs,
@@ -205,54 +225,39 @@ def replace_values(array: np.ndarray, to, *, from_=0) -> np.ndarray:
205225

206226
def dask_mask_wrapper(
207227
geom_array: np.ndarray,
208-
tile_array: np.ndarray,
228+
x_offsets: np.ndarray,
229+
y_offsets: np.ndarray,
230+
x_sizes: np.ndarray,
231+
y_sizes: np.ndarray,
209232
*,
233+
affine: Affine,
210234
all_touched: bool,
211235
invert: bool,
212236
env: rio.Env | None = None,
213237
) -> np.ndarray[Any, np.dtype[np.bool_]]:
214-
tile = tile_array.item()
215-
216238
return np_geometry_mask(
217239
geom_array[:, 0, 0].tolist(),
218-
tile=tile,
219-
all_touched=all_touched,
240+
shape=(y_sizes.item(), x_sizes.item()),
241+
affine=affine * affine.translation(x_offsets.item(), y_offsets.item()),
220242
invert=invert,
221243
env=env,
222244
)[np.newaxis, :, :]
223245

224246

247+
@with_rio_env
225248
def np_geometry_mask(
226249
geometries: Sequence[Any],
227250
*,
228-
tile,
251+
x_offset: int,
252+
y_offset: int,
253+
shape: tuple[int, int],
254+
affine: Affine,
229255
env: rio.Env | None = None,
230256
clear_cache: bool = False,
231257
**kwargs,
232258
) -> np.ndarray[Any, np.dtype[np.bool_]]:
233-
# From https://rasterio.readthedocs.io/en/latest/api/rasterio.features.html#rasterio.features.rasterize
234-
# The out array will be copied and additional temporary raster memory equal to 2x the smaller of out data
235-
# or GDAL’s max cache size (controlled by GDAL_CACHEMAX, default is 5% of the computer’s physical memory) is required.
236-
# If GDAL max cache size is smaller than the output data, the array of shapes will be iterated multiple times.
237-
# Performance is thus a linear function of buffer size. For maximum speed, ensure that GDAL_CACHEMAX
238-
# is larger than the size of out or out_shape.
239-
if env is None:
240-
# out_size = np.bool_.itemsize * math.prod(tile.shape)
241-
# env = rio.Env(GDAL_CACHEMAX=1.2 * out_size)
242-
# FIXME: figure out a good default
243-
env = rio.Env()
244-
with env:
245-
res = geometry_mask(geometries, out_shape=tile.shape, transform=tile.affine, **kwargs)
246-
if clear_cache:
247-
with rio.Env(GDAL_CACHEMAX=0):
248-
try:
249-
from osgeo import gdal
250-
251-
# attempt to force-clear the GDAL cache
252-
assert gdal.GetCacheMax() == 0
253-
except ImportError:
254-
pass
255-
assert res.shape == tile.shape
259+
res = geometry_mask(geometries, out_shape=shape, transform=affine, **kwargs)
260+
assert res.shape == shape
256261
return res
257262

258263

@@ -298,23 +303,31 @@ def geometry_clip(
298303
invert = not invert # rioxarray clip convention -> rasterio geometry_mask convention
299304
if xdim not in obj.dims or ydim not in obj.dims:
300305
raise ValueError(f"Received {xdim=!r}, {ydim=!r} but obj.dims={tuple(obj.dims)}")
301-
box = obj.odc.geobox
302-
geometry_mask_kwargs = dict(all_touched=all_touched, invert=invert)
306+
geometry_mask_kwargs = dict(
307+
all_touched=all_touched, invert=invert, affine=get_affine(obj, xdim=xdim, ydim=ydim), env=env
308+
)
303309

304310
if is_in_memory(obj=obj, geometries=geometries):
305311
geom_array = geometries.to_numpy().squeeze(axis=1)
306-
mask = np_geometry_mask(geom_array.tolist(), tile=box, env=env, **geometry_mask_kwargs)
312+
mask = np_geometry_mask(
313+
geom_array.tolist(),
314+
shape=(obj.sizes[ydim], obj.sizes[xdim]),
315+
**geometry_mask_kwargs,
316+
)
307317
else:
308318
from dask.array import map_blocks
309319

310-
chunks, tiles_array, geom_array = prepare_for_dask(
311-
obj, geometries, xdim=xdim, ydim=ydim, geoms_rechunk_size=geoms_rechunk_size
320+
map_blocks_args, chunks, geom_array = prepare_for_dask(
321+
obj,
322+
geometries,
323+
xdim=xdim,
324+
ydim=ydim,
325+
geoms_rechunk_size=geoms_rechunk_size,
312326
)
313327
mask = map_blocks(
314328
dask_mask_wrapper,
315-
geom_array[:, np.newaxis, np.newaxis],
316-
tiles_array[np.newaxis, :, :],
317-
chunks=((1,) * geom_array.numblocks[0], chunks[0], chunks[1]),
329+
*map_blocks_args,
330+
chunks=((1,) * geom_array.numblocks[0], chunks[YAXIS], chunks[XAXIS]),
318331
meta=np.array([], dtype=bool),
319332
**geometry_mask_kwargs,
320333
)

src/rasterix/rasterize/utils.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,29 @@
66
import geopandas as gpd
77
import numpy as np
88
import xarray as xr
9-
from odc.geo.geobox import GeoboxTiles
9+
from affine import Affine
1010

1111
if TYPE_CHECKING:
1212
import dask.array
1313
import dask_geopandas
1414

1515

16-
def tiles_to_array(tiles: GeoboxTiles) -> np.ndarray:
17-
shape = tiles.shape
18-
array = np.empty(shape=(shape.y, shape.x), dtype=object)
19-
for i in range(shape.x):
20-
for j in range(shape.y):
21-
array[j, i] = tiles[j, i]
16+
YAXIS = 0
17+
XAXIS = 1
2218

23-
assert array.shape == tiles.shape
24-
return array
19+
20+
def get_affine(obj: xr.Dataset | xr.DataArray, *, xdim="x", ydim="y") -> Affine:
21+
spatial_ref = obj.coords["spatial_ref"]
22+
if "GeoTransform" in spatial_ref.attrs:
23+
return Affine.from_gdal(*map(float, spatial_ref.attrs["GeoTransform"].split(" ")))
24+
else:
25+
x = obj.coords[xdim]
26+
y = obj.coords[ydim]
27+
dx = x[1] - x[0]
28+
dy = y[1] - y[0]
29+
return Affine.translation(
30+
x[0] - dx / 2, (y[0] - dy / 2) if dy > 0 else (y[-1] + dy / 2)
31+
) * Affine.scale(dx, dy)
2532

2633

2734
def is_in_memory(*, obj, geometries) -> bool:
@@ -61,15 +68,24 @@ def prepare_for_dask(
6168
):
6269
from dask.array import from_array
6370

64-
box = obj.odc.geobox
65-
6671
chunks = (
6772
obj.chunksizes.get(ydim, obj.sizes[ydim]),
6873
obj.chunksizes.get(xdim, obj.sizes[ydim]),
6974
)
70-
tiles = GeoboxTiles(box, tile_shape=chunks)
71-
tiles_array = from_array(tiles_to_array(tiles), chunks=(1, 1))
7275
geom_array = geometries_as_dask_array(geometries)
7376
if geoms_rechunk_size is not None:
7477
geom_array = geom_array.rechunk({0: geoms_rechunk_size})
75-
return chunks, tiles_array, geom_array
78+
79+
x_sizes = from_array(chunks[XAXIS], chunks=1)
80+
y_sizes = from_array(chunks[YAXIS], chunks=1)
81+
y_offsets = from_array(np.cumulative_sum(chunks[YAXIS][:-1], include_initial=True), chunks=1)
82+
x_offsets = from_array(np.cumulative_sum(chunks[XAXIS][:-1], include_initial=True), chunks=1)
83+
84+
map_blocks_args = (
85+
geom_array[:, np.newaxis, np.newaxis],
86+
x_offsets[np.newaxis, np.newaxis, :],
87+
y_offsets[np.newaxis, :, np.newaxis],
88+
x_sizes[np.newaxis, np.newaxis, :],
89+
y_sizes[np.newaxis, :, np.newaxis],
90+
)
91+
return map_blocks_args, chunks, geom_array

0 commit comments

Comments
 (0)