|
25 | 25 |
|
26 | 26 | GRIB_FILENAME = "test_stats_grib.grib" |
27 | 27 | STATS_FILE_NAMES = "test_stats.csv" |
| 28 | +NC_FILE_NAME = "test_stats.nc" |
| 29 | +NC_FILE_GLOB = "test_s*.nc" |
28 | 30 |
|
29 | 31 |
|
30 | 32 | def initialize_dummy_netcdf_file(name): |
@@ -129,80 +131,74 @@ def test_stats_grib(tmp_dir): |
129 | 131 | ), f"Stats dataframe incorrect. Difference:\n{df.values == expected}" |
130 | 132 |
|
131 | 133 |
|
132 | | -class TestStatsNetcdf(unittest.TestCase): |
133 | | - """ |
134 | | - Unit test class for validating statistical calculations from NetCDF files. |
| 134 | +@pytest.fixture(name="setup_netcdf_file") |
| 135 | +def fixture_setup_netcdf_file(tmp_dir): |
| 136 | + """Fixture to create and initialize a dummy NetCDF file for testing.""" |
135 | 137 |
|
136 | | - This class tests the accuracy of statistical calculations (mean, max, min) |
137 | | - performed on data extracted from NetCDF files. |
138 | | - It ensures that the statistics DataFrame produced from the NetCDF data |
139 | | - matches expected values. |
140 | | - """ |
| 138 | + data = initialize_dummy_netcdf_file(os.path.join(tmp_dir, NC_FILE_NAME)) |
141 | 139 |
|
142 | | - nc_file_name = "test_stats.nc" |
143 | | - nc_file_glob = "test_s*.nc" |
| 140 | + # Creating variable "v1" with specified dimensions and setting its values |
| 141 | + data.createVariable("v1", np.float64, dimensions=("t", "z", "x")) |
| 142 | + data.variables["v1"][:] = np.ones((TIME_DIM_SIZE, HEIGHT_DIM_SIZE, HOR_DIM_SIZE)) |
| 143 | + data.variables["v1"][:, :, 0] = 0 |
| 144 | + data.variables["v1"][:, :, -1] = 2 |
144 | 145 |
|
145 | | - def setUp(self): |
146 | | - data = initialize_dummy_netcdf_file(self.nc_file_name) |
| 146 | + # Creating variable "v2" with fill_value, and setting its values |
| 147 | + data.createVariable("v2", np.float64, dimensions=("t", "x"), fill_value=42) |
| 148 | + data.variables["v2"][:] = np.ones((TIME_DIM_SIZE, HOR_DIM_SIZE)) * 2 |
| 149 | + data.variables["v2"][:, 0] = 1 |
| 150 | + data.variables["v2"][:, 1] = 42 # should be ignored in max-statistic |
| 151 | + data.variables["v2"][:, -1] = 3 |
147 | 152 |
|
148 | | - data.createVariable("v1", np.float64, dimensions=("t", "z", "x")) |
149 | | - data.variables["v1"][:] = np.ones( |
150 | | - (TIME_DIM_SIZE, HEIGHT_DIM_SIZE, HOR_DIM_SIZE) |
151 | | - ) |
152 | | - data.variables["v1"][:, :, 0] = 0 |
153 | | - data.variables["v1"][:, :, -1] = 2 |
| 153 | + # Creating variable "v3" and setting its values |
| 154 | + data.createVariable("v3", np.float64, dimensions=("t", "x")) |
| 155 | + data.variables["v3"][:] = np.ones((TIME_DIM_SIZE, HOR_DIM_SIZE)) * 3 |
| 156 | + data.variables["v3"][:, 0] = 2 |
| 157 | + data.variables["v3"][:, -1] = 4 |
154 | 158 |
|
155 | | - data.createVariable("v2", np.float64, dimensions=("t", "x"), fill_value=42) |
156 | | - data.variables["v2"][:] = np.ones((TIME_DIM_SIZE, HOR_DIM_SIZE)) * 2 |
157 | | - data.variables["v2"][:, 0] = 1 |
158 | | - data.variables["v2"][:, 1] = 42 # shall be ignored in max-statistic |
159 | | - data.variables["v2"][:, -1] = 3 |
| 159 | + data.close() |
160 | 160 |
|
161 | | - data.createVariable("v3", np.float64, dimensions=("t", "x")) |
162 | | - data.variables["v3"][:] = np.ones((TIME_DIM_SIZE, HOR_DIM_SIZE)) * 3 |
163 | | - data.variables["v3"][:, 0] = 2 |
164 | | - data.variables["v3"][:, -1] = 4 |
| 161 | + yield |
165 | 162 |
|
166 | | - data.close() |
167 | 163 |
|
168 | | - def tear_down(self): |
169 | | - os.remove(self.nc_file_name) |
170 | | - os.remove(STATS_FILE_NAMES) |
| 164 | +def test_stats_netcdf(setup_netcdf_file, tmp_dir): # pylint: disable=unused-argument |
| 165 | + """Test that the statistics generated from the NetCDF file match the |
| 166 | + expected values.""" |
171 | 167 |
|
172 | | - def test_stats(self): |
173 | | - file_specification = { |
174 | | - "Test data": { |
175 | | - "format": "netcdf", |
176 | | - "time_dim": "t", |
177 | | - "horizontal_dims": ["x"], |
178 | | - "fill_value_key": "_FillValue", # should be the name for fill_value |
179 | | - }, |
180 | | - } |
| 168 | + file_specification = { |
| 169 | + "Test data": { |
| 170 | + "format": "netcdf", |
| 171 | + "time_dim": "t", |
| 172 | + "horizontal_dims": ["x"], |
| 173 | + "fill_value_key": "_FillValue", # should be the name for fill_value |
| 174 | + }, |
| 175 | + } |
181 | 176 |
|
182 | | - df = create_stats_dataframe( |
183 | | - input_dir=".", |
184 | | - file_id=[["Test data", self.nc_file_glob]], |
185 | | - stats_file_name=STATS_FILE_NAMES, |
186 | | - file_specification=file_specification, |
187 | | - ) |
| 177 | + # Call the function to generate the statistics dataframe |
| 178 | + df = create_stats_dataframe( |
| 179 | + input_dir=tmp_dir, |
| 180 | + file_id=[["Test data", NC_FILE_GLOB]], |
| 181 | + stats_file_name=STATS_FILE_NAMES, |
| 182 | + file_specification=file_specification, |
| 183 | + ) |
188 | 184 |
|
189 | | - # check that the mean/max/min are correct |
190 | | - expected = np.array( |
191 | | - [ |
192 | | - [1.0, 2.0, 0.0, 1.0, 2.0, 0.0, 1.0, 2.0, 0.0], |
193 | | - [1.0, 2.0, 0.0, 1.0, 2.0, 0.0, 1.0, 2.0, 0.0], |
194 | | - [1.0, 2.0, 0.0, 1.0, 2.0, 0.0, 1.0, 2.0, 0.0], |
195 | | - [1.0, 2.0, 0.0, 1.0, 2.0, 0.0, 1.0, 2.0, 0.0], |
196 | | - [1.0, 2.0, 0.0, 1.0, 2.0, 0.0, 1.0, 2.0, 0.0], |
197 | | - [2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0], |
198 | | - [3.0, 4.0, 2.0, 3.0, 4.0, 2.0, 3.0, 4.0, 2.0], |
199 | | - ] |
200 | | - ) |
| 185 | + # Define the expected values for comparison |
| 186 | + expected = np.array( |
| 187 | + [ |
| 188 | + [1.0, 2.0, 0.0, 1.0, 2.0, 0.0, 1.0, 2.0, 0.0], |
| 189 | + [1.0, 2.0, 0.0, 1.0, 2.0, 0.0, 1.0, 2.0, 0.0], |
| 190 | + [1.0, 2.0, 0.0, 1.0, 2.0, 0.0, 1.0, 2.0, 0.0], |
| 191 | + [1.0, 2.0, 0.0, 1.0, 2.0, 0.0, 1.0, 2.0, 0.0], |
| 192 | + [1.0, 2.0, 0.0, 1.0, 2.0, 0.0, 1.0, 2.0, 0.0], |
| 193 | + [2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0], |
| 194 | + [3.0, 4.0, 2.0, 3.0, 4.0, 2.0, 3.0, 4.0, 2.0], |
| 195 | + ] |
| 196 | + ) |
201 | 197 |
|
202 | | - self.assertTrue( |
203 | | - np.array_equal(df.values, expected), |
204 | | - f"stats dataframe incorrect. Difference:\n{df.values == expected}", |
205 | | - ) |
| 198 | + # Check that the dataframe matches the expected values |
| 199 | + assert np.array_equal( |
| 200 | + df.values, expected |
| 201 | + ), f"Stats dataframe incorrect. Difference:\n{df.values == expected}" |
206 | 202 |
|
207 | 203 |
|
208 | 204 | class TestStatsCsv(unittest.TestCase): |
|
0 commit comments