@@ -649,18 +649,19 @@ def _save_model(self, output_dir, storage_options):
649
649
storage_options = storage_options ,
650
650
)
651
651
652
+ def _validate_automlx_explanation_mode (self ):
653
+ if self .spec .model != SupportedModels .AutoMLX and self .spec .explanations_accuracy_mode == SpeedAccuracyMode .AUTOMLX :
654
+ raise ValueError (
655
+ "AUTOMLX explanation accuracy mode is only supported for AutoMLX models. "
656
+ "Please select mode other than AUTOMLX from the available explanations_accuracy_mode options"
657
+ )
658
+
652
659
@runtime_dependency (
653
660
module = "shap" ,
654
661
err_msg = (
655
662
"Please run `python3 -m pip install shap` to install the required dependencies for model explanation."
656
663
),
657
664
)
658
- @runtime_dependency (
659
- module = "automlx" ,
660
- err_msg = (
661
- "Please run `python3 -m pip install automlx` to install the required dependencies for model explanation."
662
- ),
663
- )
664
665
def explain_model (self ):
665
666
"""
666
667
Generates an explanation for the model by using the SHAP (Shapley Additive exPlanations) library.
@@ -683,53 +684,13 @@ def explain_model(self):
683
684
)
684
685
ratio = SpeedAccuracyMode .ratio [self .spec .explanations_accuracy_mode ]
685
686
687
+ # validate the automlx mode is use for automlx model
688
+ self ._validate_automlx_explanation_mode ()
689
+
686
690
for s_id , data_i in self .datasets .get_data_by_series (
687
691
include_horizon = False
688
692
).items ():
689
- if (
690
- self .spec .model == SupportedModels .AutoMLX
691
- and self .spec .explanations_accuracy_mode == SpeedAccuracyMode .AUTOMLX
692
- ):
693
- import automlx
694
-
695
- explainer = automlx .MLExplainer (
696
- self .models [s_id ],
697
- self .datasets .additional_data .get_data_for_series (series_id = s_id )
698
- .drop (self .spec .datetime_column .name , axis = 1 )
699
- .head (- self .spec .horizon )
700
- if self .spec .additional_data
701
- else None ,
702
- pd .DataFrame (data_i [self .spec .target_column ]),
703
- task = "forecasting" ,
704
- )
705
-
706
- explanations = explainer .explain_prediction (
707
- X = self .datasets .additional_data .get_data_for_series (series_id = s_id )
708
- .drop (self .spec .datetime_column .name , axis = 1 )
709
- .tail (self .spec .horizon )
710
- if self .spec .additional_data
711
- else None ,
712
- forecast_timepoints = list (range (self .spec .horizon + 1 )),
713
- )
714
-
715
- explanations_df = pd .concat (
716
- [exp .to_dataframe () for exp in explanations ]
717
- )
718
- explanations_df ["row" ] = explanations_df .groupby ("Feature" ).cumcount ()
719
- explanations_df = explanations_df .pivot (
720
- index = "row" , columns = "Feature" , values = "Attribution"
721
- )
722
- explanations_df = explanations_df .reset_index (drop = True )
723
- self .local_explanation [s_id ] = explanations_df
724
- elif (
725
- self .spec .explanations_accuracy_mode == SpeedAccuracyMode .AUTOMLX
726
- and self .spec .model != SupportedModels .AutoMLX
727
- ):
728
- raise ValueError (
729
- "AUTOMLX explanation accuracy mode is only supported for AutoMLX models. "
730
- "Please select mode other than AUTOMLX from the available explanations_accuracy_mode options"
731
- )
732
- elif s_id in self .models :
693
+ if s_id in self .models :
733
694
explain_predict_fn = self .get_explain_predict_fn (series_id = s_id )
734
695
data_trimmed = data_i .tail (
735
696
max (int (len (data_i ) * ratio ), 5 )
0 commit comments