@@ -666,10 +666,11 @@ def explain_model(self):
666
666
lambda x : x .timestamp ()
667
667
)
668
668
669
- # Explainer fails when boolean columns are passed
670
- _ , data_trimmed_encoded = _label_encode_dataframe (
671
- data_trimmed , no_encode = {datetime_col_name , self .original_target_column }
672
- )
669
+ # Explainer fails when boolean columns are passed for arima
670
+ if self .spec .model == SupportedModels .Arima :
671
+ _ , data_trimmed_encoded = _label_encode_dataframe (
672
+ data_trimmed , no_encode = {datetime_col_name , self .original_target_column }
673
+ )
673
674
674
675
kernel_explnr = PermutationExplainer (
675
676
model = explain_predict_fn , masker = data_trimmed_encoded
@@ -714,15 +715,17 @@ def local_explainer(self, kernel_explainer, series_id, datetime_col_name) -> Non
714
715
kernel_explainer: The kernel explainer object to use for generating explanations.
715
716
"""
716
717
data = self .datasets .get_horizon_at_series (s_id = series_id )
718
+ # columns that were dropped in train_model in arima, should be dropped here as well
717
719
if self .spec .model == SupportedModels .Arima and series_id in self .constant_cols :
718
720
data = data .drop (columns = self .constant_cols [series_id ])
719
721
data [datetime_col_name ] = datetime_to_seconds (data [datetime_col_name ])
720
722
data = data .reset_index (drop = True )
721
723
722
- # Explainer fails when boolean columns are passed
723
- _ , data = _label_encode_dataframe (
724
- data , no_encode = {datetime_col_name , self .original_target_column }
725
- )
724
+ # Explainer fails when boolean columns are passed for arima
725
+ if self .spec .model == SupportedModels .Arima :
726
+ _ , data = _label_encode_dataframe (
727
+ data , no_encode = {datetime_col_name , self .original_target_column }
728
+ )
726
729
# Generate local SHAP values using the kernel explainer
727
730
local_kernel_explnr_vals = kernel_explainer .shap_values (data )
728
731
0 commit comments