Skip to content

Commit 7be1ac4

Browse files
committed
change to pytest
1 parent 13ce70c commit 7be1ac4

File tree

1 file changed

+58
-62
lines changed

1 file changed

+58
-62
lines changed

tests/engine/test_stats.py

Lines changed: 58 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525

2626
GRIB_FILENAME = "test_stats_grib.grib"
2727
STATS_FILE_NAMES = "test_stats.csv"
28+
NC_FILE_NAME = "test_stats.nc"
29+
NC_FILE_GLOB = "test_s*.nc"
2830

2931

3032
def initialize_dummy_netcdf_file(name):
@@ -129,80 +131,74 @@ def test_stats_grib(tmp_dir):
129131
), f"Stats dataframe incorrect. Difference:\n{df.values == expected}"
130132

131133

132-
class TestStatsNetcdf(unittest.TestCase):
133-
"""
134-
Unit test class for validating statistical calculations from NetCDF files.
134+
@pytest.fixture(name="setup_netcdf_file")
135+
def fixture_setup_netcdf_file(tmp_dir):
136+
"""Fixture to create and initialize a dummy NetCDF file for testing."""
135137

136-
This class tests the accuracy of statistical calculations (mean, max, min)
137-
performed on data extracted from NetCDF files.
138-
It ensures that the statistics DataFrame produced from the NetCDF data
139-
matches expected values.
140-
"""
138+
data = initialize_dummy_netcdf_file(os.path.join(tmp_dir, NC_FILE_NAME))
141139

142-
nc_file_name = "test_stats.nc"
143-
nc_file_glob = "test_s*.nc"
140+
# Creating variable "v1" with specified dimensions and setting its values
141+
data.createVariable("v1", np.float64, dimensions=("t", "z", "x"))
142+
data.variables["v1"][:] = np.ones((TIME_DIM_SIZE, HEIGHT_DIM_SIZE, HOR_DIM_SIZE))
143+
data.variables["v1"][:, :, 0] = 0
144+
data.variables["v1"][:, :, -1] = 2
144145

145-
def setUp(self):
146-
data = initialize_dummy_netcdf_file(self.nc_file_name)
146+
# Creating variable "v2" with fill_value, and setting its values
147+
data.createVariable("v2", np.float64, dimensions=("t", "x"), fill_value=42)
148+
data.variables["v2"][:] = np.ones((TIME_DIM_SIZE, HOR_DIM_SIZE)) * 2
149+
data.variables["v2"][:, 0] = 1
150+
data.variables["v2"][:, 1] = 42 # should be ignored in max-statistic
151+
data.variables["v2"][:, -1] = 3
147152

148-
data.createVariable("v1", np.float64, dimensions=("t", "z", "x"))
149-
data.variables["v1"][:] = np.ones(
150-
(TIME_DIM_SIZE, HEIGHT_DIM_SIZE, HOR_DIM_SIZE)
151-
)
152-
data.variables["v1"][:, :, 0] = 0
153-
data.variables["v1"][:, :, -1] = 2
153+
# Creating variable "v3" and setting its values
154+
data.createVariable("v3", np.float64, dimensions=("t", "x"))
155+
data.variables["v3"][:] = np.ones((TIME_DIM_SIZE, HOR_DIM_SIZE)) * 3
156+
data.variables["v3"][:, 0] = 2
157+
data.variables["v3"][:, -1] = 4
154158

155-
data.createVariable("v2", np.float64, dimensions=("t", "x"), fill_value=42)
156-
data.variables["v2"][:] = np.ones((TIME_DIM_SIZE, HOR_DIM_SIZE)) * 2
157-
data.variables["v2"][:, 0] = 1
158-
data.variables["v2"][:, 1] = 42 # shall be ignored in max-statistic
159-
data.variables["v2"][:, -1] = 3
159+
data.close()
160160

