47
47
SpeedAccuracyMode ,
48
48
SupportedMetrics ,
49
49
SupportedModels ,
50
- BACKTEST_REPORT_NAME
50
+ BACKTEST_REPORT_NAME ,
51
51
)
52
52
from ..operator_config import ForecastOperatorConfig , ForecastOperatorSpec
53
53
from .forecast_datasets import ForecastDatasets
@@ -259,7 +259,11 @@ def generate_report(self):
259
259
output_dir = self .spec .output_directory .url
260
260
file_path = f"{ output_dir } /{ BACKTEST_REPORT_NAME } "
261
261
if self .spec .model == AUTO_SELECT :
262
- backtest_sections .append (rc .Heading ("Auto-Select Backtesting and Performance Metrics" , level = 2 ))
262
+ backtest_sections .append (
263
+ rc .Heading (
264
+ "Auto-Select Backtesting and Performance Metrics" , level = 2
265
+ )
266
+ )
263
267
if not os .path .exists (file_path ):
264
268
failure_msg = rc .Text (
265
269
"auto-select could not be executed. Please check the "
@@ -268,15 +272,23 @@ def generate_report(self):
268
272
backtest_sections .append (failure_msg )
269
273
else :
270
274
backtest_stats = pd .read_csv (file_path )
271
- model_metric_map = backtest_stats .drop (columns = ['metric' , 'backtest' ])
272
- average_dict = {k : round (v , 4 ) for k , v in model_metric_map .mean ().to_dict ().items ()}
275
+ model_metric_map = backtest_stats .drop (
276
+ columns = ["metric" , "backtest" ]
277
+ )
278
+ average_dict = {
279
+ k : round (v , 4 )
280
+ for k , v in model_metric_map .mean ().to_dict ().items ()
281
+ }
273
282
best_model = min (average_dict , key = average_dict .get )
274
283
summary_text = rc .Text (
275
284
f"Overall, the average { self .spec .metric } scores for the models are { average_dict } , with"
276
- f" { best_model } being identified as the top-performing model during backtesting." )
285
+ f" { best_model } being identified as the top-performing model during backtesting."
286
+ )
277
287
backtest_table = rc .DataTable (backtest_stats , index = True )
278
288
liner_plot = get_auto_select_plot (backtest_stats )
279
- backtest_sections .extend ([backtest_table , summary_text , liner_plot ])
289
+ backtest_sections .extend (
290
+ [backtest_table , summary_text , liner_plot ]
291
+ )
280
292
281
293
forecast_plots = []
282
294
if len (self .forecast_output .list_series_ids ()) > 0 :
@@ -643,6 +655,12 @@ def _save_model(self, output_dir, storage_options):
643
655
"Please run `python3 -m pip install shap` to install the required dependencies for model explanation."
644
656
),
645
657
)
658
+ @runtime_dependency (
659
+ module = "automlx" ,
660
+ err_msg = (
661
+ "Please run `python3 -m pip install automlx` to install the required dependencies for model explanation."
662
+ ),
663
+ )
646
664
def explain_model (self ):
647
665
"""
648
666
Generates an explanation for the model by using the SHAP (Shapley Additive exPlanations) library.
@@ -668,7 +686,44 @@ def explain_model(self):
668
686
for s_id , data_i in self .datasets .get_data_by_series (
669
687
include_horizon = False
670
688
).items ():
671
- if s_id in self .models :
689
+ if (
690
+ self .spec .model == SupportedModels .AutoMLX
691
+ and self .spec .explanations_accuracy_mode == SpeedAccuracyMode .AUTOMLX
692
+ ):
693
+ import automlx
694
+
695
+ explainer = automlx .MLExplainer (
696
+ self .models [s_id ],
697
+ self .datasets .additional_data .get_data_for_series (series_id = s_id )
698
+ .drop (self .spec .datetime_column .name , axis = 1 )
699
+ .head (- self .spec .horizon )
700
+ if self .spec .additional_data
701
+ else None ,
702
+ pd .DataFrame (data_i [self .spec .target_column ]),
703
+ task = "forecasting" ,
704
+ )
705
+
706
+ explanations = explainer .explain_prediction (
707
+ X = self .datasets .additional_data .get_data_for_series (series_id = s_id )
708
+ .drop (self .spec .datetime_column .name , axis = 1 )
709
+ .tail (self .spec .horizon )
710
+ if self .spec .additional_data
711
+ else None ,
712
+ forecast_timepoints = list (range (self .spec .horizon + 1 )),
713
+ )
714
+
715
+ explanations_df = pd .concat (
716
+ [exp .to_dataframe () for exp in explanations ]
717
+ )
718
+ explanations_df ["row" ] = explanations_df .groupby ("Feature" ).cumcount ()
719
+ explanations_df = explanations_df .pivot (
720
+ index = "row" , columns = "Feature" , values = "Attribution"
721
+ )
722
+ explanations_df = explanations_df .reset_index (drop = True )
723
+ # explanations_df[self.spec.datetime_column.name]=self.datasets.additional_data.get_data_for_series(series_id=s_id).tail(self.spec.horizon)[self.spec.datetime_column.name].reset_index(drop=True)
724
+ self .local_explanation [s_id ] = explanations_df
725
+
726
+ elif s_id in self .models :
672
727
explain_predict_fn = self .get_explain_predict_fn (series_id = s_id )
673
728
data_trimmed = data_i .tail (
674
729
max (int (len (data_i ) * ratio ), 5 )
@@ -699,6 +754,14 @@ def explain_model(self):
699
754
logger .warn (
700
755
"No explanations generated. Ensure that additional data has been provided."
701
756
)
757
+ elif (
758
+ self .spec .model == SupportedModels .AutoMLX
759
+ and self .spec .explanations_accuracy_mode
760
+ == SpeedAccuracyMode .AUTOMLX
761
+ ):
762
+ logger .warning (
763
+ "Global explanations not available for AutoMLX models with inherent explainability"
764
+ )
702
765
else :
703
766
self .global_explanation [s_id ] = dict (
704
767
zip (
0 commit comments