@@ -150,12 +150,22 @@ def _train_model(self, i, s_id, df, model_kwargs):
150
150
logger .debug (forecast .tail ())
151
151
152
152
# TODO; could also extract trend and seasonality?
153
- cols_to_read = filter (
154
- lambda x : x .startswith ("future_regressor" ), forecast .columns
153
+ cols_to_read = set (
154
+ forecast .columns [forecast .columns .str .startswith ("future_regressor" )]
155
+ + ["ds" , "trend" ]
155
156
)
156
- self .explanations_info [s_id ] = (
157
- forecast [cols_to_read ].rename ({"ds" : "Date" }, axis = 1 ).set_index ("Date" )
157
+ cols_to_read = cols_to_read - {
158
+ "future_regressors_additive" ,
159
+ "future_regressors_multiplicative" ,
160
+ }
161
+ combine_terms = cols_to_read - set (self .accepted_regressors [s_id ])
162
+ temp_df = (
163
+ forecast [list (cols_to_read )]
164
+ .rename ({"ds" : "Date" }, axis = 1 )
165
+ .set_index ("Date" )
158
166
)
167
+ temp_df [self .spec .target_column ] = temp_df [combine_terms ].sum (axis = 1 )
168
+ self .explanations_info [s_id ] = temp_df .drop (combine_terms , axis = 1 )
159
169
160
170
self .outputs [s_id ] = forecast
161
171
self .forecast_output .populate_series_output (
@@ -457,19 +467,14 @@ def explain_model(self):
457
467
for s_id , expl_df in self .explanations_info .items ():
458
468
expl_df = expl_df .rename (rename_cols , axis = 1 )
459
469
# Local Expl
460
- self .local_explanation [s_id ] = self .get_horizon (expl_df ).drop (
461
- ["future_regressors_additive" ], axis = 1
462
- )
470
+ self .local_explanation [s_id ] = self .get_horizon (expl_df )
463
471
self .local_explanation [s_id ]["Series" ] = s_id
464
472
self .local_explanation [s_id ].index .rename (self .dt_column_name , inplace = True )
465
473
# Global Expl
466
474
g_expl = self .drop_horizon (expl_df ).mean ()
467
475
g_expl .name = s_id
468
476
global_expl .append (g_expl )
469
477
self .global_explanation = pd .concat (global_expl , axis = 1 )
470
- self .global_explanation = self .global_explanation .drop (
471
- index = ["future_regressors_additive" ], axis = 0
472
- )
473
478
self .formatted_global_explanation = (
474
479
self .global_explanation / self .global_explanation .sum (axis = 0 ) * 100
475
480
)
0 commit comments