1
1
# rasterio wrappers
2
2
from __future__ import annotations
3
3
4
- from collections .abc import Sequence
4
+ import functools
5
+ from collections .abc import Callable , Sequence
5
6
from functools import partial
6
- from typing import TYPE_CHECKING , Any
7
+ from typing import TYPE_CHECKING , Any , TypeVar
7
8
8
9
import geopandas as gpd
9
10
import numpy as np
10
- import odc .geo .xr # noqa
11
11
import rasterio as rio
12
12
import xarray as xr
13
+ from affine import Affine
13
14
from rasterio .features import MergeAlg , geometry_mask
14
15
from rasterio .features import rasterize as rasterize_rio
15
16
16
- from .utils import is_in_memory , prepare_for_dask
17
+ from .utils import XAXIS , YAXIS , get_affine , is_in_memory , prepare_for_dask
18
+
19
+ F = TypeVar ("F" , bound = Callable [..., Any ])
17
20
18
21
if TYPE_CHECKING :
19
22
import dask_geopandas
20
23
21
24
25
+ def with_rio_env (func : F ) -> F :
26
+ """
27
+ Decorator that handles the 'env' and 'clear_cache' kwargs.
28
+ """
29
+
30
+ @functools .wraps (func )
31
+ def wrapper (* args , ** kwargs ):
32
+ env = kwargs .pop ("env" , None )
33
+ clear_cache = kwargs .pop ("clear_cache" , False )
34
+
35
+ if env is None :
36
+ env = rio .Env ()
37
+
38
+ with env :
39
+ # Remove env and clear_cache from kwargs before calling the wrapped function
40
+ # since the function shouldn't handle the context management
41
+ result = func (* args , ** kwargs )
42
+
43
+ if clear_cache :
44
+ with rio .Env (GDAL_CACHEMAX = 0 ):
45
+ # attempt to force-clear the GDAL cache
46
+ pass
47
+
48
+ return result
49
+
50
+ return wrapper
51
+
52
+
22
53
def dask_rasterize_wrapper (
23
54
geom_array : np .ndarray ,
24
- tile_array : np .ndarray ,
55
+ x_offsets : np .ndarray ,
56
+ y_offsets : np .ndarray ,
57
+ x_sizes : np .ndarray ,
58
+ y_sizes : np .ndarray ,
25
59
offset_array : np .ndarray ,
26
60
* ,
27
61
fill : Any ,
62
+ affine : Affine ,
28
63
all_touched : bool ,
29
64
merge_alg : MergeAlg ,
30
65
dtype_ : np .dtype ,
31
66
env : rio .Env | None = None ,
32
67
) -> np .ndarray :
33
- tile = tile_array .item ()
34
68
offset = offset_array .item ()
35
69
36
70
return rasterize_geometries (
37
71
geom_array [:, 0 , 0 ].tolist (),
38
- tile = tile ,
72
+ affine = affine * affine .translation (x_offsets .item (), y_offsets .item ()),
73
+ shape = (y_sizes .item (), x_sizes .item ()),
39
74
offset = offset ,
40
75
all_touched = all_touched ,
41
76
merge_alg = merge_alg ,
@@ -45,44 +80,25 @@ def dask_rasterize_wrapper(
45
80
)[np .newaxis , :, :]
46
81
47
82
83
+ @with_rio_env
48
84
def rasterize_geometries (
49
85
geometries : Sequence [Any ],
50
86
* ,
51
87
dtype : np .dtype ,
52
- tile ,
53
- offset ,
88
+ shape : tuple [int , int ],
89
+ affine : Affine ,
90
+ offset : int ,
54
91
env : rio .Env | None = None ,
55
92
clear_cache : bool = False ,
56
93
** kwargs ,
57
94
):
58
- # From https://rasterio.readthedocs.io/en/latest/api/rasterio.features.html#rasterio.features.rasterize
59
- # The out array will be copied and additional temporary raster memory equal to 2x the smaller of out data
60
- # or GDAL’s max cache size (controlled by GDAL_CACHEMAX, default is 5% of the computer’s physical memory) is required.
61
- # If GDAL max cache size is smaller than the output data, the array of shapes will be iterated multiple times.
62
- # Performance is thus a linear function of buffer size. For maximum speed, ensure that GDAL_CACHEMAX
63
- # is larger than the size of out or out_shape.
64
- if env is None :
65
- # out_size = dtype.itemsize * math.prod(tile.shape)
66
- # env = rio.Env(GDAL_CACHEMAX=1.2 * out_size)
67
- # FIXME: figure out a good default
68
- env = rio .Env ()
69
- with env :
70
- res = rasterize_rio (
71
- zip (geometries , range (offset , offset + len (geometries )), strict = True ),
72
- out_shape = tile .shape ,
73
- transform = tile .affine ,
74
- ** kwargs ,
75
- )
76
- if clear_cache :
77
- with rio .Env (GDAL_CACHEMAX = 0 ):
78
- try :
79
- from osgeo import gdal
80
-
81
- # attempt to force-clear the GDAL cache
82
- assert gdal .GetCacheMax () == 0
83
- except ImportError :
84
- pass
85
- assert res .shape == tile .shape
95
+ res = rasterize_rio (
96
+ zip (geometries , range (offset , offset + len (geometries )), strict = True ),
97
+ out_shape = shape ,
98
+ transform = affine ,
99
+ ** kwargs ,
100
+ )
101
+ assert res .shape == shape
86
102
return res
87
103
88
104
@@ -129,25 +145,30 @@ def rasterize(
129
145
"""
130
146
if xdim not in obj .dims or ydim not in obj .dims :
131
147
raise ValueError (f"Received { xdim = !r} , { ydim = !r} but obj.dims={ tuple (obj .dims )} " )
132
- box = obj .odc .geobox
133
- rasterize_kwargs = dict (all_touched = all_touched , merge_alg = merge_alg )
148
+
149
+ rasterize_kwargs = dict (
150
+ all_touched = all_touched , merge_alg = merge_alg , affine = get_affine (obj , xdim = xdim , ydim = ydim ), env = env
151
+ )
134
152
# FIXME: box.crs == geometries.crs
135
153
if is_in_memory (obj = obj , geometries = geometries ):
136
154
geom_array = geometries .to_numpy ().squeeze (axis = 1 )
137
155
rasterized = rasterize_geometries (
138
156
geom_array .tolist (),
139
- tile = box ,
157
+ shape = ( obj . sizes [ ydim ], obj . sizes [ xdim ]) ,
140
158
offset = 0 ,
141
159
dtype = np .min_scalar_type (len (geometries )),
142
160
fill = len (geometries ),
143
- env = env ,
144
161
** rasterize_kwargs ,
145
162
)
146
163
else :
147
164
from dask .array import from_array , map_blocks
148
165
149
- chunks , tiles_array , geom_array = prepare_for_dask (
150
- obj , geometries , xdim = xdim , ydim = ydim , geoms_rechunk_size = geoms_rechunk_size
166
+ map_blocks_args , chunks , geom_array = prepare_for_dask (
167
+ obj ,
168
+ geometries ,
169
+ xdim = xdim ,
170
+ ydim = ydim ,
171
+ geoms_rechunk_size = geoms_rechunk_size ,
151
172
)
152
173
# DaskGeoDataFrame.len() computes!
153
174
num_geoms = geom_array .size
@@ -159,10 +180,9 @@ def rasterize(
159
180
160
181
rasterized = map_blocks (
161
182
dask_rasterize_wrapper ,
162
- geom_array [:, np .newaxis , np .newaxis ],
163
- tiles_array [np .newaxis , :, :],
183
+ * map_blocks_args ,
164
184
offsets [:, np .newaxis , np .newaxis ],
165
- chunks = ((1 ,) * geom_array .numblocks [0 ], chunks [0 ], chunks [1 ]),
185
+ chunks = ((1 ,) * geom_array .numblocks [0 ], chunks [YAXIS ], chunks [XAXIS ]),
166
186
meta = np .array ([], dtype = dtype ),
167
187
fill = 0 , # good identity value for both sum & replace.
168
188
** rasterize_kwargs ,
@@ -205,54 +225,39 @@ def replace_values(array: np.ndarray, to, *, from_=0) -> np.ndarray:
205
225
206
226
def dask_mask_wrapper (
207
227
geom_array : np .ndarray ,
208
- tile_array : np .ndarray ,
228
+ x_offsets : np .ndarray ,
229
+ y_offsets : np .ndarray ,
230
+ x_sizes : np .ndarray ,
231
+ y_sizes : np .ndarray ,
209
232
* ,
233
+ affine : Affine ,
210
234
all_touched : bool ,
211
235
invert : bool ,
212
236
env : rio .Env | None = None ,
213
237
) -> np .ndarray [Any , np .dtype [np .bool_ ]]:
214
- tile = tile_array .item ()
215
-
216
238
return np_geometry_mask (
217
239
geom_array [:, 0 , 0 ].tolist (),
218
- tile = tile ,
219
- all_touched = all_touched ,
240
+ shape = ( y_sizes . item (), x_sizes . item ()) ,
241
+ affine = affine * affine . translation ( x_offsets . item (), y_offsets . item ()) ,
220
242
invert = invert ,
221
243
env = env ,
222
244
)[np .newaxis , :, :]
223
245
224
246
247
+ @with_rio_env
225
248
def np_geometry_mask (
226
249
geometries : Sequence [Any ],
227
250
* ,
228
- tile ,
251
+ x_offset : int ,
252
+ y_offset : int ,
253
+ shape : tuple [int , int ],
254
+ affine : Affine ,
229
255
env : rio .Env | None = None ,
230
256
clear_cache : bool = False ,
231
257
** kwargs ,
232
258
) -> np .ndarray [Any , np .dtype [np .bool_ ]]:
233
- # From https://rasterio.readthedocs.io/en/latest/api/rasterio.features.html#rasterio.features.rasterize
234
- # The out array will be copied and additional temporary raster memory equal to 2x the smaller of out data
235
- # or GDAL’s max cache size (controlled by GDAL_CACHEMAX, default is 5% of the computer’s physical memory) is required.
236
- # If GDAL max cache size is smaller than the output data, the array of shapes will be iterated multiple times.
237
- # Performance is thus a linear function of buffer size. For maximum speed, ensure that GDAL_CACHEMAX
238
- # is larger than the size of out or out_shape.
239
- if env is None :
240
- # out_size = np.bool_.itemsize * math.prod(tile.shape)
241
- # env = rio.Env(GDAL_CACHEMAX=1.2 * out_size)
242
- # FIXME: figure out a good default
243
- env = rio .Env ()
244
- with env :
245
- res = geometry_mask (geometries , out_shape = tile .shape , transform = tile .affine , ** kwargs )
246
- if clear_cache :
247
- with rio .Env (GDAL_CACHEMAX = 0 ):
248
- try :
249
- from osgeo import gdal
250
-
251
- # attempt to force-clear the GDAL cache
252
- assert gdal .GetCacheMax () == 0
253
- except ImportError :
254
- pass
255
- assert res .shape == tile .shape
259
+ res = geometry_mask (geometries , out_shape = shape , transform = affine , ** kwargs )
260
+ assert res .shape == shape
256
261
return res
257
262
258
263
@@ -298,23 +303,31 @@ def geometry_clip(
298
303
invert = not invert # rioxarray clip convention -> rasterio geometry_mask convention
299
304
if xdim not in obj .dims or ydim not in obj .dims :
300
305
raise ValueError (f"Received { xdim = !r} , { ydim = !r} but obj.dims={ tuple (obj .dims )} " )
301
- box = obj .odc .geobox
302
- geometry_mask_kwargs = dict (all_touched = all_touched , invert = invert )
306
+ geometry_mask_kwargs = dict (
307
+ all_touched = all_touched , invert = invert , affine = get_affine (obj , xdim = xdim , ydim = ydim ), env = env
308
+ )
303
309
304
310
if is_in_memory (obj = obj , geometries = geometries ):
305
311
geom_array = geometries .to_numpy ().squeeze (axis = 1 )
306
- mask = np_geometry_mask (geom_array .tolist (), tile = box , env = env , ** geometry_mask_kwargs )
312
+ mask = np_geometry_mask (
313
+ geom_array .tolist (),
314
+ shape = (obj .sizes [ydim ], obj .sizes [xdim ]),
315
+ ** geometry_mask_kwargs ,
316
+ )
307
317
else :
308
318
from dask .array import map_blocks
309
319
310
- chunks , tiles_array , geom_array = prepare_for_dask (
311
- obj , geometries , xdim = xdim , ydim = ydim , geoms_rechunk_size = geoms_rechunk_size
320
+ map_blocks_args , chunks , geom_array = prepare_for_dask (
321
+ obj ,
322
+ geometries ,
323
+ xdim = xdim ,
324
+ ydim = ydim ,
325
+ geoms_rechunk_size = geoms_rechunk_size ,
312
326
)
313
327
mask = map_blocks (
314
328
dask_mask_wrapper ,
315
- geom_array [:, np .newaxis , np .newaxis ],
316
- tiles_array [np .newaxis , :, :],
317
- chunks = ((1 ,) * geom_array .numblocks [0 ], chunks [0 ], chunks [1 ]),
329
+ * map_blocks_args ,
330
+ chunks = ((1 ,) * geom_array .numblocks [0 ], chunks [YAXIS ], chunks [XAXIS ]),
318
331
meta = np .array ([], dtype = bool ),
319
332
** geometry_mask_kwargs ,
320
333
)
0 commit comments