Skip to content

Commit 8936236

Browse files
authored
Place ensemble member number determination for blending inside forecast loop to prevent out of memory issues (#273)
Determine which member is blended with which NWP member per time step instead of at once to reduce memory usage and requirements.
1 parent b44c945 commit 8936236

File tree

2 files changed

+93
-102
lines changed

2 files changed

+93
-102
lines changed

pysteps/blending/steps.py

Lines changed: 91 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,13 @@
1616
method to decompose and store the NWP model fields whenever a new NWP model
1717
field is present, is present in pysteps.blending.utils.decompose_NWP.
1818
#. Estimate AR parameters for the extrapolation nowcast and noise cascade.
19-
#. Before starting the forecast loop, determine which NWP models will be
20-
combined with which nowcast ensemble members. The number of output ensemble
21-
members equals the maximum number of (ensemble) members in the input, which
22-
can be either the defined number of (nowcast) ensemble members or the number
23-
of NWP models/members.
2419
#. Initialize all the random generators.
2520
#. Calculate the initial skill of the NWP model forecasts at t=0.
2621
#. Start the forecasting loop:
22+
#. Determine which NWP models will be combined with which nowcast ensemble
23+
member. The number of output ensemble members equals the maximum number
24+
of (ensemble) members in the input, which can be either the defined
25+
number of (nowcast) ensemble members or the number of NWP models/members.
2726
#. Determine the skill and weights of the forecasting components
2827
(extrapolation, NWP and noise) for that lead time.
2928
#. Regress the extrapolation and noise cascades separately to the subsequent
@@ -149,7 +148,9 @@ def forecast(
149148
Datetime object containing the date and time for which the forecast
150149
is issued.
151150
n_ens_members: int
152-
The number of ensemble members to generate.
151+
The number of ensemble members to generate. This number should always be
152+
equal to or larger than the number of NWP ensemble members / number of
153+
NWP models.
153154
n_cascade_levels: int, optional
154155
The number of cascade levels to use. Default set to 8 due to default
155156
climatological skill values on 8 levels.
@@ -561,34 +562,19 @@ def forecast(
561562
precip_cascade, ar_order, n_cascade_levels, MASK_thr
562563
)
563564

564-
# 5. Before calling the worker for the forecast loop, determine which (NWP)
565-
# models will be combined with which nowcast ensemble members. With the
566-
# way it is implemented at this moment: n_ens_members of the output equals
567-
# the maximum number of (ensemble) members in the input (either the nowcasts or NWP).
568-
(
569-
precip_cascade,
570-
precip_models_cascade,
571-
precip_models_pm,
572-
velocity_models,
573-
mu_models,
574-
sigma_models,
575-
n_ens_members,
576-
n_model_indices,
577-
) = _find_nwp_combination(
578-
precip_cascade,
579-
precip_models_cascade,
580-
precip_models_pm,
581-
velocity_models,
582-
mu_models,
583-
sigma_models,
584-
n_ens_members,
585-
ar_order,
586-
n_cascade_levels,
587-
blend_nwp_members,
588-
)
565+
# 5. Repeat precip_cascade for n ensemble members
566+
# First, discard all except the p-1 last cascades because they are not needed
567+
# for the AR(p) model
568+
precip_cascade = [precip_cascade[i][-ar_order:] for i in range(n_cascade_levels)]
569+
570+
precip_cascade = [
571+
[precip_cascade[j].copy() for j in range(n_cascade_levels)]
572+
for i in range(n_ens_members)
573+
]
574+
precip_cascade = np.stack(precip_cascade)
589575

590576
# Also initialize the cascade of temporally correlated noise, which has the
591-
# same shape as R_c, but starts with value zero.
577+
# same shape as precip_cascade, but starts with value zero.
592578
noise_cascade = np.zeros(precip_cascade.shape)
593579

594580
# 6. Initialize all the random generators and prepare for the forecast loop
@@ -617,17 +603,7 @@ def forecast(
617603

618604
precip = precip[-1, :, :]
619605

620-
# 7. Calculate the initial skill of the (NWP) model forecasts at t=0
621-
rho_nwp_models = _compute_initial_nwp_skill(
622-
precip_cascade,
623-
precip_models_cascade,
624-
domain_mask,
625-
issuetime,
626-
outdir_path_skill,
627-
clim_kwargs,
628-
)
629-
630-
# Also initizalize the current and previous extrapolation forecast scale
606+
# 7. initizalize the current and previous extrapolation forecast scale
631607
# for the nowcasting component
632608
rho_extr_prev = np.repeat(1.0, PHI.shape[0])
633609
rho_extr = PHI[:, 0] / (1.0 - PHI[:, 1]) # phi1 / (1 - phi2), see BPS2004
@@ -681,9 +657,43 @@ def forecast(
681657
if measure_time:
682658
starttime = time.time()
683659

684-
# 8.1.1 Determine the skill of the components for lead time (t0 + t)
685-
# First for the extrapolation component. Only calculate it when t > 0.
660+
# 8.1.1 Before calling the worker for the forecast loop, determine which (NWP)
661+
# models will be combined with which nowcast ensemble members. With the
662+
# way it is implemented at this moment: n_ens_members of the output equals
663+
# the maximum number of (ensemble) members in the input (either the nowcasts or NWP).
664+
(
665+
precip_models_cascade_temp,
666+
precip_models_pm_temp,
667+
velocity_models_temp,
668+
mu_models_temp,
669+
sigma_models_temp,
670+
n_model_indices,
671+
) = _find_nwp_combination(
672+
precip_models_cascade[:, t, :, :, :],
673+
precip_models_pm[:, t, :, :],
674+
velocity_models[:, t, :, :, :],
675+
mu_models[:, t, :],
676+
sigma_models[:, t, :],
677+
n_ens_members,
678+
ar_order,
679+
n_cascade_levels,
680+
blend_nwp_members,
681+
)
682+
683+
if t == 0:
684+
# 8.1.2 Calculate the initial skill of the (NWP) model forecasts at t=0
685+
rho_nwp_models = _compute_initial_nwp_skill(
686+
precip_cascade,
687+
precip_models_cascade_temp,
688+
domain_mask,
689+
issuetime,
690+
outdir_path_skill,
691+
clim_kwargs,
692+
)
693+
686694
if t > 0:
695+
# 8.1.3 Determine the skill of the components for lead time (t0 + t)
696+
# First for the extrapolation component. Only calculate it when t > 0.
687697
(
688698
rho_extr,
689699
rho_extr_prev,
@@ -740,18 +750,22 @@ def worker(j):
740750
# Only the weights of the components without the extrapolation
741751
# cascade will be determined here. The full set of weights are
742752
# determined after the extrapolation step in this method.
743-
if blend_nwp_members and precip_models_cascade.shape[0] > 1:
753+
if blend_nwp_members and precip_models_cascade_temp.shape[0] > 1:
744754
weights_model_only = np.zeros(
745-
(precip_models_cascade.shape[0] + 1, n_cascade_levels)
755+
(precip_models_cascade_temp.shape[0] + 1, n_cascade_levels)
746756
)
747757
for i in range(n_cascade_levels):
748758
# Determine the normalized covariance matrix (containing)
749759
# the cross-correlations between the models
750760
cov = np.corrcoef(
751761
np.stack(
752762
[
753-
precip_models_cascade[n_model, t, i, :, :].flatten()
754-
for n_model in range(precip_models_cascade.shape[0])
763+
precip_models_cascade_temp[
764+
n_model, i, :, :
765+
].flatten()
766+
for n_model in range(
767+
precip_models_cascade_temp.shape[0]
768+
)
755769
]
756770
)
757771
)
@@ -883,12 +897,12 @@ def worker(j):
883897
V_stack = np.concatenate(
884898
(
885899
velocity_pert[None, :, :, :],
886-
velocity_models[:, t, :, :, :],
900+
velocity_models_temp,
887901
),
888902
axis=0,
889903
)
890904
else:
891-
V_model_ = velocity_models[j, t, :, :, :]
905+
V_model_ = velocity_models_temp[j]
892906
V_stack = np.concatenate(
893907
(velocity_pert[None, :, :, :], V_model_[None, :, :, :]),
894908
axis=0,
@@ -983,11 +997,11 @@ def worker(j):
983997
# Stack the perturbed extrapolation and the NWP velocities
984998
if blend_nwp_members:
985999
V_stack = np.concatenate(
986-
(velocity_pert[None, :, :, :], velocity_models[:, t, :, :, :]),
1000+
(velocity_pert[None, :, :, :], velocity_models_temp),
9871001
axis=0,
9881002
)
9891003
else:
990-
V_model_ = velocity_models[j, t, :, :, :]
1004+
V_model_ = velocity_models_temp[j]
9911005
V_stack = np.concatenate(
9921006
(velocity_pert[None, :, :, :], V_model_[None, :, :, :]), axis=0
9931007
)
@@ -1050,32 +1064,32 @@ def worker(j):
10501064
cascades_stacked = np.concatenate(
10511065
(
10521066
R_f_ep_out[None, t_index],
1053-
precip_models_cascade[:, t],
1067+
precip_models_cascade_temp,
10541068
Yn_ep_out[None, t_index],
10551069
),
10561070
axis=0,
10571071
) # [(extr_field, n_model_fields, noise), n_cascade_levels, ...]
10581072
means_stacked = np.concatenate(
1059-
(mu_extrapolation[None, :], mu_models[:, t]), axis=0
1073+
(mu_extrapolation[None, :], mu_models_temp), axis=0
10601074
)
10611075
sigmas_stacked = np.concatenate(
1062-
(sigma_extrapolation[None, :], sigma_models[:, t]),
1076+
(sigma_extrapolation[None, :], sigma_models_temp),
10631077
axis=0,
10641078
)
10651079
else:
10661080
cascades_stacked = np.concatenate(
10671081
(
10681082
R_f_ep_out[None, t_index],
1069-
precip_models_cascade[None, j, t],
1083+
precip_models_cascade_temp[None, j],
10701084
Yn_ep_out[None, t_index],
10711085
),
10721086
axis=0,
10731087
) # [(extr_field, n_model_fields, noise), n_cascade_levels, ...]
10741088
means_stacked = np.concatenate(
1075-
(mu_extrapolation[None, :], mu_models[None, j, t]), axis=0
1089+
(mu_extrapolation[None, :], mu_models_temp[None, j]), axis=0
10761090
)
10771091
sigmas_stacked = np.concatenate(
1078-
(sigma_extrapolation[None, :], sigma_models[None, j, t]),
1092+
(sigma_extrapolation[None, :], sigma_models_temp[None, j]),
10791093
axis=0,
10801094
)
10811095

@@ -1173,15 +1187,15 @@ def worker(j):
11731187
R_pm_stacked = np.concatenate(
11741188
(
11751189
R_pm_ep[None, t_index],
1176-
precip_models_pm[:, t],
1190+
precip_models_pm_temp,
11771191
),
11781192
axis=0,
11791193
)
11801194
else:
11811195
R_pm_stacked = np.concatenate(
11821196
(
11831197
R_pm_ep[None, t_index],
1184-
precip_models_pm[None, j, t],
1198+
precip_models_pm_temp[None, j],
11851199
),
11861200
axis=0,
11871201
)
@@ -1198,11 +1212,11 @@ def worker(j):
11981212
weights_pm_normalized_mod_only.reshape(
11991213
weights_pm_normalized_mod_only.shape[0], 1, 1
12001214
)
1201-
* precip_models_pm[:, t],
1215+
* precip_models_pm_temp,
12021216
axis=0,
12031217
)
12041218
else:
1205-
R_pm_blended_mod_only = precip_models_pm[j, t]
1219+
R_pm_blended_mod_only = precip_models_pm_temp[j]
12061220

12071221
# The extrapolation components are NaN outside the advected
12081222
# radar domain. This results in NaN values in the blended
@@ -1816,7 +1830,6 @@ def _estimate_ar_parameters_radar(R_c, ar_order, n_cascade_levels, MASK_thr):
18161830

18171831

18181832
def _find_nwp_combination(
1819-
R_c,
18201833
precip_models,
18211834
R_models_pm,
18221835
velocity_models,
@@ -1831,26 +1844,22 @@ def _find_nwp_combination(
18311844
With the way it is implemented at this moment: n_ens_members of the output equals
18321845
the maximum number of (ensemble) members in the input (either the nowcasts or NWP).
18331846
"""
1834-
###
1835-
# First, discard all except the p-1 last cascades because they are not needed
1836-
# for the AR(p) model
1837-
R_c = [R_c[i][-ar_order:] for i in range(n_cascade_levels)]
1847+
# Make sure the number of model members is not larger than than or equal to
1848+
# n_ens_members
1849+
n_model_members = precip_models.shape[0]
1850+
if n_model_members > n_ens_members:
1851+
raise ValueError(
1852+
"The number of NWP model members is larger than the given number of ensemble members. n_model_members <= n_ens_members."
1853+
)
18381854

18391855
# Check if NWP models/members should be used individually, or if all of
18401856
# them are blended together per nowcast ensemble member.
18411857
if blend_nwp_members:
1842-
# stack the extrapolation cascades into a list containing all ensemble members
1843-
R_c = [
1844-
[R_c[j].copy() for j in range(n_cascade_levels)]
1845-
for i in range(n_ens_members)
1846-
]
1847-
R_c = np.stack(R_c)
18481858
n_model_indices = None
18491859

18501860
else:
18511861
# Start with determining the maximum and mimimum number of members/models
18521862
# in both input products
1853-
n_model_members = precip_models.shape[0]
18541863
n_ens_members_max = max(n_ens_members, n_model_members)
18551864
n_ens_members_min = min(n_ens_members, n_model_members)
18561865
# Also make a list of the model index numbers. These indices are needed
@@ -1871,20 +1880,12 @@ def _find_nwp_combination(
18711880
# member 5, etc.), until 10 is reached.
18721881
if n_ens_members_min != n_ens_members_max:
18731882
if n_model_members == 1:
1874-
precip_models = np.repeat(
1875-
precip_models[:, :, :, :, :], n_ens_members_max, axis=0
1876-
)
1877-
mu_models = np.repeat(mu_models[:, :, :], n_ens_members_max, axis=0)
1878-
sigma_models = np.repeat(
1879-
sigma_models[:, :, :], n_ens_members_max, axis=0
1880-
)
1881-
velocity_models = np.repeat(
1882-
velocity_models[:, :, :, :], n_ens_members_max, axis=0
1883-
)
1883+
precip_models = np.repeat(precip_models, n_ens_members_max, axis=0)
1884+
mu_models = np.repeat(mu_models, n_ens_members_max, axis=0)
1885+
sigma_models = np.repeat(sigma_models, n_ens_members_max, axis=0)
1886+
velocity_models = np.repeat(velocity_models, n_ens_members_max, axis=0)
18841887
# For the prob. matching
1885-
R_models_pm = np.repeat(
1886-
R_models_pm[:, :, :, :], n_ens_members_max, axis=0
1887-
)
1888+
R_models_pm = np.repeat(R_models_pm, n_ens_members_max, axis=0)
18881889
# Finally, for the model indices
18891890
n_model_indices = np.repeat(n_model_indices, n_ens_members_max, axis=0)
18901891

@@ -1903,22 +1904,12 @@ def _find_nwp_combination(
19031904
# Finally, for the model indices
19041905
n_model_indices = np.repeat(n_model_indices, repeats, axis=0)
19051906

1906-
R_c = [
1907-
[R_c[j].copy() for j in range(n_cascade_levels)]
1908-
for i in range(n_ens_members_max)
1909-
]
1910-
R_c = np.stack(R_c)
1911-
1912-
n_ens_members = n_ens_members_max
1913-
19141907
return (
1915-
R_c,
19161908
precip_models,
19171909
R_models_pm,
19181910
velocity_models,
19191911
mu_models,
19201912
sigma_models,
1921-
n_ens_members,
19221913
n_model_indices,
19231914
)
19241915

@@ -2015,7 +2006,7 @@ def _compute_initial_nwp_skill(
20152006
rho_nwp_models = [
20162007
blending.skill_scores.spatial_correlation(
20172008
obs=R_c[0, :, -1, :, :],
2018-
mod=precip_models[n_model, 0, :, :, :],
2009+
mod=precip_models[n_model, :, :, :],
20192010
domain_mask=domain_mask,
20202011
)
20212012
for n_model in range(precip_models.shape[0])
@@ -2024,7 +2015,7 @@ def _compute_initial_nwp_skill(
20242015

20252016
# Ensure that the model skill decreases with increasing scale level.
20262017
for n_model in range(precip_models.shape[0]):
2027-
for i in range(1, precip_models.shape[2]):
2018+
for i in range(1, precip_models.shape[1]):
20282019
if rho_nwp_models[n_model, i] > rho_nwp_models[n_model, i - 1]:
20292020
# Set it equal to the previous scale level
20302021
rho_nwp_models[n_model, i] = rho_nwp_models[n_model, i - 1]

pysteps/tests/test_blending_steps.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@
3131
(1, 3, 4, 6, "incremental", "cdf", False, "bps", False, 4),
3232
(1, 3, 4, 9, "incremental", "cdf", False, "spn", True, 4),
3333
(2, 3, 10, 8, "incremental", "cdf", False, "spn", True, 10),
34-
(5, 3, 4, 8, "incremental", "cdf", False, "spn", True, 5),
34+
(5, 3, 5, 8, "incremental", "cdf", False, "spn", True, 5),
3535
(1, 10, 1, 8, "incremental", "cdf", False, "spn", True, 1),
36-
(5, 3, 2, 8, "incremental", "cdf", True, "spn", True, 2),
36+
(2, 3, 2, 8, "incremental", "cdf", True, "spn", True, 2),
3737
]
3838

3939

0 commit comments

Comments
 (0)