|
46 | 46 | AUTO_SELECT,
|
47 | 47 | BACKTEST_REPORT_NAME,
|
48 | 48 | SUMMARY_METRICS_HORIZON_LIMIT,
|
| 49 | + ForecastOutputColumns, |
49 | 50 | SpeedAccuracyMode,
|
50 | 51 | SupportedMetrics,
|
51 | 52 | SupportedModels,
|
@@ -743,43 +744,60 @@ def explain_model(self):
|
743 | 744 | include_horizon=False
|
744 | 745 | ).items():
|
745 | 746 | if s_id in self.models:
|
746 |
| - explain_predict_fn = self.get_explain_predict_fn(series_id=s_id) |
747 |
| - data_trimmed = data_i.tail( |
748 |
| - max(int(len(data_i) * ratio), 5) |
749 |
| - ).reset_index(drop=True) |
750 |
| - data_trimmed[datetime_col_name] = data_trimmed[datetime_col_name].apply( |
751 |
| - lambda x: x.timestamp() |
752 |
| - ) |
753 |
| - |
754 |
| - # Explainer fails when boolean columns are passed |
755 |
| - |
756 |
| - _, data_trimmed_encoded = _label_encode_dataframe( |
757 |
| - data_trimmed, |
758 |
| - no_encode={datetime_col_name, self.original_target_column}, |
759 |
| - ) |
760 |
| - |
761 |
| - kernel_explnr = PermutationExplainer( |
762 |
| - model=explain_predict_fn, masker=data_trimmed_encoded |
763 |
| - ) |
764 |
| - kernel_explnr_vals = kernel_explnr.shap_values(data_trimmed_encoded) |
765 |
| - exp_end_time = time.time() |
766 |
| - global_ex_time = global_ex_time + exp_end_time - exp_start_time |
767 |
| - self.local_explainer( |
768 |
| - kernel_explnr, series_id=s_id, datetime_col_name=datetime_col_name |
769 |
| - ) |
770 |
| - local_ex_time = local_ex_time + time.time() - exp_end_time |
| 747 | + try: |
| 748 | + explain_predict_fn = self.get_explain_predict_fn(series_id=s_id) |
| 749 | + data_trimmed = data_i.tail( |
| 750 | + max(int(len(data_i) * ratio), 5) |
| 751 | + ).reset_index(drop=True) |
| 752 | + data_trimmed[datetime_col_name] = data_trimmed[ |
| 753 | + datetime_col_name |
| 754 | + ].apply(lambda x: x.timestamp()) |
| 755 | + |
| 756 | + # Explainer fails when boolean columns are passed |
| 757 | + |
| 758 | + _, data_trimmed_encoded = _label_encode_dataframe( |
| 759 | + data_trimmed, |
| 760 | + no_encode={datetime_col_name, self.original_target_column}, |
| 761 | + ) |
771 | 762 |
|
772 |
| - if not len(kernel_explnr_vals): |
773 |
| - logger.warning( |
774 |
| - "No explanations generated. Ensure that additional data has been provided." |
| 763 | + kernel_explnr = PermutationExplainer( |
| 764 | + model=explain_predict_fn, masker=data_trimmed_encoded |
775 | 765 | )
|
776 |
| - else: |
777 |
| - self.global_explanation[s_id] = dict( |
778 |
| - zip( |
779 |
| - data_trimmed.columns[1:], |
780 |
| - np.average(np.absolute(kernel_explnr_vals[:, 1:]), axis=0), |
781 |
| - ) |
| 766 | + kernel_explnr_vals = kernel_explnr.shap_values(data_trimmed_encoded) |
| 767 | + exp_end_time = time.time() |
| 768 | + global_ex_time = global_ex_time + exp_end_time - exp_start_time |
| 769 | + self.local_explainer( |
| 770 | + kernel_explnr, |
| 771 | + series_id=s_id, |
| 772 | + datetime_col_name=datetime_col_name, |
782 | 773 | )
|
| 774 | + local_ex_time = local_ex_time + time.time() - exp_end_time |
| 775 | + |
| 776 | + if not len(kernel_explnr_vals): |
| 777 | + logger.warning( |
| 778 | + "No explanations generated. Ensure that additional data has been provided." |
| 779 | + ) |
| 780 | + else: |
| 781 | + self.global_explanation[s_id] = dict( |
| 782 | + zip( |
| 783 | + data_trimmed.columns[1:], |
| 784 | + np.average( |
| 785 | + np.absolute(kernel_explnr_vals[:, 1:]), axis=0 |
| 786 | + ), |
| 787 | + ) |
| 788 | + ) |
| 789 | + except Exception as e: |
| 790 | + if s_id in self.errors_dict: |
| 791 | + self.errors_dict[s_id]["explainer_error"] = str(e) |
| 792 | + self.errors_dict[s_id]["explainer_error_trace"] = ( |
| 793 | + traceback.format_exc() |
| 794 | + ) |
| 795 | + else: |
| 796 | + self.errors_dict[s_id] = { |
| 797 | + "model_name": self.spec.model, |
| 798 | + "explainer_error": str(e), |
| 799 | + "explainer_error_trace": traceback.format_exc(), |
| 800 | + } |
783 | 801 | else:
|
784 | 802 | logger.warning(
|
785 | 803 | f"Skipping explanations for {s_id}, as forecast was not generated."
|
@@ -816,6 +834,13 @@ def local_explainer(self, kernel_explainer, series_id, datetime_col_name) -> Non
|
816 | 834 | local_kernel_explnr_df = pd.DataFrame(
|
817 | 835 | local_kernel_explnr_vals, columns=data.columns
|
818 | 836 | )
|
| 837 | + |
| 838 | + # Add date column to local explanation DataFrame |
| 839 | + local_kernel_explnr_df[ForecastOutputColumns.DATE] = ( |
| 840 | + self.datasets.get_horizon_at_series( |
| 841 | + s_id=series_id |
| 842 | + )[self.spec.datetime_column.name].reset_index(drop=True) |
| 843 | + ) |
819 | 844 | self.local_explanation[series_id] = local_kernel_explnr_df
|
820 | 845 |
|
821 | 846 | def get_explain_predict_fn(self, series_id, fcst_col_name="yhat"):
|
|
0 commit comments