Skip to content

Commit 953f799

Browse files
authored
fix: make sure netcdf exporters can handle list of timesteps (#369)
1 parent 3121cbf commit 953f799

File tree

2 files changed

+38
-16
lines changed

2 files changed

+38
-16
lines changed

pysteps/io/exporters.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -392,9 +392,10 @@ def initialize_forecast_exporter_netcdf(
392392
Start date of the forecast.
393393
timestep: int
394394
Time step of the forecast (minutes).
395-
n_timesteps: int
396-
Number of time steps in the forecast this argument is ignored if
397-
incremental is set to 'timestep'.
395+
n_timesteps: int or list of integers
396+
Number of time steps to forecast or a list of time steps for which the
397+
forecasts are computed (relative to the input time step). The elements of
398+
the list are required to be in ascending order.
398399
shape: tuple of int
399400
Two-element tuple defining the shape (height,width) of the forecast
400401
grids.
@@ -460,8 +461,14 @@ def initialize_forecast_exporter_netcdf(
460461
+ "'timestep' or 'member'"
461462
)
462463

464+
n_timesteps_is_list = isinstance(n_timesteps, list)
465+
if n_timesteps_is_list:
466+
num_timesteps = len(n_timesteps)
467+
else:
468+
num_timesteps = n_timesteps
469+
463470
if incremental == "timestep":
464-
n_timesteps = None
471+
num_timesteps = None
465472
elif incremental == "member":
466473
n_ens_members = None
467474
elif incremental is not None:
@@ -498,7 +505,7 @@ def initialize_forecast_exporter_netcdf(
498505
h, w = shape
499506

500507
ncf.createDimension("ens_number", size=n_ens_members)
501-
ncf.createDimension("time", size=n_timesteps)
508+
ncf.createDimension("time", size=num_timesteps)
502509
ncf.createDimension("y", size=h)
503510
ncf.createDimension("x", size=w)
504511

@@ -585,7 +592,10 @@ def initialize_forecast_exporter_netcdf(
585592

586593
var_time = ncf.createVariable("time", int, dimensions=("time",))
587594
if incremental != "timestep":
588-
var_time[:] = [i * timestep * 60 for i in range(1, n_timesteps + 1)]
595+
if n_timesteps_is_list:
596+
var_time[:] = np.array(n_timesteps) * timestep * 60
597+
else:
598+
var_time[:] = [i * timestep * 60 for i in range(1, n_timesteps + 1)]
589599
var_time.long_name = "forecast time"
590600
startdate_str = datetime.strftime(startdate, "%Y-%m-%d %H:%M:%S")
591601
var_time.units = "seconds since %s" % startdate_str
@@ -635,7 +645,8 @@ def initialize_forecast_exporter_netcdf(
635645
exporter["timestep"] = timestep
636646
exporter["metadata"] = metadata
637647
exporter["incremental"] = incremental
638-
exporter["num_timesteps"] = n_timesteps
648+
exporter["num_timesteps"] = num_timesteps
649+
exporter["timesteps"] = n_timesteps
639650
exporter["num_ens_members"] = n_ens_members
640651
exporter["shape"] = shape
641652

@@ -853,7 +864,12 @@ def _export_netcdf(field, exporter):
853864
else:
854865
var_f[var_f.shape[0], :, :] = field
855866
var_time = exporter["var_time"]
856-
var_time[len(var_time) - 1] = len(var_time) * exporter["timestep"] * 60
867+
if isinstance(exporter["timesteps"], list):
868+
var_time[len(var_time) - 1] = (
869+
exporter["timesteps"][len(var_time) - 1] * exporter["timestep"] * 60
870+
)
871+
else:
872+
var_time[len(var_time) - 1] = len(var_time) * exporter["timestep"] * 60
857873
else:
858874
var_f[var_f.shape[0], :, :, :] = field
859875
var_ens_num = exporter["var_ens_num"]

pysteps/tests/test_exporters.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,17 @@
2323
"fill_value",
2424
"scale_factor",
2525
"offset",
26+
"n_timesteps",
2627
)
2728

2829
exporter_arg_values = [
29-
(1, None, np.float32, None, None, None),
30-
(1, "timestep", np.float32, 65535, None, None),
31-
(2, None, np.float32, 65535, None, None),
32-
(2, "timestep", np.float32, None, None, None),
33-
(2, "member", np.float64, None, 0.01, 1.0),
30+
(1, None, np.float32, None, None, None, 3),
31+
(1, "timestep", np.float32, 65535, None, None, 3),
32+
(2, None, np.float32, 65535, None, None, 3),
33+
(2, None, np.float32, 65535, None, None, [1, 2, 4]),
34+
(2, "timestep", np.float32, None, None, None, 3),
35+
(2, "timestep", np.float32, None, None, None, [1, 2, 4]),
36+
(2, "member", np.float64, None, 0.01, 1.0, 3),
3437
]
3538

3639

@@ -54,7 +57,7 @@ def test_get_geotiff_filename():
5457

5558
@pytest.mark.parametrize(exporter_arg_names, exporter_arg_values)
5659
def test_io_export_netcdf_one_member_one_time_step(
57-
n_ens_members, incremental, datatype, fill_value, scale_factor, offset
60+
n_ens_members, incremental, datatype, fill_value, scale_factor, offset, n_timesteps
5861
):
5962
"""
6063
Test the export netcdf.
@@ -75,7 +78,6 @@ def test_io_export_netcdf_one_member_one_time_step(
7578
file_path = os.path.join(outpath, outfnprefix + ".nc")
7679
startdate = metadata["timestamps"][0]
7780
timestep = metadata["accutime"]
78-
n_timesteps = 3
7981
shape = precip.shape[1:]
8082

8183
exporter = initialize_forecast_exporter_netcdf(
@@ -100,7 +102,11 @@ def test_io_export_netcdf_one_member_one_time_step(
100102
if incremental == None:
101103
export_forecast_dataset(precip, exporter)
102104
if incremental == "timestep":
103-
for t in range(n_timesteps):
105+
if isinstance(n_timesteps, list):
106+
timesteps = len(n_timesteps)
107+
else:
108+
timesteps = n_timesteps
109+
for t in range(timesteps):
104110
if n_ens_members > 1:
105111
export_forecast_dataset(precip[:, t, :, :], exporter)
106112
else:

0 commit comments

Comments
 (0)