Skip to content

Commit 69061dd

Browse files
dneriniaperezhortalloforestpulkkins
authored
Refactor utils.interpolate module (#210)
* Rename to iwdinterp2d; adapt docstrings, variable names, and interface * Fix interface * Initial wrapper to scipy's method * Aesthetics * Use mode=N-D to do multivariate interpolation * Fix black check By upgrading black to latest version * Support parameters of the original function * Refactor idw method * Aesthetics * Implement checks as a decorator * Add tests * Aesthetics * Remove unnecessary test * Remove unnecessary test * Test for wrong number of dimensions * Chunking of the dst grid is done in the preamble * Fix chunking * Simplify idw * Add LRU caching of interpolator classes * Small adjustments * Change default interp method to idw * Adapt test * Adapt examples * Silence some warnings And other super minor changes * Fix black check * Fix tests * Aesthetics * Fix docstrings Or at least try... * Fix univariate interpolation And include test * Fix black check * Super minor changes * Update precommit config * Define RBF in docstring * Add function to append extra kwrds to docstrings * Rename decorator * Test idw with k=1 or k=None * Test that all outputs are finite nuumberrs * Update pysteps/utils/interpolate.py Co-authored-by: Loris Foresti <39999237+loforest@users.noreply.github.com> * Improve variable naming * Workaround for TypeError issue with Pillow 8.3 * Add abbreviation for inverse distance weighting * Fix typo * Docstring polishing in idwinterp2d * Small refactoring * Docstring polishing * Improve docstrings * Adapt variable naming * Inputs must be ndarrays * Fix typo * Make the offset contant a user-selectable parameter Co-authored-by: Andres Perez Hortal <16256571+aperezhortal@users.noreply.github.com> Co-authored-by: Loris Foresti <39999237+loforest@users.noreply.github.com> Co-authored-by: Seppo Pulkkinen <pulkkins@gmail.com>
1 parent 76324d8 commit 69061dd

15 files changed

+484
-221
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
repos:
22
- repo: https://github.com/psf/black
3-
rev: 20.8b1
3+
rev: 21.6b0
44
hooks:
55
- id: black
66
language_version: python3

examples/LK_buffer_mask.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -169,16 +169,9 @@
169169
# By comparing the velocity of the motion fields, we can easily notice
170170
# the negative bias that is introduced by the the erroneous interpretation of
171171
# velocities near the maximum range of the radars.
172-
# Please note that we are setting a small shape parameter ``epsilon`` for the
173-
# interpolation routine. This will produce a smoother motion field.
174172

175-
interp_kwargs = {"epsilon": 5} # use a small shape parameter for interpolation
176-
UV1 = dense_lucaskanade(
177-
R, dense=True, fd_kwargs=fd_kwargs1, interp_kwargs=interp_kwargs
178-
)
179-
UV2 = dense_lucaskanade(
180-
R, dense=True, fd_kwargs=fd_kwargs2, interp_kwargs=interp_kwargs
181-
)
173+
UV1 = dense_lucaskanade(R, dense=True, fd_kwargs=fd_kwargs1)
174+
UV2 = dense_lucaskanade(R, dense=True, fd_kwargs=fd_kwargs2)
182175

183176
V1 = np.sqrt(UV1[0] ** 2 + UV1[1] ** 2)
184177
V2 = np.sqrt(UV2[0] ** 2 + UV2[1] ** 2)

examples/optical_flow_methods_convergence.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ def plot_optflow_method_convergence(input_precip, optflow_method_name, motion_ty
290290
axis=0
291291
)
292292

293-
cmap = get_cmap("jet")
293+
cmap = get_cmap("jet").copy()
294294
cmap.set_under("grey", alpha=0.25)
295295
cmap.set_over("none")
296296

pysteps/decorators.py

Lines changed: 156 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,34 @@
1111
1212
postprocess_import
1313
check_input_frames
14+
prepare_interpolator
15+
memoize
1416
"""
1517
import inspect
18+
import uuid
1619
from collections import defaultdict
1720
from functools import wraps
1821

1922
import numpy as np
2023

2124

25+
def _add_extra_kwrds_to_docstrings(target_func, extra_kwargs_doc_text):
26+
"""
27+
Update the functions docstrings by replacing the `{extra_kwargs_doc}` occurences in
28+
the docstring by the `extra_kwargs_doc_text` value.
29+
"""
30+
# Clean up indentation from docstrings for the
31+
# docstrings to be merged correctly.
32+
extra_kwargs_doc = inspect.cleandoc(extra_kwargs_doc_text)
33+
target_func.__doc__ = inspect.cleandoc(target_func.__doc__)
34+
35+
# Add extra kwargs docstrings
36+
target_func.__doc__ = target_func.__doc__.format_map(
37+
defaultdict(str, extra_kwargs_doc=extra_kwargs_doc)
38+
)
39+
return target_func
40+
41+
2242
def postprocess_import(fillna=np.nan, dtype="double"):
2343
"""
2444
Postprocess the imported precipitation data.
@@ -82,19 +102,7 @@ def _import_with_postprocessing(*args, **kwargs):
82102
By default, np.nan is used.
83103
"""
84104

85-
# Clean up indentation from docstrings for the
86-
# docstrings to be merged correctly.
87-
extra_kwargs_doc = inspect.cleandoc(extra_kwargs_doc)
88-
_import_with_postprocessing.__doc__ = inspect.cleandoc(
89-
_import_with_postprocessing.__doc__
90-
)
91-
92-
# Add extra kwargs docstrings
93-
_import_with_postprocessing.__doc__ = (
94-
_import_with_postprocessing.__doc__.format_map(
95-
defaultdict(str, extra_kwargs_doc=extra_kwargs_doc)
96-
)
97-
)
105+
_add_extra_kwrds_to_docstrings(_import_with_postprocessing, extra_kwargs_doc)
98106

99107
return _import_with_postprocessing
100108

@@ -140,3 +148,138 @@ def new_function(*args, **kwargs):
140148
return new_function
141149

142150
return _check_input_frames
151+
152+
153+
def prepare_interpolator(nchunks=4):
154+
"""
155+
Check that all the inputs have the correct shape, and that all values are
156+
finite. It also split the destination grid in `nchunks` parts, and process each
157+
part independently.
158+
"""
159+
160+
def _preamble_interpolation(interpolator):
161+
@wraps(interpolator)
162+
def _interpolator_with_preamble(xy_coord, values, xgrid, ygrid, **kwargs):
163+
nonlocal nchunks # https://stackoverflow.com/questions/5630409/
164+
165+
values = values.copy()
166+
xy_coord = xy_coord.copy()
167+
168+
input_ndims = values.ndim
169+
input_nvars = 1 if input_ndims == 1 else values.shape[1]
170+
input_nsamples = values.shape[0]
171+
172+
coord_ndims = xy_coord.ndim
173+
coord_nsamples = xy_coord.shape[0]
174+
175+
grid_shape = (ygrid.size, xgrid.size)
176+
177+
if np.any(~np.isfinite(values)):
178+
raise ValueError("argument 'values' contains non-finite values")
179+
if np.any(~np.isfinite(xy_coord)):
180+
raise ValueError("argument 'xy_coord' contains non-finite values")
181+
182+
if input_ndims > 2:
183+
raise ValueError(
184+
"argument 'values' must have 1 (n) or 2 dimensions (n, m), "
185+
f"but it has {input_ndims}"
186+
)
187+
if not coord_ndims == 2:
188+
raise ValueError(
189+
"argument 'xy_coord' must have 2 dimensions (n, 2), "
190+
f"but it has {coord_ndims}"
191+
)
192+
193+
if not input_nsamples == coord_nsamples:
194+
raise ValueError(
195+
"the number of samples in argument 'values' does not match the "
196+
f"number of coordinates {input_nsamples}!={coord_nsamples}"
197+
)
198+
199+
# only one sample, return uniform output
200+
if input_nsamples == 1:
201+
output_array = np.ones((input_nvars,) + grid_shape)
202+
for n, v in enumerate(values[0, ...]):
203+
output_array[n, ...] *= v
204+
return output_array.squeeze()
205+
206+
# all equal elements, return uniform output
207+
if values.max() == values.min():
208+
return np.ones((input_nvars,) + grid_shape) * values.ravel()[0]
209+
210+
# split grid in n chunks
211+
nchunks = int(kwargs.get("nchunks", nchunks) ** 0.5)
212+
if nchunks > 1:
213+
subxgrids = np.array_split(xgrid, nchunks)
214+
subxgrids = [x for x in subxgrids if x.size > 0]
215+
subygrids = np.array_split(ygrid, nchunks)
216+
subygrids = [y for y in subygrids if y.size > 0]
217+
218+
# generate a unique identifier to be used for caching
219+
# intermediate results
220+
kwargs["hkey"] = uuid.uuid1().int
221+
else:
222+
subxgrids = [xgrid]
223+
subygrids = [ygrid]
224+
225+
interpolated = np.zeros((input_nvars,) + grid_shape)
226+
indx = 0
227+
for subxgrid in subxgrids:
228+
deltax = subxgrid.size
229+
indy = 0
230+
for subygrid in subygrids:
231+
deltay = subygrid.size
232+
interpolated[
233+
:, indy : (indy + deltay), indx : (indx + deltax)
234+
] = interpolator(xy_coord, values, subxgrid, subygrid, **kwargs)
235+
indy += deltay
236+
indx += deltax
237+
238+
return interpolated.squeeze()
239+
240+
extra_kwargs_doc = """
241+
nchunks: int, optional
242+
Split and process the destination grid in nchunks.
243+
Useful for large grids to limit the memory footprint.
244+
"""
245+
246+
_add_extra_kwrds_to_docstrings(_interpolator_with_preamble, extra_kwargs_doc)
247+
248+
return _interpolator_with_preamble
249+
250+
return _preamble_interpolation
251+
252+
253+
def memoize(maxsize=10):
254+
"""
255+
Add a Least Recently Used (LRU) cache to any function.
256+
Caching is purely based on the optional keyword argument 'hkey', which needs
257+
to be a hashable.
258+
259+
Parameters
260+
----------
261+
maxsize: int, optional
262+
The maximum number of elements stored in the LRU cache.
263+
"""
264+
265+
def _memoize(func):
266+
cache = dict()
267+
hkeys = []
268+
269+
@wraps(func)
270+
def _func_with_cache(*args, **kwargs):
271+
hkey = kwargs.pop("hkey", None)
272+
if hkey in cache:
273+
return cache[hkey]
274+
result = func(*args, **kwargs)
275+
if hkey is not None:
276+
cache[hkey] = result
277+
hkeys.append(hkey)
278+
if len(hkeys) > maxsize:
279+
cache.pop(hkeys.pop(0))
280+
281+
return result
282+
283+
return _func_with_cache
284+
285+
return _memoize

pysteps/io/exporters.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -510,14 +510,14 @@ def initialize_forecast_exporter_netcdf(
510510
pr = pyproj.Proj(metadata["projection"])
511511
lon, lat = pr(x_2d.flatten(), y_2d.flatten(), inverse=True)
512512

513-
var_lon = ncf.createVariable("lon", np.float, dimensions=("y", "x"))
513+
var_lon = ncf.createVariable("lon", float, dimensions=("y", "x"))
514514
var_lon[:] = lon.reshape(shape)
515515
var_lon.standard_name = "longitude"
516516
var_lon.long_name = "longitude coordinate"
517517
# TODO(exporters): Don't hard-code the unit.
518518
var_lon.units = "degrees_east"
519519

520-
var_lat = ncf.createVariable("lat", np.float, dimensions=("y", "x"))
520+
var_lat = ncf.createVariable("lat", float, dimensions=("y", "x"))
521521
var_lat[:] = lat.reshape(shape)
522522
var_lat.standard_name = "latitude"
523523
var_lat.long_name = "latitude coordinate"
@@ -533,22 +533,20 @@ def initialize_forecast_exporter_netcdf(
533533
) = _convert_proj4_to_grid_mapping(metadata["projection"])
534534
# skip writing the grid mapping if a matching name was not found
535535
if grid_mapping_var_name is not None:
536-
var_gm = ncf.createVariable(grid_mapping_var_name, np.int, dimensions=())
536+
var_gm = ncf.createVariable(grid_mapping_var_name, int, dimensions=())
537537
var_gm.grid_mapping_name = grid_mapping_name
538538
for i in grid_mapping_params.items():
539539
var_gm.setncattr(i[0], i[1])
540540

541541
if incremental == "member" or n_ens_gt_one:
542-
var_ens_num = ncf.createVariable(
543-
"ens_number", np.int, dimensions=("ens_number",)
544-
)
542+
var_ens_num = ncf.createVariable("ens_number", int, dimensions=("ens_number",))
545543
if incremental != "member":
546544
var_ens_num[:] = list(range(1, n_ens_members + 1))
547545
var_ens_num.long_name = "ensemble member"
548546
var_ens_num.standard_name = "realization"
549547
var_ens_num.units = ""
550548

551-
var_time = ncf.createVariable("time", np.int, dimensions=("time",))
549+
var_time = ncf.createVariable("time", int, dimensions=("time",))
552550
if incremental != "timestep":
553551
var_time[:] = [i * timestep * 60 for i in range(1, n_timesteps + 1)]
554552
var_time.long_name = "forecast time"

pysteps/io/importers.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
| ypixelsize | grid resolution in y-direction |
4040
+------------------+----------------------------------------------------------+
4141
| cartesian_unit | the physical unit of the cartesian x- and y-coordinates: |
42-
| | e.g. 'm' or 'km' |
42+
| | e.g. 'm' or 'km' |
4343
+------------------+----------------------------------------------------------+
4444
| yorigin | a string specifying the location of the first element in |
4545
| | the data raster w.r.t. y-axis: |
@@ -934,12 +934,12 @@ def import_mch_gif(filename, product, unit, accutime, **kwargs):
934934
metadata = geodata
935935

936936
# import gif file
937-
B = Image.open(filename)
937+
img = Image.open(filename)
938938

939939
if product.lower() in ["azc", "rzc", "precip"]:
940940

941941
# convert 8-bit GIF colortable to RGB values
942-
Brgb = B.convert("RGB")
942+
img_rgb = img.convert("RGB")
943943

944944
# load lookup table
945945
if product.lower() == "azc":
@@ -954,12 +954,12 @@ def import_mch_gif(filename, product, unit, accutime, **kwargs):
954954
lut = dict(zip(zip(lut[:, 1], lut[:, 2], lut[:, 3]), lut[:, -1]))
955955

956956
# apply lookup table conversion
957-
precip = np.zeros(len(Brgb.getdata()))
958-
for i, dn in enumerate(Brgb.getdata()):
957+
precip = np.zeros(len(img_rgb.getdata()))
958+
for i, dn in enumerate(img_rgb.getdata()):
959959
precip[i] = lut.get(dn, np.nan)
960960

961961
# convert to original shape
962-
width, height = B.size
962+
width, height = img.size
963963
precip = precip.reshape(height, width)
964964

965965
# set values outside observational range to NaN,
@@ -970,7 +970,7 @@ def import_mch_gif(filename, product, unit, accutime, **kwargs):
970970
elif product.lower() in ["aqc", "cpc", "acquire ", "combiprecip"]:
971971

972972
# convert digital numbers to physical values
973-
B = np.array(B, dtype=int)
973+
img = np.array(img).astype(int)
974974

975975
# build lookup table [mm/5min]
976976
lut = np.zeros(256)
@@ -985,7 +985,7 @@ def import_mch_gif(filename, product, unit, accutime, **kwargs):
985985
lut[i] = (10.0 ** ((i - 71.5) / 20.0) / a) ** (1.0 / b)
986986

987987
# apply lookup table
988-
precip = lut[B]
988+
precip = lut[img]
989989

990990
else:
991991
raise ValueError("unknown product %s" % product)

pysteps/motion/lucaskanade.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def dense_lucaskanade(
4949
lk_kwargs=None,
5050
fd_method="shitomasi",
5151
fd_kwargs=None,
52-
interp_method="rbfinterp2d",
52+
interp_method="idwinterp2d",
5353
interp_kwargs=None,
5454
dense=True,
5555
nr_std_outlier=3,

pysteps/nowcasts/sseps.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,9 +230,9 @@ def forecast(
230230
)
231231

232232
if np.isscalar(win_size):
233-
win_size = (np.int(win_size), np.int(win_size))
233+
win_size = (int(win_size), int(win_size))
234234
else:
235-
win_size = tuple([np.int(win_size[i]) for i in range(2)])
235+
win_size = tuple([int(win_size[i]) for i in range(2)])
236236

237237
timestep = metadata["accutime"]
238238
kmperpixel = metadata["xpixelsize"] / 1000

pysteps/tests/test_decorators.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# -*- coding: utf-8 -*-
2+
import time
3+
4+
from pysteps.decorators import memoize
5+
6+
7+
def test_memoize():
8+
@memoize(maxsize=1)
9+
def _slow_function(x, **kwargs):
10+
time.sleep(1)
11+
return x
12+
13+
for i in range(2):
14+
out = _slow_function(i, hkey=i)
15+
assert out == i
16+
17+
# cached result
18+
t0 = time.monotonic()
19+
out = _slow_function(1, hkey=1)
20+
assert time.monotonic() - t0 < 1
21+
assert out == 1
22+
23+
# maxsize exceeded
24+
t0 = time.monotonic()
25+
out = _slow_function(0, hkey=0)
26+
assert time.monotonic() - t0 >= 1
27+
assert out == 0
28+
29+
# no hash
30+
t0 = time.monotonic()
31+
out = _slow_function(1)
32+
assert time.monotonic() - t0 >= 1
33+
assert out == 1

pysteps/tests/test_nowcasts_steps.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
(5, 6, 2, "sprog", None, "spatial", 3, 8.35),
2828
(5, 6, 2, "obs", None, "spatial", 3, 8.30),
2929
(5, 6, 2, None, "cdf", "spatial", 3, 0.60),
30-
(5, 6, 2, None, "mean", "spatial", 3, 1.30),
30+
(5, 6, 2, None, "mean", "spatial", 3, 1.35),
3131
(5, 6, 2, "incremental", "cdf", "spectral", 3, 0.60),
3232
]
3333

0 commit comments

Comments
 (0)