Skip to content

Commit a7dae54

Browse files
authored
Use seed for all rng in blending to make a test run completely deterministic (#450)
* Use seed for all rng in blending to make a test run completely deterministic * fix coverage * Actually add a test that runs the previously uncovered lines * Add randgen to docstring and add default value
1 parent e332585 commit a7dae54

File tree

5 files changed

+82
-53
lines changed

5 files changed

+82
-53
lines changed

pysteps/blending/steps.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -734,16 +734,19 @@ def forecast(
734734
)
735735

736736
# 6. Initialize all the random generators and prepare for the forecast loop
737-
randgen_prec, vps, generate_vel_noise = _init_random_generators(
738-
velocity,
739-
noise_method,
740-
vel_pert_method,
741-
vp_par,
742-
vp_perp,
743-
seed,
744-
n_ens_members,
745-
kmperpixel,
746-
timestep,
737+
randgen_prec, vps, generate_vel_noise, randgen_probmatching = (
738+
_init_random_generators(
739+
velocity,
740+
noise_method,
741+
probmatching_method,
742+
vel_pert_method,
743+
vp_par,
744+
vp_perp,
745+
seed,
746+
n_ens_members,
747+
kmperpixel,
748+
timestep,
749+
)
747750
)
748751
D, D_Yn, D_pb, R_f, R_m, mask_rim, struct, fft_objs = _prepare_forecast_loop(
749752
precip_cascade,
@@ -1621,6 +1624,7 @@ def worker(j):
16211624
first_array=arr1,
16221625
second_array=arr2,
16231626
probability_first_array=weights_pm_normalized[0],
1627+
randgen=randgen_probmatching[j],
16241628
)
16251629
else:
16261630
R_pm_resampled = R_pm_blended.copy()
@@ -2290,6 +2294,7 @@ def _find_nwp_combination(
22902294
def _init_random_generators(
22912295
velocity,
22922296
noise_method,
2297+
probmatching_method,
22932298
vel_pert_method,
22942299
vp_par,
22952300
vp_perp,
@@ -2299,18 +2304,28 @@ def _init_random_generators(
22992304
timestep,
23002305
):
23012306
"""Initialize all the random generators."""
2307+
randgen_prec = None
23022308
if noise_method is not None:
23032309
randgen_prec = []
2304-
randgen_motion = []
23052310
for j in range(n_ens_members):
23062311
rs = np.random.RandomState(seed)
23072312
randgen_prec.append(rs)
23082313
seed = rs.randint(0, high=1e9)
2314+
2315+
randgen_probmatching = None
2316+
if probmatching_method is not None:
2317+
randgen_probmatching = []
2318+
for j in range(n_ens_members):
23092319
rs = np.random.RandomState(seed)
2310-
randgen_motion.append(rs)
2320+
randgen_probmatching.append(rs)
23112321
seed = rs.randint(0, high=1e9)
23122322

23132323
if vel_pert_method is not None:
2324+
randgen_motion = []
2325+
for j in range(n_ens_members):
2326+
rs = np.random.RandomState(seed)
2327+
randgen_motion.append(rs)
2328+
seed = rs.randint(0, high=1e9)
23142329
init_vel_noise, generate_vel_noise = noise.get_method(vel_pert_method)
23152330

23162331
# initialize the perturbation generators for the motion field
@@ -2326,7 +2341,7 @@ def _init_random_generators(
23262341
else:
23272342
vps, generate_vel_noise = None, None
23282343

2329-
return randgen_prec, vps, generate_vel_noise
2344+
return randgen_prec, vps, generate_vel_noise, randgen_probmatching
23302345

23312346

23322347
def _prepare_forecast_loop(

pysteps/noise/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,9 @@ def compute_noise_stddev_adjs(
101101
randstates = []
102102

103103
for k in range(num_iter):
104-
randstates.append(np.random.RandomState(seed=seed))
105-
seed = np.random.randint(0, high=1e9)
104+
rs = np.random.RandomState(seed=seed)
105+
randstates.append(rs)
106+
seed = rs.randint(0, high=1e9)
106107

107108
def worker(k):
108109
# generate Gaussian white noise field, filter it using the chosen

pysteps/postprocessing/probmatching.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,9 @@ def _get_error(scale):
274274
return shift, scale, R.reshape(shape)
275275

276276

277-
def resample_distributions(first_array, second_array, probability_first_array):
277+
def resample_distributions(
278+
first_array, second_array, probability_first_array, randgen=np.random
279+
):
278280
"""
279281
Merges two distributions (e.g., from the extrapolation nowcast and NWP in the blending module)
280282
to effectively combine two distributions for probability matching without losing extremes.
@@ -291,6 +293,9 @@ def resample_distributions(first_array, second_array, probability_first_array):
291293
probability_first_array: float
292294
The weight that `first_array` should get (a value between 0 and 1). This determines the
293295
likelihood of selecting elements from `first_array` over `second_array`.
296+
randgen: numpy.random or numpy.RandomState
297+
The random number generator to be used for the binomial distribution. You can pass a seeded
298+
random state here for reproducibility. Default is numpy.random.
294299
295300
Returns
296301
-------
@@ -324,7 +329,7 @@ def resample_distributions(first_array, second_array, probability_first_array):
324329
n = asort.shape[0]
325330

326331
# Resample the distributions
327-
idxsamples = np.random.binomial(1, probability_first_array, n).astype(bool)
332+
idxsamples = randgen.binomial(1, probability_first_array, n).astype(bool)
328333
csort = np.where(idxsamples, asort, bsort)
329334
csort = np.sort(csort)[::-1]
330335

pysteps/tests/test_blending_steps.py

Lines changed: 39 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -8,45 +8,48 @@
88
import pysteps
99
from pysteps import blending, cascade
1010

11+
# fmt:off
1112
steps_arg_values = [
12-
(1, 3, 4, 8, None, None, False, "spn", True, 4, False, False, 0, False),
13-
(1, 3, 4, 8, "obs", None, False, "spn", True, 4, False, False, 0, False),
14-
(1, 3, 4, 8, "incremental", None, False, "spn", True, 4, False, False, 0, False),
15-
(1, 3, 4, 8, None, "mean", False, "spn", True, 4, False, False, 0, False),
16-
(1, 3, 4, 8, None, "mean", False, "spn", True, 4, False, False, 0, True),
17-
(1, 3, 4, 8, None, "cdf", False, "spn", True, 4, False, False, 0, False),
18-
(1, [1, 2, 3], 4, 8, None, "cdf", False, "spn", True, 4, False, False, 0, False),
19-
(1, 3, 4, 8, "incremental", "cdf", False, "spn", True, 4, False, False, 0, False),
20-
(1, 3, 4, 6, "incremental", "cdf", False, "bps", True, 4, False, False, 0, False),
21-
(1, 3, 4, 6, "incremental", "cdf", False, "bps", False, 4, False, False, 0, False),
22-
(1, 3, 4, 6, "incremental", "cdf", False, "bps", False, 4, False, False, 0, True),
23-
(1, 3, 4, 9, "incremental", "cdf", False, "spn", True, 4, False, False, 0, False),
24-
(2, 3, 10, 8, "incremental", "cdf", False, "spn", True, 10, False, False, 0, False),
25-
(5, 3, 5, 8, "incremental", "cdf", False, "spn", True, 5, False, False, 0, False),
26-
(1, 10, 1, 8, "incremental", "cdf", False, "spn", True, 1, False, False, 0, False),
27-
(2, 3, 2, 8, "incremental", "cdf", True, "spn", True, 2, False, False, 0, False),
28-
(1, 3, 6, 8, None, None, False, "spn", True, 6, False, False, 0, False),
13+
(1, 3, 4, 8, None, None, False, "spn", True, 4, False, False, 0, False, None),
14+
(1, 3, 4, 8, "obs", None, False, "spn", True, 4, False, False, 0, False, None),
15+
(1, 3, 4, 8, "incremental", None, False, "spn", True, 4, False, False, 0, False, None),
16+
(1, 3, 4, 8, None, "mean", False, "spn", True, 4, False, False, 0, False, None),
17+
(1, 3, 4, 8, None, "mean", False, "spn", True, 4, False, False, 0, True, None),
18+
(1, 3, 4, 8, None, "cdf", False, "spn", True, 4, False, False, 0, False, None),
19+
(1, [1, 2, 3], 4, 8, None, "cdf", False, "spn", True, 4, False, False, 0, False, None),
20+
(1, 3, 4, 8, "incremental", "cdf", False, "spn", True, 4, False, False, 0, False, None),
21+
(1, 3, 4, 6, "incremental", "cdf", False, "bps", True, 4, False, False, 0, False, None),
22+
(1, 3, 4, 6, "incremental", "cdf", False, "bps", False, 4, False, False, 0, False, None),
23+
(1, 3, 4, 6, "incremental", "cdf", False, "bps", False, 4, False, False, 0, True, None),
24+
(1, 3, 4, 9, "incremental", "cdf", False, "spn", True, 4, False, False, 0, False, None),
25+
(2, 3, 10, 8, "incremental", "cdf", False, "spn", True, 10, False, False, 0, False, None),
26+
(5, 3, 5, 8, "incremental", "cdf", False, "spn", True, 5, False, False, 0, False, None),
27+
(1, 10, 1, 8, "incremental", "cdf", False, "spn", True, 1, False, False, 0, False, None),
28+
(2, 3, 2, 8, "incremental", "cdf", True, "spn", True, 2, False, False, 0, False, None),
29+
(1, 3, 6, 8, None, None, False, "spn", True, 6, False, False, 0, False, None),
30+
(1, 3, 6, 8, None, None, False, "spn", True, 6, False, False, 0, False, "bps"),
2931
# Test the case where the radar image contains no rain.
30-
(1, 3, 6, 8, None, None, False, "spn", True, 6, True, False, 0, False),
31-
(5, 3, 5, 6, "incremental", "cdf", False, "spn", False, 5, True, False, 0, False),
32-
(5, 3, 5, 6, "incremental", "cdf", False, "spn", False, 5, True, False, 0, True),
32+
(1, 3, 6, 8, None, None, False, "spn", True, 6, True, False, 0, False, None),
33+
(5, 3, 5, 6, "incremental", "cdf", False, "spn", False, 5, True, False, 0, False, None),
34+
(5, 3, 5, 6, "incremental", "cdf", False, "spn", False, 5, True, False, 0, True, None),
3335
# Test the case where the NWP fields contain no rain.
34-
(1, 3, 6, 8, None, None, False, "spn", True, 6, False, True, 0, False),
35-
(5, 3, 5, 6, "incremental", "cdf", False, "spn", False, 5, False, True, 0, True),
36+
(1, 3, 6, 8, None, None, False, "spn", True, 6, False, True, 0, False, None),
37+
(5, 3, 5, 6, "incremental", "cdf", False, "spn", False, 5, False, True, 0, True, None),
3638
# Test the case where both the radar image and the NWP fields contain no rain.
37-
(1, 3, 6, 8, None, None, False, "spn", True, 6, True, True, 0, False),
38-
(5, 3, 5, 6, "incremental", "cdf", False, "spn", False, 5, True, True, 0, False),
39-
(5, 3, 5, 6, "obs", "mean", True, "spn", True, 5, True, True, 0, False),
39+
(1, 3, 6, 8, None, None, False, "spn", True, 6, True, True, 0, False, None),
40+
(5, 3, 5, 6, "incremental", "cdf", False, "spn", False, 5, True, True, 0, False, None),
41+
(5, 3, 5, 6, "obs", "mean", True, "spn", True, 5, True, True, 0, False, None),
4042
# Test for smooth radar mask
41-
(1, 3, 6, 8, None, None, False, "spn", True, 6, False, False, 80, False),
42-
(5, 3, 5, 6, "incremental", "cdf", False, "spn", False, 5, False, False, 80, False),
43-
(5, 3, 5, 6, "obs", "mean", False, "spn", False, 5, False, False, 80, False),
44-
(1, 3, 6, 8, None, None, False, "spn", True, 6, False, True, 80, False),
45-
(5, 3, 5, 6, "incremental", "cdf", False, "spn", False, 5, True, False, 80, True),
46-
(5, 3, 5, 6, "obs", "mean", False, "spn", False, 5, True, True, 80, False),
47-
(5, [1, 2, 3], 5, 6, "obs", "mean", False, "spn", False, 5, True, True, 80, False),
48-
(5, [1, 3], 5, 6, "obs", "mean", False, "spn", False, 5, True, True, 80, False),
43+
(1, 3, 6, 8, None, None, False, "spn", True, 6, False, False, 80, False, None),
44+
(5, 3, 5, 6, "incremental", "cdf", False, "spn", False, 5, False, False, 80, False, None),
45+
(5, 3, 5, 6, "obs", "mean", False, "spn", False, 5, False, False, 80, False, None),
46+
(1, 3, 6, 8, None, None, False, "spn", True, 6, False, True, 80, False, None),
47+
(5, 3, 5, 6, "incremental", "cdf", False, "spn", False, 5, True, False, 80, True, None),
48+
(5, 3, 5, 6, "obs", "mean", False, "spn", False, 5, True, True, 80, False, None),
49+
(5, [1, 2, 3], 5, 6, "obs", "mean", False, "spn", False, 5, True, True, 80, False, None),
50+
(5, [1, 3], 5, 6, "obs", "mean", False, "spn", False, 5, True, True, 80, False, None),
4951
]
52+
# fmt:on
5053

5154
steps_arg_names = (
5255
"n_models",
@@ -63,6 +66,7 @@
6366
"zero_nwp",
6467
"smooth_radar_mask_range",
6568
"resample_distribution",
69+
"vel_pert_method",
6670
)
6771

6872

@@ -82,6 +86,7 @@ def test_steps_blending(
8286
zero_nwp,
8387
smooth_radar_mask_range,
8488
resample_distribution,
89+
vel_pert_method,
8590
):
8691
pytest.importorskip("cv2")
8792

@@ -275,7 +280,7 @@ def test_steps_blending(
275280
noise_method="nonparametric",
276281
noise_stddev_adj="auto",
277282
ar_order=2,
278-
vel_pert_method=None,
283+
vel_pert_method=vel_pert_method,
279284
weights_method=weights_method,
280285
conditional=False,
281286
probmatching_method=probmatching_method,

pysteps/tests/test_postprocessing_probmatching.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import numpy as np
22
import pytest
3-
from pysteps.postprocessing.probmatching import resample_distributions
4-
from pysteps.postprocessing.probmatching import nonparam_match_empirical_cdf
3+
4+
from pysteps.postprocessing.probmatching import (
5+
nonparam_match_empirical_cdf,
6+
resample_distributions,
7+
)
58

69

710
class TestResampleDistributions:

0 commit comments

Comments
 (0)