161-
data.createVariable("v3", np.float64, dimensions=("t", "x"))
162-
data.variables["v3"][:] = np.ones((TIME_DIM_SIZE, HOR_DIM_SIZE)) * 3
163-
data.variables["v3"][:, 0] = 2
164-
data.variables["v3"][:, -1] = 4
161+
yield
165162

166-
data.close()
167163

168-
def tear_down(self):
169-
os.remove(self.nc_file_name)
170-
os.remove(STATS_FILE_NAMES)
164+
def test_stats_netcdf(setup_netcdf_file, tmp_dir): # pylint: disable=unused-argument
165+
"""Test that the statistics generated from the NetCDF file match the
166+
expected values."""
171167

172-
def test_stats(self):
173-
file_specification = {
174-
"Test data": {
175-
"format": "netcdf",
176-
"time_dim": "t",
177-
"horizontal_dims": ["x"],
178-
"fill_value_key": "_FillValue", # should be the name for fill_value
179-
},
180-
}
168+
file_specification = {
169+
"Test data": {
170+
"format": "netcdf",
171+
"time_dim": "t",
172+
"horizontal_dims": ["x"],
173+
"fill_value_key": "_FillValue", # should be the name for fill_value
174+
},
175+
}
181176

182-
df = create_stats_dataframe(
183-
input_dir=".",
184-
file_id=[["Test data", self.nc_file_glob]],
185-
stats_file_name=STATS_FILE_NAMES,
186-
file_specification=file_specification,
187-
)
177+
# Call the function to generate the statistics dataframe
178+
df = create_stats_dataframe(
179+
input_dir=tmp_dir,
180+
file_id=[["Test data", NC_FILE_GLOB]],
181+
stats_file_name=STATS_FILE_NAMES,
182+
file_specification=file_specification,
183+
)
188184

189-
# check that the mean/max/min are correct
190-
expected = np.array(
191-
[
192-
[1.0, 2.0, 0.0, 1.0, 2.0, 0.0, 1.0, 2.0, 0.0],
193-
[1.0, 2.0, 0.0, 1.0, 2.0, 0.0, 1.0, 2.0, 0.0],
194-
[1.0, 2.0, 0.0, 1.0, 2.0, 0.0, 1.0, 2.0, 0.0],
195-
[1.0, 2.0, 0.0, 1.0, 2.0, 0.0, 1.0, 2.0, 0.0],
196-
[1.0, 2.0, 0.0, 1.0, 2.0, 0.0, 1.0, 2.0, 0.0],
197-
[2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0],
198-
[3.0, 4.0, 2.0, 3.0, 4.0, 2.0, 3.0, 4.0, 2.0],
199-
]
200-
)
185+
# Define the expected values for comparison
186+
expected = np.array(
187+
[
188+
[1.0, 2.0, 0.0, 1.0, 2.0, 0.0, 1.0, 2.0, 0.0],
189+
[1.0, 2.0, 0.0, 1.0, 2.0, 0.0, 1.0, 2.0, 0.0],
190+
[1.0, 2.0, 0.0, 1.0, 2.0, 0.0, 1.0, 2.0, 0.0],
191+
[1.0, 2.0, 0.0, 1.0, 2.0, 0.0, 1.0, 2.0, 0.0],
192+
[1.0, 2.0, 0.0, 1.0, 2.0, 0.0, 1.0, 2.0, 0.0],
193+
[2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0],
194+
[3.0, 4.0, 2.0, 3.0, 4.0, 2.0, 3.0, 4.0, 2.0],
195+
]
196+
)
201197

202-
self.assertTrue(
203-
np.array_equal(df.values, expected),
204-
f"stats dataframe incorrect. Difference:\n{df.values == expected}",
205-
)
198+
# Check that the dataframe matches the expected values
199+
assert np.array_equal(
200+
df.values, expected
201+
), f"Stats dataframe incorrect. Difference:\n{df.values == expected}"
206202

207203

208204
class TestStatsCsv(unittest.TestCase):

0 commit comments

Comments
 (0)