@@ -96,7 +96,9 @@ def test_generate_datasets():
96
96
assert "target" not in additional_columns
97
97
98
98
99
- def setup_test_data (model , freq , num_series , horizon = 5 , num_points = 100 , seed = 42 , include_additional = True ):
99
+ def setup_test_data (
100
+ model , freq , num_series , horizon = 5 , num_points = 100 , seed = 42 , include_additional = True
101
+ ):
100
102
"""
101
103
Setup test data for the given parameters.
102
104
@@ -113,17 +115,21 @@ def setup_test_data(model, freq, num_series, horizon=5, num_points=100, seed=42,
113
115
- Tuple containing primary, additional datasets and the operator configuration.
114
116
"""
115
117
primary , additional , _ , _ = generate_datasets (
116
- freq = freq , horizon = horizon , num_series = num_series , num_points = num_points , seed = seed
118
+ freq = freq ,
119
+ horizon = horizon ,
120
+ num_series = num_series ,
121
+ num_points = num_points ,
122
+ seed = seed ,
117
123
)
118
124
119
125
yaml_i = deepcopy (TEMPLATE_YAML )
120
126
yaml_i ["spec" ]["historical_data" ].pop ("url" )
121
127
yaml_i ["spec" ]["historical_data" ]["data" ] = primary
122
128
yaml_i ["spec" ]["historical_data" ]["format" ] = "pandas"
123
-
129
+
124
130
if include_additional :
125
131
yaml_i ["spec" ]["additional_data" ] = {"data" : additional , "format" : "pandas" }
126
-
132
+
127
133
yaml_i ["spec" ]["model" ] = model
128
134
yaml_i ["spec" ]["target_column" ] = "target"
129
135
yaml_i ["spec" ]["datetime_column" ]["name" ] = "ds"
@@ -177,10 +183,18 @@ def test_explanations_output_and_columns(model, freq, num_series):
177
183
not (local_explanations == 0 ).all ().all ()
178
184
), "Local explanations contain only 0 values"
179
185
180
- additional_columns = additional .columns .tolist ()
186
+ additional_columns = list (
187
+ set (additional .columns .tolist ())
188
+ - set (operator_config .spec .target_category_columns )
189
+ - {operator_config .spec .datetime_column .name }
190
+ )
181
191
for column in additional_columns :
182
- assert column in global_explanations .columns , f"Column { column } missing in global explanations"
183
- assert column in local_explanations .columns , f"Column { column } missing in local explanations"
192
+ assert (
193
+ column in global_explanations .T .columns
194
+ ), f"Column { column } missing in global explanations"
195
+ assert (
196
+ column in local_explanations .columns
197
+ ), f"Column { column } missing in local explanations"
184
198
185
199
186
200
@pytest .mark .parametrize ("model" , MODELS )
@@ -208,11 +222,19 @@ def test_explanations_filenames(model, num_series):
208
222
209
223
results = forecast_operate (operator_config )
210
224
211
- global_explanation_path = os .path .join (output_directory , global_explanation_filename )
212
- local_explanation_path = os .path .join (output_directory , local_explanation_filename )
225
+ global_explanation_path = os .path .join (
226
+ output_directory , global_explanation_filename
227
+ )
228
+ local_explanation_path = os .path .join (
229
+ output_directory , local_explanation_filename
230
+ )
213
231
214
- assert os .path .exists (global_explanation_path ), f"Global explanation file not found at { global_explanation_path } "
215
- assert os .path .exists (local_explanation_path ), f"Local explanation file not found at { local_explanation_path } "
232
+ assert os .path .exists (
233
+ global_explanation_path
234
+ ), f"Global explanation file not found at { global_explanation_path } "
235
+ assert os .path .exists (
236
+ local_explanation_path
237
+ ), f"Local explanation file not found at { local_explanation_path } "
216
238
217
239
218
240
@pytest .mark .parametrize ("model" , MODELS )
@@ -231,19 +253,23 @@ def test_explanations_no_additional_data(model, num_series, caplog):
231
253
with tempfile .TemporaryDirectory () as tmpdirname :
232
254
output_directory = tmpdirname
233
255
234
- _ , _ , operator_config = setup_test_data (model , "D" , num_series , include_additional = False )
256
+ _ , _ , operator_config = setup_test_data (
257
+ model , "D" , num_series , include_additional = False
258
+ )
235
259
operator_config .spec .output_directory .url = output_directory
236
260
237
261
forecast_operate (operator_config )
238
262
239
263
assert any (
240
264
"Unable to generate explanations as there is no additional data passed in. Either set generate_explanations to False, or pass in additional data."
241
- in message for message in caplog .messages
265
+ in message
266
+ for message in caplog .messages
242
267
), "Required warning message not found in logs"
243
268
244
269
245
270
MODES = ["BALANCED" , "HIGH_ACCURACY" ]
246
271
272
+
247
273
@pytest .mark .skip (reason = "Disabled by default. Enable to run this test." )
248
274
@pytest .mark .parametrize ("mode" , MODES )
249
275
@pytest .mark .parametrize ("model" , MODELS )
@@ -269,11 +295,19 @@ def test_explanations_accuracy_mode(mode, model, num_series):
269
295
270
296
results = forecast_operate (operator_config )
271
297
272
- global_explanation_path = os .path .join (output_directory , operator_config .spec .global_explanation_filename )
273
- local_explanation_path = os .path .join (output_directory , operator_config .spec .local_explanation_filename )
298
+ global_explanation_path = os .path .join (
299
+ output_directory , operator_config .spec .global_explanation_filename
300
+ )
301
+ local_explanation_path = os .path .join (
302
+ output_directory , operator_config .spec .local_explanation_filename
303
+ )
274
304
275
- assert os .path .exists (global_explanation_path ), f"Global explanation file not found at { global_explanation_path } "
276
- assert os .path .exists (local_explanation_path ), f"Local explanation file not found at { local_explanation_path } "
305
+ assert os .path .exists (
306
+ global_explanation_path
307
+ ), f"Global explanation file not found at { global_explanation_path } "
308
+ assert os .path .exists (
309
+ local_explanation_path
310
+ ), f"Local explanation file not found at { local_explanation_path } "
277
311
278
312
279
313
@pytest .mark .parametrize ("model" , MODELS )
0 commit comments