30
30
31
31
MODELS = [
32
32
"arima" ,
33
- "automlx" ,
33
+ # "automlx",
34
34
"prophet" ,
35
35
"neuralprophet" ,
36
36
]
@@ -140,9 +140,10 @@ def setup_test_data(model, freq, num_series, horizon=5, num_points=100, seed=42,
140
140
@pytest .mark .parametrize ("model" , MODELS )
141
141
@pytest .mark .parametrize ("freq" , ["D" , "W" , "M" , "H" , "T" ])
142
142
@pytest .mark .parametrize ("num_series" , [1 , 3 ])
143
- def test_explanations_output (model , freq , num_series ):
143
+ def test_explanations_output_and_columns (model , freq , num_series ):
144
144
"""
145
145
Test the global and local explanations for different models, frequencies, and number of series.
146
+ Also test that the explanation output contains all the columns from the additional dataset.
146
147
147
148
Parameters:
148
149
- model: The forecasting model to use.
@@ -156,7 +157,7 @@ def test_explanations_output(model, freq, num_series):
156
157
if model == "neuralprophet" :
157
158
pytest .skip ("Skipping 'neuralprophet' model as it takes a long time to finish" )
158
159
159
- _ , _ , operator_config = setup_test_data (model , freq , num_series )
160
+ _ , additional , operator_config = setup_test_data (model , freq , num_series )
160
161
161
162
results = forecast_operate (operator_config )
162
163
@@ -176,33 +177,6 @@ def test_explanations_output(model, freq, num_series):
176
177
not (local_explanations == 0 ).all ().all ()
177
178
), "Local explanations contain only 0 values"
178
179
179
-
180
- @pytest .mark .parametrize ("model" , MODELS )
181
- @pytest .mark .parametrize ("freq" , ["D" , "W" , "M" , "H" , "T" ])
182
- @pytest .mark .parametrize ("num_series" , [1 , 3 ])
183
- def test_explanations_columns (model , freq , num_series ):
184
- """
185
- Test that the explanation output contains all the columns from the additional dataset.
186
-
187
- Parameters:
188
- - model: The forecasting model to use.
189
- - freq: Frequency of the datetime column.
190
- - num_series: Number of different time series to generate.
191
- """
192
- if model == "automlx" and freq == "T" :
193
- pytest .skip (
194
- "Skipping 'T' frequency for 'automlx' model. automlx requires data with a frequency of at least one hour"
195
- )
196
- if model == "neuralprophet" :
197
- pytest .skip ("Skipping 'neuralprophet' model as it takes a long time to finish" )
198
-
199
- _ , additional , operator_config = setup_test_data (model , freq , num_series )
200
-
201
- results = forecast_operate (operator_config )
202
-
203
- global_explanations = results .get_global_explanations ()
204
- local_explanations = results .get_local_explanations ()
205
-
206
180
additional_columns = additional .columns .tolist ()
207
181
for column in additional_columns :
208
182
assert column in global_explanations .columns , f"Column { column } missing in global explanations"
0 commit comments