Skip to content

Commit 49cc4f5

Browse files
pulkkinsdnerini
andauthored
Cascade scale fix (#283)
* Remove unnecessary encoding lines * Use default l_0 value that gives constant ratio between scales * Take l_0 from max(width,height) instead of min * Fix typo * Docstring fix * Use more consistent output variable names * Add missing output variable description to docstring * Add option to return callable functions for bandpass filter weights * Implement alternative strategy for choosing filter weights * Set the weights corresponding to the first Fourier frequency to zero * Add option to subtract field mean before decomposition (defaults to True) * Fix nan values due to division by zero * Fix typo * Add the first Fourier wavenumber/field mean to the first filter by default * Fix typo * Ensure that weights sum to one * Set subtract_mean to False by default * Slightly adjust CRPS thresholds for the new bandpass filter configuration Co-authored-by: Daniele Nerini <daniele.nerini@gmail.com>
1 parent fd2e665 commit 49cc4f5

File tree

4 files changed

+75
-56
lines changed

4 files changed

+75
-56
lines changed

pysteps/cascade/bandpass_filters.py

Lines changed: 55 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# -*- coding: utf-8 -*-
21
"""
32
pysteps.cascade.bandpass_filters
43
================================
@@ -64,10 +63,14 @@ def filter_uniform(shape, n):
6463
n: int
6564
Not used. Needed for compatibility with the filter interface.
6665
66+
Returns
67+
-------
68+
out: dict
69+
A dictionary containing the filter.
6770
"""
6871
del n # Unused
6972

70-
result = {}
73+
out = {}
7174

7275
try:
7376
height, width = shape
@@ -76,17 +79,23 @@ def filter_uniform(shape, n):
7679

7780
r_max = int(max(width, height) / 2) + 1
7881

79-
result["weights_1d"] = np.ones((1, r_max))
80-
result["weights_2d"] = np.ones((1, height, int(width / 2) + 1))
81-
result["central_freqs"] = None
82-
result["central_wavenumbers"] = None
83-
result["shape"] = shape
82+
out["weights_1d"] = np.ones((1, r_max))
83+
out["weights_2d"] = np.ones((1, height, int(width / 2) + 1))
84+
out["central_freqs"] = None
85+
out["central_wavenumbers"] = None
86+
out["shape"] = shape
8487

85-
return result
88+
return out
8689

8790

8891
def filter_gaussian(
89-
shape, n, l_0=3, gauss_scale=0.5, gauss_scale_0=0.5, d=1.0, normalize=True
92+
shape,
93+
n,
94+
gauss_scale=0.5,
95+
d=1.0,
96+
normalize=True,
97+
return_weight_funcs=False,
98+
include_mean=True,
9099
):
91100
"""
92101
Implements a set of Gaussian bandpass filters in logarithmic frequency
@@ -99,20 +108,20 @@ def filter_gaussian(
99108
the domain is assumed to have square shape.
100109
n: int
101110
The number of frequency bands to use. Must be greater than 2.
102-
l_0: int
103-
Central frequency of the second band (the first band is always centered
104-
at zero).
105111
gauss_scale: float
106-
Optional scaling prameter. Proportional to the standard deviation of
112+
Optional scaling parameter. Proportional to the standard deviation of
107113
the Gaussian weight functions.
108-
gauss_scale_0: float
109-
Optional scaling parameter for the Gaussian function corresponding to
110-
the first frequency band.
111114
d: scalar, optional
112115
Sample spacing (inverse of the sampling rate). Defaults to 1.
113116
normalize: bool
114117
If True, normalize the weights so that for any given wavenumber
115118
they sum to one.
119+
return_weight_funcs: bool
120+
If True, add callable weight functions to the output dictionary with
121+
the key 'weight_funcs'.
122+
include_mean: bool
123+
If True, include the first Fourier wavenumber (corresponding to the
124+
field mean) to the first filter.
116125
117126
Returns
118127
-------
@@ -133,6 +142,8 @@ def filter_gaussian(
133142
except TypeError:
134143
height, width = (shape, shape)
135144

145+
max_length = max(width, height)
146+
136147
rx = np.s_[: int(width / 2) + 1]
137148

138149
if (height % 2) == 1:
@@ -145,13 +156,13 @@ def filter_gaussian(
145156

146157
r_2d = np.roll(np.sqrt(x_grid * x_grid + y_grid * y_grid), dy, axis=0)
147158

148-
max_length = max(width, height)
149-
150159
r_max = int(max_length / 2) + 1
151160
r_1d = np.arange(r_max)
152161

153162
wfs, central_wavenumbers = _gaussweights_1d(
154-
max_length, n, l_0=l_0, gauss_scale=gauss_scale, gauss_scale_0=gauss_scale_0
163+
max_length,
164+
n,
165+
gauss_scale=gauss_scale,
155166
)
156167

157168
weights_1d = np.empty((n, r_max))
@@ -168,36 +179,48 @@ def filter_gaussian(
168179
weights_1d[k, :] /= weights_1d_sum
169180
weights_2d[k, :, :] /= weights_2d_sum
170181

171-
result = {"weights_1d": weights_1d, "weights_2d": weights_2d}
172-
result["shape"] = shape
182+
for i in range(len(wfs)):
183+
if i == 0 and include_mean:
184+
weights_1d[i, 0] = 1.0
185+
weights_2d[i, 0, 0] = 1.0
186+
else:
187+
weights_1d[i, 0] = 0.0
188+
weights_2d[i, 0, 0] = 0.0
189+
190+
out = {"weights_1d": weights_1d, "weights_2d": weights_2d}
191+
out["shape"] = shape
173192

174193
central_wavenumbers = np.array(central_wavenumbers)
175-
result["central_wavenumbers"] = central_wavenumbers
194+
out["central_wavenumbers"] = central_wavenumbers
176195

177196
# Compute frequencies
178197
central_freqs = 1.0 * central_wavenumbers / max_length
179198
central_freqs[0] = 1.0 / max_length
180199
central_freqs[-1] = 0.5 # Nyquist freq
181200
central_freqs = 1.0 * d * central_freqs
182-
result["central_freqs"] = central_freqs
201+
out["central_freqs"] = central_freqs
202+
203+
if return_weight_funcs:
204+
out["weight_funcs"] = wfs
183205

184-
return result
206+
return out
185207

186208

187-
def _gaussweights_1d(l, n, l_0=3, gauss_scale=0.5, gauss_scale_0=0.5):
188-
e = pow(0.5 * l / l_0, 1.0 / (n - 2))
189-
r = [(l_0 * pow(e, k - 1), l_0 * pow(e, k)) for k in range(1, n - 1)]
209+
def _gaussweights_1d(l, n, gauss_scale=0.5):
210+
q = pow(0.5 * l, 1.0 / n)
211+
r = [(pow(q, k - 1), pow(q, k)) for k in range(1, n + 1)]
212+
r = [0.5 * (r_[0] + r_[1]) for r_ in r]
190213

191214
def log_e(x):
192215
if len(np.shape(x)) > 0:
193216
res = np.empty(x.shape)
194217
res[x == 0] = 0.0
195-
res[x > 0] = np.log(x[x > 0]) / np.log(e)
218+
res[x > 0] = np.log(x[x > 0]) / np.log(q)
196219
else:
197220
if x == 0.0:
198221
res = 0.0
199222
else:
200-
res = np.log(x) / np.log(e)
223+
res = np.log(x) / np.log(q)
201224

202225
return res
203226

@@ -211,25 +234,11 @@ def __call__(self, x):
211234
return np.exp(-(x**2.0) / (2.0 * self.s**2.0))
212235

213236
weight_funcs = []
214-
central_wavenumbers = [0.0]
215-
216-
weight_funcs.append(GaussFunc(0.0, gauss_scale_0))
237+
central_wavenumbers = []
217238

218239
for i, ri in enumerate(r):
219-
rc = log_e(ri[0])
240+
rc = log_e(ri)
220241
weight_funcs.append(GaussFunc(rc, gauss_scale))
221-
central_wavenumbers.append(ri[0])
222-
223-
gf = GaussFunc(log_e(l / 2), gauss_scale)
224-
225-
def g(x):
226-
res = np.ones(x.shape)
227-
mask = x <= l / 2
228-
res[mask] = gf(x[mask])
229-
230-
return res
231-
232-
weight_funcs.append(g)
233-
central_wavenumbers.append(l / 2)
242+
central_wavenumbers.append(ri)
234243

235244
return weight_funcs, central_wavenumbers

pysteps/cascade/decomposition.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# -*- coding: utf-8 -*-
21
"""
32
pysteps.cascade.decomposition
43
=============================
@@ -14,9 +13,8 @@
1413
where field is the input field and bp_filter is a dictionary returned by a
1514
filter method implemented in :py:mod:`pysteps.cascade.bandpass_filters`. The
1615
decomp argument is a decomposition obtained by calling decomposition_xxx.
17-
Optional parameters can be passed in
18-
the keyword arguments. The output of each method is a dictionary with the
19-
following key-value pairs:
16+
Optional parameters can be passed in the keyword arguments. The output of each
17+
method is a dictionary with the following key-value pairs:
2018
2119
+-------------------+----------------------------------------------------------+
2220
| Key | Value |
@@ -120,6 +118,10 @@ def decomposition_fft(field, bp_filter, **kwargs):
120118
Applicable if output_domain is "spectral". If set to True, only the
121119
parts of the Fourier spectrum with non-negligible filter weights are
122120
stored. Defaults to False.
121+
subtract_mean: bool
122+
If set to True, subtract the mean value before the decomposition and
123+
store it to the output dictionary. Applicable if input_domain is
124+
"spatial". Defaults to False.
123125
124126
Returns
125127
-------
@@ -138,6 +140,7 @@ def decomposition_fft(field, bp_filter, **kwargs):
138140
output_domain = kwargs.get("output_domain", "spatial")
139141
compute_stats = kwargs.get("compute_stats", True)
140142
compact_output = kwargs.get("compact_output", False)
143+
subtract_mean = kwargs.get("subtract_mean", False)
141144

142145
if normalize and not compute_stats:
143146
compute_stats = True
@@ -194,6 +197,11 @@ def decomposition_fft(field, bp_filter, **kwargs):
194197
means = []
195198
stds = []
196199

200+
if subtract_mean and input_domain == "spatial":
201+
field_mean = np.mean(field)
202+
field = field - field_mean
203+
result["field_mean"] = field_mean
204+
197205
if input_domain == "spatial":
198206
field_fft = fft.rfft2(field)
199207
else:
@@ -276,7 +284,7 @@ def recompose_fft(decomp, **kwargs):
276284
if not decomp["normalized"] and not (
277285
decomp["domain"] == "spectral" and decomp["compact_output"]
278286
):
279-
return np.sum(levels, axis=0)
287+
result = np.sum(levels, axis=0)
280288
else:
281289
if decomp["compact_output"]:
282290
weight_masks = decomp["weight_masks"]
@@ -291,4 +299,7 @@ def recompose_fft(decomp, **kwargs):
291299
result = [levels[i] * sigma[i] + mu[i] for i in range(len(levels))]
292300
result = np.sum(np.stack(result), axis=0)
293301

294-
return result
302+
if "field_mean" in decomp:
303+
result += decomp["field_mean"]
304+
305+
return result

pysteps/cascade/interface.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# -*- coding: utf-8 -*-
21
"""
32
pysteps.cascade.interface
43
=========================

pysteps/tests/test_nowcasts_steps.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@
2424
steps_arg_values = [
2525
(5, 6, 2, None, None, "spatial", 3, 1.30),
2626
(5, 6, 2, None, None, "spatial", [3], 1.30),
27-
(5, 6, 2, "incremental", None, "spatial", 3, 7.25),
28-
(5, 6, 2, "sprog", None, "spatial", 3, 8.35),
29-
(5, 6, 2, "obs", None, "spatial", 3, 8.30),
27+
(5, 6, 2, "incremental", None, "spatial", 3, 7.31),
28+
(5, 6, 2, "sprog", None, "spatial", 3, 8.4),
29+
(5, 6, 2, "obs", None, "spatial", 3, 8.37),
3030
(5, 6, 2, None, "cdf", "spatial", 3, 0.60),
3131
(5, 6, 2, None, "mean", "spatial", 3, 1.35),
3232
(5, 6, 2, "incremental", "cdf", "spectral", 3, 0.60),

0 commit comments

Comments
 (0)