|
46 | 46 | AUTO_SELECT,
|
47 | 47 | BACKTEST_REPORT_NAME,
|
48 | 48 | SUMMARY_METRICS_HORIZON_LIMIT,
|
| 49 | + ForecastOutputColumns, |
49 | 50 | SpeedAccuracyMode,
|
50 | 51 | SupportedMetrics,
|
51 | 52 | SupportedModels,
|
@@ -742,43 +743,60 @@ def explain_model(self):
|
742 | 743 | include_horizon=False
|
743 | 744 | ).items():
|
744 | 745 | if s_id in self.models:
|
745 |
| - explain_predict_fn = self.get_explain_predict_fn(series_id=s_id) |
746 |
| - data_trimmed = data_i.tail( |
747 |
| - max(int(len(data_i) * ratio), 5) |
748 |
| - ).reset_index(drop=True) |
749 |
| - data_trimmed[datetime_col_name] = data_trimmed[datetime_col_name].apply( |
750 |
| - lambda x: x.timestamp() |
751 |
| - ) |
752 |
| - |
753 |
| - # Explainer fails when boolean columns are passed |
754 |
| - |
755 |
| - _, data_trimmed_encoded = _label_encode_dataframe( |
756 |
| - data_trimmed, |
757 |
| - no_encode={datetime_col_name, self.original_target_column}, |
758 |
| - ) |
759 |
| - |
760 |
| - kernel_explnr = PermutationExplainer( |
761 |
| - model=explain_predict_fn, masker=data_trimmed_encoded |
762 |
| - ) |
763 |
| - kernel_explnr_vals = kernel_explnr.shap_values(data_trimmed_encoded) |
764 |
| - exp_end_time = time.time() |
765 |
| - global_ex_time = global_ex_time + exp_end_time - exp_start_time |
766 |
| - self.local_explainer( |
767 |
| - kernel_explnr, series_id=s_id, datetime_col_name=datetime_col_name |
768 |
| - ) |
769 |
| - local_ex_time = local_ex_time + time.time() - exp_end_time |
| 746 | + try: |
| 747 | + explain_predict_fn = self.get_explain_predict_fn(series_id=s_id) |
| 748 | + data_trimmed = data_i.tail( |
| 749 | + max(int(len(data_i) * ratio), 5) |
| 750 | + ).reset_index(drop=True) |
| 751 | + data_trimmed[datetime_col_name] = data_trimmed[ |
| 752 | + datetime_col_name |
| 753 | + ].apply(lambda x: x.timestamp()) |
| 754 | + |
| 755 | + # Explainer fails when boolean columns are passed |
| 756 | + |
| 757 | + _, data_trimmed_encoded = _label_encode_dataframe( |
| 758 | + data_trimmed, |
| 759 | + no_encode={datetime_col_name, self.original_target_column}, |
| 760 | + ) |
770 | 761 |
|
771 |
| - if not len(kernel_explnr_vals): |
772 |
| - logger.warn( |
773 |
| - "No explanations generated. Ensure that additional data has been provided." |
| 762 | + kernel_explnr = PermutationExplainer( |
| 763 | + model=explain_predict_fn, masker=data_trimmed_encoded |
774 | 764 | )
|
775 |
| - else: |
776 |
| - self.global_explanation[s_id] = dict( |
777 |
| - zip( |
778 |
| - data_trimmed.columns[1:], |
779 |
| - np.average(np.absolute(kernel_explnr_vals[:, 1:]), axis=0), |
780 |
| - ) |
| 765 | + kernel_explnr_vals = kernel_explnr.shap_values(data_trimmed_encoded) |
| 766 | + exp_end_time = time.time() |
| 767 | + global_ex_time = global_ex_time + exp_end_time - exp_start_time |
| 768 | + self.local_explainer( |
| 769 | + kernel_explnr, |
| 770 | + series_id=s_id, |
| 771 | + datetime_col_name=datetime_col_name, |
781 | 772 | )
|
| 773 | + local_ex_time = local_ex_time + time.time() - exp_end_time |
| 774 | + |
| 775 | + if not len(kernel_explnr_vals): |
| 776 | + logger.warn( |
| 777 | + "No explanations generated. Ensure that additional data has been provided." |
| 778 | + ) |
| 779 | + else: |
| 780 | + self.global_explanation[s_id] = dict( |
| 781 | + zip( |
| 782 | + data_trimmed.columns[1:], |
| 783 | + np.average( |
| 784 | + np.absolute(kernel_explnr_vals[:, 1:]), axis=0 |
| 785 | + ), |
| 786 | + ) |
| 787 | + ) |
| 788 | + except Exception as e: |
| 789 | + if s_id in self.errors_dict: |
| 790 | + self.errors_dict[s_id]["explainer_error"] = str(e) |
| 791 | + self.errors_dict[s_id]["explainer_error_trace"] = ( |
| 792 | + traceback.format_exc() |
| 793 | + ) |
| 794 | + else: |
| 795 | + self.errors_dict[s_id] = { |
| 796 | + "model_name": self.spec.model, |
| 797 | + "explainer_error": str(e), |
| 798 | + "explainer_error_trace": traceback.format_exc(), |
| 799 | + } |
782 | 800 | else:
|
783 | 801 | logger.warn(
|
784 | 802 | f"Skipping explanations for {s_id}, as forecast was not generated."
|
@@ -815,6 +833,13 @@ def local_explainer(self, kernel_explainer, series_id, datetime_col_name) -> Non
|
815 | 833 | local_kernel_explnr_df = pd.DataFrame(
|
816 | 834 | local_kernel_explnr_vals, columns=data.columns
|
817 | 835 | )
|
| 836 | + |
| 837 | + # Add date column to local explanation DataFrame |
| 838 | + local_kernel_explnr_df[ForecastOutputColumns.DATE] = ( |
| 839 | + self.datasets.get_horizon_at_series( |
| 840 | + s_id=series_id |
| 841 | + )[self.spec.datetime_column.name].reset_index(drop=True) |
| 842 | + ) |
818 | 843 | self.local_explanation[series_id] = local_kernel_explnr_df
|
819 | 844 |
|
820 | 845 | def get_explain_predict_fn(self, series_id, fcst_col_name="yhat"):
|
|
0 commit comments