22
22
from ..const import (
23
23
DEFAULT_TRIALS ,
24
24
PROPHET_INTERNAL_DATE_COL ,
25
- ForecastOutputColumns ,
26
25
SupportedModels ,
27
26
)
28
27
from .base_model import ForecastOperatorBaseModel
@@ -123,6 +122,14 @@ def _train_model(self, i, series_id, df, model_kwargs):
123
122
upper_bound = self .get_horizon (forecast ["yhat_upper" ]).values ,
124
123
lower_bound = self .get_horizon (forecast ["yhat_lower" ]).values ,
125
124
)
125
+ # Get all features that make up the forecast. Exclude CI (upper/lower) and drop yhat ([:-1])
126
+ core_columns = forecast .columns [
127
+ ~ forecast .columns .str .endswith ("_lower" )
128
+ & ~ forecast .columns .str .endswith ("_upper" )
129
+ ][:- 1 ]
130
+ self .explanations_info [series_id ] = (
131
+ forecast [core_columns ].rename ({"ds" : "Date" }, axis = 1 ).set_index ("Date" )
132
+ )
126
133
127
134
self .models [series_id ] = {}
128
135
self .models [series_id ]["model" ] = model
@@ -151,6 +158,7 @@ def _build_model(self) -> pd.DataFrame:
151
158
full_data_dict = self .datasets .get_data_by_series ()
152
159
self .models = {}
153
160
self .outputs = {}
161
+ self .explanations_info = {}
154
162
self .additional_regressors = self .datasets .get_additional_data_column_names ()
155
163
model_kwargs = self .set_kwargs ()
156
164
self .forecast_output = ForecastOutput (
@@ -257,6 +265,25 @@ def objective(trial):
257
265
model_kwargs_i = study .best_params
258
266
return model_kwargs_i
259
267
268
+ def explain_model (self ):
269
+ self .local_explanation = {}
270
+ global_expl = []
271
+
272
+ for s_id , expl_df in self .explanations_info .items ():
273
+ # Local Expl
274
+ self .local_explanation [s_id ] = self .get_horizon (expl_df )
275
+ self .local_explanation [s_id ]["Series" ] = s_id
276
+ self .local_explanation [s_id ].index .rename (self .dt_column_name , inplace = True )
277
+ # Global Expl
278
+ g_expl = self .drop_horizon (expl_df ).mean ()
279
+ g_expl .name = s_id
280
+ global_expl .append (g_expl )
281
+ self .global_explanation = pd .concat (global_expl , axis = 1 )
282
+ self .formatted_global_explanation = (
283
+ self .global_explanation / self .global_explanation .sum (axis = 0 ) * 100
284
+ )
285
+ self .formatted_local_explanation = pd .concat (self .local_explanation .values ())
286
+
260
287
def _generate_report (self ):
261
288
import report_creator as rc
262
289
from prophet .plot import add_changepoints_to_plot
@@ -335,22 +362,6 @@ def _generate_report(self):
335
362
# If the key is present, call the "explain_model" method
336
363
self .explain_model ()
337
364
338
- # Convert the global explanation data to a DataFrame
339
- global_explanation_df = pd .DataFrame (self .global_explanation )
340
-
341
- self .formatted_global_explanation = (
342
- global_explanation_df / global_explanation_df .sum (axis = 0 ) * 100
343
- )
344
-
345
- aggregate_local_explanations = pd .DataFrame ()
346
- for s_id , local_ex_df in self .local_explanation .items ():
347
- local_ex_df_copy = local_ex_df .copy ()
348
- local_ex_df_copy [ForecastOutputColumns .SERIES ] = s_id
349
- aggregate_local_explanations = pd .concat (
350
- [aggregate_local_explanations , local_ex_df_copy ], axis = 0
351
- )
352
- self .formatted_local_explanation = aggregate_local_explanations
353
-
354
365
if not self .target_cat_col :
355
366
self .formatted_global_explanation = (
356
367
self .formatted_global_explanation .rename (
@@ -364,7 +375,7 @@ def _generate_report(self):
364
375
365
376
# Create a markdown section for the global explainability
366
377
global_explanation_section = rc .Block (
367
- rc .Heading ("Global Explanation of Models " , level = 2 ),
378
+ rc .Heading ("Global Explainability " , level = 2 ),
368
379
rc .Text (
369
380
"The following tables provide the feature attribution for the global explainability."
370
381
),
@@ -373,7 +384,7 @@ def _generate_report(self):
373
384
374
385
blocks = [
375
386
rc .DataTable (
376
- local_ex_df .div ( local_ex_df . abs (). sum ( axis = 1 ) , axis = 0 ) * 100 ,
387
+ local_ex_df .drop ( "Series" , axis = 1 ) ,
377
388
label = s_id if self .target_cat_col else None ,
378
389
index = True ,
379
390
)
@@ -393,6 +404,8 @@ def _generate_report(self):
393
404
# Do not fail the whole run due to explanations failure
394
405
logger .warning (f"Failed to generate Explanations with error: { e } ." )
395
406
logger .debug (f"Full Traceback: { traceback .format_exc ()} " )
407
+ self .errors_dict ["explainer_error" ] = str (e )
408
+ self .errors_dict ["explainer_error_error" ] = traceback .format_exc ()
396
409
397
410
model_description = rc .Text (
398
411
"""Prophet is a procedure for forecasting time series data based on an additive model where non-linear trends are fit with yearly, weekly, and daily seasonality, plus holiday effects. It works best with time series that have strong seasonal effects and several seasons of historical data. Prophet is robust to missing data and shifts in the trend, and typically handles outliers well."""
0 commit comments