Skip to content

Commit 8bec82e

Browse files
authored
Fix support for timesteps list in case model has rain but radar does not (#411)
* Fix support for timesteps list in case model has rain but radar does not * One more small fix and update tests * black * remove accidentally added duplicate test cases
1 parent 07a5aa8 commit 8bec82e

File tree

2 files changed

+37
-36
lines changed

2 files changed

+37
-36
lines changed

pysteps/blending/steps.py

Lines changed: 19 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,10 @@
5050
from scipy.linalg import inv
5151
from scipy.ndimage import binary_dilation, generate_binary_structure, iterate_structure
5252

53-
from pysteps import cascade
54-
from pysteps import extrapolation
55-
from pysteps import noise
56-
from pysteps import utils
53+
from pysteps import blending, cascade, extrapolation, noise, utils
5754
from pysteps.nowcasts import utils as nowcast_utils
5855
from pysteps.postprocessing import probmatching
5956
from pysteps.timeseries import autoregression, correlation
60-
from pysteps import blending
6157

6258
try:
6359
import dask
@@ -578,6 +574,14 @@ def forecast(
578574
precip_models_pm, precip_thr, norain_thr
579575
)
580576

577+
if isinstance(timesteps, int):
578+
timesteps = list(range(timesteps + 1))
579+
timestep_type = "int"
580+
else:
581+
original_timesteps = [0] + list(timesteps)
582+
timesteps = nowcast_utils.binned_timesteps(original_timesteps)
583+
timestep_type = "list"
584+
581585
# 2.3.1 If precip is below the norain threshold and precip_models_pm is zero,
582586
# we consider it as no rain in the domain.
583587
# The forecast will directly return an array filled with the minimum
@@ -591,14 +595,6 @@ def forecast(
591595
# Create the output list
592596
R_f = [[] for j in range(n_ens_members)]
593597

594-
if isinstance(timesteps, int):
595-
timesteps = range(timesteps + 1)
596-
timestep_type = "int"
597-
else:
598-
original_timesteps = [0] + list(timesteps)
599-
timesteps = nowcast_utils.binned_timesteps(original_timesteps)
600-
timestep_type = "list"
601-
602598
# Save per time step to ensure the array does not become too large if
603599
# no return_output is requested and callback is not None.
604600
for t, subtimestep_idx in enumerate(timesteps):
@@ -610,12 +606,13 @@ def forecast(
610606
R_f_ = np.full(
611607
(n_ens_members, precip_shape[0], precip_shape[1]), np.nanmin(precip)
612608
)
613-
if callback is not None:
614-
if R_f_.shape[1] > 0:
615-
callback(R_f_.squeeze())
616-
if return_output:
617-
for j in range(n_ens_members):
618-
R_f[j].append(R_f_[j])
609+
if subtimestep_idx:
610+
if callback is not None:
611+
if R_f_.shape[1] > 0:
612+
callback(R_f_.squeeze())
613+
if return_output:
614+
for j in range(n_ens_members):
615+
R_f[j].append(R_f_[j])
619616

620617
R_f_ = None
621618

@@ -680,7 +677,8 @@ def forecast(
680677
precip_models_pm, precip_thr, precip_models_pm.shape[0], timesteps
681678
)
682679
# Make sure precip_noise_input is three dimensional
683-
precip_noise_input = precip_noise_input[np.newaxis, :, :]
680+
if len(precip_noise_input.shape) != 3:
681+
precip_noise_input = precip_noise_input[np.newaxis, :, :]
684682
else:
685683
precip_noise_input = precip.copy()
686684

@@ -782,14 +780,6 @@ def forecast(
782780
if measure_time:
783781
starttime_mainloop = time.time()
784782

785-
if isinstance(timesteps, int):
786-
timesteps = range(timesteps + 1)
787-
timestep_type = "int"
788-
else:
789-
original_timesteps = [0] + list(timesteps)
790-
timesteps = nowcast_utils.binned_timesteps(original_timesteps)
791-
timestep_type = "list"
792-
793783
extrap_kwargs["return_displacement"] = True
794784
forecast_prev = precip_cascade
795785
noise_prev = noise_cascade
@@ -2498,7 +2488,7 @@ def _determine_max_nr_rainy_cells_nwp(
24982488
max_rain_pixels_j = -1
24992489
max_rain_pixels_t = -1
25002490
for j in range(n_models):
2501-
for t in range(timesteps):
2491+
for t in timesteps:
25022492
rain_pixels = precip_models_pm[j][t][
25032493
precip_models_pm[j][t] > precip_thr
25042494
].size

pysteps/tests/test_blending_steps.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
# -*- coding: utf-8 -*-
22

3-
import numpy as np
43
import datetime
4+
5+
import numpy as np
56
import pytest
6-
import pysteps
7-
from pysteps import cascade, blending
87

8+
import pysteps
9+
from pysteps import blending, cascade
910

1011
steps_arg_values = [
1112
(1, 3, 4, 8, None, None, False, "spn", True, 4, False, False, 0, False),
@@ -14,6 +15,7 @@
1415
(1, 3, 4, 8, None, "mean", False, "spn", True, 4, False, False, 0, False),
1516
(1, 3, 4, 8, None, "mean", False, "spn", True, 4, False, False, 0, True),
1617
(1, 3, 4, 8, None, "cdf", False, "spn", True, 4, False, False, 0, False),
18+
(1, [1, 2, 3], 4, 8, None, "cdf", False, "spn", True, 4, False, False, 0, False),
1719
(1, 3, 4, 8, "incremental", "cdf", False, "spn", True, 4, False, False, 0, False),
1820
(1, 3, 4, 6, "incremental", "cdf", False, "bps", True, 4, False, False, 0, False),
1921
(1, 3, 4, 6, "incremental", "cdf", False, "bps", False, 4, False, False, 0, False),
@@ -42,11 +44,13 @@
4244
(1, 3, 6, 8, None, None, False, "spn", True, 6, False, True, 80, False),
4345
(5, 3, 5, 6, "incremental", "cdf", False, "spn", False, 5, True, False, 80, True),
4446
(5, 3, 5, 6, "obs", "mean", False, "spn", False, 5, True, True, 80, False),
47+
(5, [1, 2, 3], 5, 6, "obs", "mean", False, "spn", False, 5, True, True, 80, False),
48+
(5, [1, 3], 5, 6, "obs", "mean", False, "spn", False, 5, True, True, 80, False),
4549
]
4650

4751
steps_arg_names = (
4852
"n_models",
49-
"n_timesteps",
53+
"timesteps",
5054
"n_ens_members",
5155
"n_cascade_levels",
5256
"mask_method",
@@ -65,7 +69,7 @@
6569
@pytest.mark.parametrize(steps_arg_names, steps_arg_values)
6670
def test_steps_blending(
6771
n_models,
68-
n_timesteps,
72+
timesteps,
6973
n_ens_members,
7074
n_cascade_levels,
7175
mask_method,
@@ -85,7 +89,14 @@ def test_steps_blending(
8589
# The input data
8690
###
8791
# Initialise dummy NWP data
88-
nwp_precip = np.zeros((n_models, n_timesteps + 1, 200, 200))
92+
if not isinstance(timesteps, int):
93+
n_timesteps = len(timesteps)
94+
last_timestep = timesteps[-1]
95+
else:
96+
n_timesteps = timesteps
97+
last_timestep = timesteps
98+
99+
nwp_precip = np.zeros((n_models, last_timestep + 1, 200, 200))
89100

90101
if not zero_nwp:
91102
for n_model in range(n_models):
@@ -250,7 +261,7 @@ def test_steps_blending(
250261
precip_models=nwp_precip_decomp,
251262
velocity=radar_velocity,
252263
velocity_models=nwp_velocity,
253-
timesteps=n_timesteps,
264+
timesteps=timesteps,
254265
timestep=5.0,
255266
issuetime=datetime.datetime.strptime("202112012355", "%Y%m%d%H%M"),
256267
n_ens_members=n_ens_members,

0 commit comments

Comments
 (0)