Skip to content

Commit 098927d

Browse files
mpvgindednerini
andauthored
Bugfix reproducibility & ensemble member order with dask (#347)
* Bugfix: fix random placement of ensemble members in numpy array due to dask multi-threading (#337) * Bugfix: make STEPS (blending) nowcast reproducable when the seed argument is given (#346) * Bugfix: make STEPS (blending) nowcast reproducable, independent of number of workers (#346) * Formatting with black --------- Co-authored-by: ned <daniele.nerini@meteoswiss.ch>
1 parent 3167a11 commit 098927d

File tree

3 files changed

+29
-27
lines changed

3 files changed

+29
-27
lines changed

pysteps/blending/steps.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,7 @@ def forecast(
514514
)
515515

516516
# 2. Initialize the noise method
517+
np.random.seed(seed)
517518
pp, generate_noise, noise_std_coeffs = _init_noise(
518519
precip,
519520
precip_thr,
@@ -526,6 +527,7 @@ def forecast(
526527
noise_stddev_adj,
527528
measure_time,
528529
num_workers,
530+
seed,
529531
)
530532

531533
# 3. Perform the cascade decomposition for the input precip fields and
@@ -1662,6 +1664,7 @@ def _init_noise(
16621664
noise_stddev_adj,
16631665
measure_time,
16641666
num_workers,
1667+
seed,
16651668
):
16661669
"""Initialize the noise method."""
16671670
if noise_method is None:
@@ -1690,6 +1693,7 @@ def _init_noise(
16901693
20,
16911694
conditional=True,
16921695
num_workers=num_workers,
1696+
seed=seed,
16931697
)
16941698

16951699
if measure_time:
@@ -1944,7 +1948,6 @@ def _init_random_generators(
19441948
if noise_method is not None:
19451949
randgen_prec = []
19461950
randgen_motion = []
1947-
np.random.seed(seed)
19481951
for j in range(n_ens_members):
19491952
rs = np.random.RandomState(seed)
19501953
randgen_prec.append(rs)

pysteps/noise/utils.py

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -96,39 +96,37 @@ def compute_noise_stddev_adjs(
9696

9797
if dask_imported and num_workers > 1:
9898
res = []
99-
else:
100-
N_stds = []
10199

100+
N_stds = [None] * num_iter
102101
randstates = []
103-
seed = None
102+
104103
for k in range(num_iter):
105104
randstates.append(np.random.RandomState(seed=seed))
106105
seed = np.random.randint(0, high=1e9)
107106

108-
for k in range(num_iter):
109-
110-
def worker():
111-
# generate Gaussian white noise field, filter it using the chosen
112-
# method, multiply it with the standard deviation of the observed
113-
# field and apply the precipitation mask
114-
N = noise_generator(noise_filter, randstate=randstates[k], seed=seed)
115-
N = N / np.std(N) * sigma + mu
116-
N[~MASK] = R_thr_2
107+
def worker(k):
108+
# generate Gaussian white noise field, filter it using the chosen
109+
# method, multiply it with the standard deviation of the observed
110+
# field and apply the precipitation mask
111+
N = noise_generator(noise_filter, randstate=randstates[k])
112+
N = N / np.std(N) * sigma + mu
113+
N[~MASK] = R_thr_2
117114

118-
# subtract the mean and decompose the masked noise field into a
119-
# cascade
120-
N -= mu
121-
decomp_N = decomp_method(N, F, mask=MASK_)
115+
# subtract the mean and decompose the masked noise field into a
116+
# cascade
117+
N -= mu
118+
decomp_N = decomp_method(N, F, mask=MASK_)
122119

123-
return decomp_N["stds"]
124-
125-
if dask_imported and num_workers > 1:
126-
res.append(dask.delayed(worker)())
127-
else:
128-
N_stds.append(worker())
120+
N_stds[k] = decomp_N["stds"]
129121

130122
if dask_imported and num_workers > 1:
131-
N_stds = dask.compute(*res, num_workers=num_workers)
123+
for k in range(num_iter):
124+
res.append(dask.delayed(worker)(k))
125+
dask.compute(*res, num_workers=num_workers)
126+
127+
else:
128+
for k in range(num_iter):
129+
worker(k)
132130

133131
# for each cascade level, compare the standard deviations between the
134132
# observed field and the masked noise field, which gives the correction

pysteps/nowcasts/steps.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,7 @@ def f(precip, i):
443443
precip[i, ~np.isfinite(precip[i, :])] = np.nanmin(precip[i, :])
444444

445445
if noise_method is not None:
446+
np.random.seed(seed)
446447
# get methods for perturbations
447448
init_noise, generate_noise = noise.get_method(noise_method)
448449

@@ -466,6 +467,7 @@ def f(precip, i):
466467
20,
467468
conditional=True,
468469
num_workers=num_workers,
470+
seed=seed,
469471
)
470472

471473
if measure_time:
@@ -543,7 +545,6 @@ def f(precip, i):
543545
if noise_method is not None:
544546
randgen_prec = []
545547
randgen_motion = []
546-
np.random.seed(seed)
547548
for _ in range(n_ens_members):
548549
rs = np.random.RandomState(seed)
549550
randgen_prec.append(rs)
@@ -706,7 +707,7 @@ def _check_inputs(precip, velocity, timesteps, ar_order):
706707

707708

708709
def _update(state, params):
709-
precip_forecast_out = []
710+
precip_forecast_out = [None] * params["n_ens_members"]
710711

711712
if params["noise_method"] is None or params["mask_method"] == "sprog":
712713
for i in range(params["n_cascade_levels"]):
@@ -828,7 +829,7 @@ def worker(j):
828829

829830
precip_forecast[params["domain_mask"]] = np.nan
830831

831-
precip_forecast_out.append(precip_forecast)
832+
precip_forecast_out[j] = precip_forecast
832833

833834
if (
834835
DASK_IMPORTED

0 commit comments

Comments
 (0)