1
- # -*- coding: utf-8 -*-
2
1
"""
3
2
pysteps.cascade.bandpass_filters
4
3
================================
@@ -64,10 +63,14 @@ def filter_uniform(shape, n):
64
63
n: int
65
64
Not used. Needed for compatibility with the filter interface.
66
65
66
+ Returns
67
+ -------
68
+ out: dict
69
+ A dictionary containing the filter.
67
70
"""
68
71
del n # Unused
69
72
70
- result = {}
73
+ out = {}
71
74
72
75
try :
73
76
height , width = shape
@@ -76,17 +79,23 @@ def filter_uniform(shape, n):
76
79
77
80
r_max = int (max (width , height ) / 2 ) + 1
78
81
79
- result ["weights_1d" ] = np .ones ((1 , r_max ))
80
- result ["weights_2d" ] = np .ones ((1 , height , int (width / 2 ) + 1 ))
81
- result ["central_freqs" ] = None
82
- result ["central_wavenumbers" ] = None
83
- result ["shape" ] = shape
82
+ out ["weights_1d" ] = np .ones ((1 , r_max ))
83
+ out ["weights_2d" ] = np .ones ((1 , height , int (width / 2 ) + 1 ))
84
+ out ["central_freqs" ] = None
85
+ out ["central_wavenumbers" ] = None
86
+ out ["shape" ] = shape
84
87
85
- return result
88
+ return out
86
89
87
90
88
91
def filter_gaussian (
89
- shape , n , l_0 = 3 , gauss_scale = 0.5 , gauss_scale_0 = 0.5 , d = 1.0 , normalize = True
92
+ shape ,
93
+ n ,
94
+ gauss_scale = 0.5 ,
95
+ d = 1.0 ,
96
+ normalize = True ,
97
+ return_weight_funcs = False ,
98
+ include_mean = True ,
90
99
):
91
100
"""
92
101
Implements a set of Gaussian bandpass filters in logarithmic frequency
@@ -99,20 +108,20 @@ def filter_gaussian(
99
108
the domain is assumed to have square shape.
100
109
n: int
101
110
The number of frequency bands to use. Must be greater than 2.
102
- l_0: int
103
- Central frequency of the second band (the first band is always centered
104
- at zero).
105
111
gauss_scale: float
106
- Optional scaling prameter . Proportional to the standard deviation of
112
+ Optional scaling parameter . Proportional to the standard deviation of
107
113
the Gaussian weight functions.
108
- gauss_scale_0: float
109
- Optional scaling parameter for the Gaussian function corresponding to
110
- the first frequency band.
111
114
d: scalar, optional
112
115
Sample spacing (inverse of the sampling rate). Defaults to 1.
113
116
normalize: bool
114
117
If True, normalize the weights so that for any given wavenumber
115
118
they sum to one.
119
+ return_weight_funcs: bool
120
+ If True, add callable weight functions to the output dictionary with
121
+ the key 'weight_funcs'.
122
+ include_mean: bool
123
+ If True, include the first Fourier wavenumber (corresponding to the
124
+ field mean) to the first filter.
116
125
117
126
Returns
118
127
-------
@@ -133,6 +142,8 @@ def filter_gaussian(
133
142
except TypeError :
134
143
height , width = (shape , shape )
135
144
145
+ max_length = max (width , height )
146
+
136
147
rx = np .s_ [: int (width / 2 ) + 1 ]
137
148
138
149
if (height % 2 ) == 1 :
@@ -145,13 +156,13 @@ def filter_gaussian(
145
156
146
157
r_2d = np .roll (np .sqrt (x_grid * x_grid + y_grid * y_grid ), dy , axis = 0 )
147
158
148
- max_length = max (width , height )
149
-
150
159
r_max = int (max_length / 2 ) + 1
151
160
r_1d = np .arange (r_max )
152
161
153
162
wfs , central_wavenumbers = _gaussweights_1d (
154
- max_length , n , l_0 = l_0 , gauss_scale = gauss_scale , gauss_scale_0 = gauss_scale_0
163
+ max_length ,
164
+ n ,
165
+ gauss_scale = gauss_scale ,
155
166
)
156
167
157
168
weights_1d = np .empty ((n , r_max ))
@@ -168,36 +179,48 @@ def filter_gaussian(
168
179
weights_1d [k , :] /= weights_1d_sum
169
180
weights_2d [k , :, :] /= weights_2d_sum
170
181
171
- result = {"weights_1d" : weights_1d , "weights_2d" : weights_2d }
172
- result ["shape" ] = shape
182
+ for i in range (len (wfs )):
183
+ if i == 0 and include_mean :
184
+ weights_1d [i , 0 ] = 1.0
185
+ weights_2d [i , 0 , 0 ] = 1.0
186
+ else :
187
+ weights_1d [i , 0 ] = 0.0
188
+ weights_2d [i , 0 , 0 ] = 0.0
189
+
190
+ out = {"weights_1d" : weights_1d , "weights_2d" : weights_2d }
191
+ out ["shape" ] = shape
173
192
174
193
central_wavenumbers = np .array (central_wavenumbers )
175
- result ["central_wavenumbers" ] = central_wavenumbers
194
+ out ["central_wavenumbers" ] = central_wavenumbers
176
195
177
196
# Compute frequencies
178
197
central_freqs = 1.0 * central_wavenumbers / max_length
179
198
central_freqs [0 ] = 1.0 / max_length
180
199
central_freqs [- 1 ] = 0.5 # Nyquist freq
181
200
central_freqs = 1.0 * d * central_freqs
182
- result ["central_freqs" ] = central_freqs
201
+ out ["central_freqs" ] = central_freqs
202
+
203
+ if return_weight_funcs :
204
+ out ["weight_funcs" ] = wfs
183
205
184
- return result
206
+ return out
185
207
186
208
187
- def _gaussweights_1d (l , n , l_0 = 3 , gauss_scale = 0.5 , gauss_scale_0 = 0.5 ):
188
- e = pow (0.5 * l / l_0 , 1.0 / (n - 2 ))
189
- r = [(l_0 * pow (e , k - 1 ), l_0 * pow (e , k )) for k in range (1 , n - 1 )]
209
+ def _gaussweights_1d (l , n , gauss_scale = 0.5 ):
210
+ q = pow (0.5 * l , 1.0 / n )
211
+ r = [(pow (q , k - 1 ), pow (q , k )) for k in range (1 , n + 1 )]
212
+ r = [0.5 * (r_ [0 ] + r_ [1 ]) for r_ in r ]
190
213
191
214
def log_e (x ):
192
215
if len (np .shape (x )) > 0 :
193
216
res = np .empty (x .shape )
194
217
res [x == 0 ] = 0.0
195
- res [x > 0 ] = np .log (x [x > 0 ]) / np .log (e )
218
+ res [x > 0 ] = np .log (x [x > 0 ]) / np .log (q )
196
219
else :
197
220
if x == 0.0 :
198
221
res = 0.0
199
222
else :
200
- res = np .log (x ) / np .log (e )
223
+ res = np .log (x ) / np .log (q )
201
224
202
225
return res
203
226
@@ -211,25 +234,11 @@ def __call__(self, x):
211
234
return np .exp (- (x ** 2.0 ) / (2.0 * self .s ** 2.0 ))
212
235
213
236
weight_funcs = []
214
- central_wavenumbers = [0.0 ]
215
-
216
- weight_funcs .append (GaussFunc (0.0 , gauss_scale_0 ))
237
+ central_wavenumbers = []
217
238
218
239
for i , ri in enumerate (r ):
219
- rc = log_e (ri [ 0 ] )
240
+ rc = log_e (ri )
220
241
weight_funcs .append (GaussFunc (rc , gauss_scale ))
221
- central_wavenumbers .append (ri [0 ])
222
-
223
- gf = GaussFunc (log_e (l / 2 ), gauss_scale )
224
-
225
- def g (x ):
226
- res = np .ones (x .shape )
227
- mask = x <= l / 2
228
- res [mask ] = gf (x [mask ])
229
-
230
- return res
231
-
232
- weight_funcs .append (g )
233
- central_wavenumbers .append (l / 2 )
242
+ central_wavenumbers .append (ri )
234
243
235
244
return weight_funcs , central_wavenumbers
0 commit comments