Skip to content

Commit b44c945

Browse files
authored
Minor exporters adjustment (#272)
Adjustment to the netcdf exporter to save the forecast correctly when a deterministic forecast is made.
1 parent 2dbd3de commit b44c945

File tree

2 files changed

+39
-11
lines changed

2 files changed

+39
-11
lines changed

pysteps/io/exporters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -804,7 +804,7 @@ def _export_netcdf(field, exporter):
804804
if exporter["num_ens_members"] > 1:
805805
var_f[:, var_f.shape[1], :, :] = field
806806
else:
807-
var_f[var_f.shape[1], :, :] = field
807+
var_f[var_f.shape[0], :, :] = field
808808
var_time = exporter["var_time"]
809809
var_time[len(var_time) - 1] = len(var_time) * exporter["timestep"] * 60
810810
else:

pysteps/tests/test_exporters.py

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,17 @@
1515
from pysteps.io.exporters import initialize_forecast_exporter_netcdf
1616
from pysteps.tests.helpers import get_precipitation_fields, get_invalid_mask
1717

18+
# Test arguments
19+
exporter_arg_names = ("n_ens_members", "incremental")
20+
21+
exporter_arg_values = [
22+
(1, None),
23+
(1, "timestep"),
24+
(2, None),
25+
(2, "timestep"),
26+
(2, "member"),
27+
]
28+
1829

1930
def test_get_geotiff_filename():
2031
"""Test the geotif name generator."""
@@ -34,7 +45,8 @@ def test_get_geotiff_filename():
3445
assert expected == file_name
3546

3647

37-
def test_io_export_netcdf_one_member_one_time_step():
48+
@pytest.mark.parametrize(exporter_arg_names, exporter_arg_values)
49+
def test_io_export_netcdf_one_member_one_time_step(n_ens_members, incremental):
3850
"""
3951
Test the export netcdf.
4052
Also, test that the exported file can be read by the importer.
@@ -43,20 +55,20 @@ def test_io_export_netcdf_one_member_one_time_step():
4355
pytest.importorskip("pyproj")
4456

4557
precip, metadata = get_precipitation_fields(
46-
return_raw=True, metadata=True, source="fmi"
58+
num_prev_files=2, return_raw=True, metadata=True, source="fmi"
4759
)
48-
precip = precip.squeeze()
4960

5061
invalid_mask = get_invalid_mask(precip)
5162

52-
# save it back to disk
5363
with tempfile.TemporaryDirectory() as outpath:
64+
# save it back to disk
5465
outfnprefix = "test_netcdf_out"
5566
file_path = os.path.join(outpath, outfnprefix + ".nc")
5667
startdate = metadata["timestamps"][0]
5768
timestep = metadata["accutime"]
58-
n_timesteps = 1
59-
shape = precip.shape
69+
n_timesteps = 3
70+
shape = precip.shape[1:]
71+
6072
exporter = initialize_forecast_exporter_netcdf(
6173
outpath,
6274
outfnprefix,
@@ -65,9 +77,25 @@ def test_io_export_netcdf_one_member_one_time_step():
6577
n_timesteps,
6678
shape,
6779
metadata,
68-
n_ens_members=1,
80+
n_ens_members=n_ens_members,
81+
incremental=incremental,
6982
)
70-
export_forecast_dataset(precip[np.newaxis, :], exporter)
83+
84+
if n_ens_members > 1:
85+
precip = np.repeat(precip[np.newaxis, :, :, :], n_ens_members, axis=0)
86+
87+
if incremental == None:
88+
export_forecast_dataset(precip, exporter)
89+
if incremental == "timestep":
90+
for t in range(n_timesteps):
91+
if n_ens_members > 1:
92+
export_forecast_dataset(precip[:, t, :, :], exporter)
93+
else:
94+
export_forecast_dataset(precip[t, :, :], exporter)
95+
if incremental == "member":
96+
for ens_mem in range(n_ens_members):
97+
export_forecast_dataset(precip[ens_mem, :, :, :], exporter)
98+
7199
close_forecast_files(exporter)
72100

73101
# assert if netcdf file was saved and file size is not zero
@@ -78,11 +106,11 @@ def test_io_export_netcdf_one_member_one_time_step():
78106

79107
precip_new, _ = import_netcdf_pysteps(output_file_path)
80108

81-
assert_array_almost_equal(precip, precip_new.data)
109+
assert_array_almost_equal(precip.squeeze(), precip_new.data)
82110
assert precip_new.dtype == "single"
83111

84112
precip_new, _ = import_netcdf_pysteps(output_file_path, dtype="double")
85-
assert_array_almost_equal(precip, precip_new.data)
113+
assert_array_almost_equal(precip.squeeze(), precip_new.data)
86114
assert precip_new.dtype == "double"
87115

88116
precip_new, _ = import_netcdf_pysteps(output_file_path, fillna=-1000)

0 commit comments

Comments
 (0)