Skip to content

Commit 724f423

Browse files
authored
Merge branch 'main' into add_min_max_prophet
2 parents 5225e65 + c5edd1c commit 724f423

File tree

3 files changed

+77
-39
lines changed

3 files changed

+77
-39
lines changed

ads/opctl/operator/lowcode/forecast/model/automlx.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -469,14 +469,27 @@ def explain_model(self):
469469
index="row", columns="Feature", values="Attribution"
470470
)
471471
explanations_df = explanations_df.reset_index(drop=True)
472-
472+
explanations_df[ForecastOutputColumns.DATE] = (
473+
self.datasets.get_horizon_at_series(
474+
s_id=s_id
475+
)[self.spec.datetime_column.name].reset_index(drop=True)
476+
)
473477
# Store the explanations in the local_explanation dictionary
474478
self.local_explanation[s_id] = explanations_df
475479

476480
self.global_explanation[s_id] = dict(
477481
zip(
478-
self.local_explanation[s_id].columns,
479-
np.nanmean(np.abs(self.local_explanation[s_id]), axis=0),
482+
self.local_explanation[s_id]
483+
.drop(ForecastOutputColumns.DATE, axis=1)
484+
.columns,
485+
np.nanmean(
486+
np.abs(
487+
self.local_explanation[s_id].drop(
488+
ForecastOutputColumns.DATE, axis=1
489+
)
490+
),
491+
axis=0,
492+
),
480493
)
481494
)
482495
else:

ads/opctl/operator/lowcode/forecast/model/base_model.py

Lines changed: 59 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
AUTO_SELECT,
4747
BACKTEST_REPORT_NAME,
4848
SUMMARY_METRICS_HORIZON_LIMIT,
49+
ForecastOutputColumns,
4950
SpeedAccuracyMode,
5051
SupportedMetrics,
5152
SupportedModels,
@@ -743,43 +744,60 @@ def explain_model(self):
743744
include_horizon=False
744745
).items():
745746
if s_id in self.models:
746-
explain_predict_fn = self.get_explain_predict_fn(series_id=s_id)
747-
data_trimmed = data_i.tail(
748-
max(int(len(data_i) * ratio), 5)
749-
).reset_index(drop=True)
750-
data_trimmed[datetime_col_name] = data_trimmed[datetime_col_name].apply(
751-
lambda x: x.timestamp()
752-
)
753-
754-
# Explainer fails when boolean columns are passed
755-
756-
_, data_trimmed_encoded = _label_encode_dataframe(
757-
data_trimmed,
758-
no_encode={datetime_col_name, self.original_target_column},
759-
)
760-
761-
kernel_explnr = PermutationExplainer(
762-
model=explain_predict_fn, masker=data_trimmed_encoded
763-
)
764-
kernel_explnr_vals = kernel_explnr.shap_values(data_trimmed_encoded)
765-
exp_end_time = time.time()
766-
global_ex_time = global_ex_time + exp_end_time - exp_start_time
767-
self.local_explainer(
768-
kernel_explnr, series_id=s_id, datetime_col_name=datetime_col_name
769-
)
770-
local_ex_time = local_ex_time + time.time() - exp_end_time
747+
try:
748+
explain_predict_fn = self.get_explain_predict_fn(series_id=s_id)
749+
data_trimmed = data_i.tail(
750+
max(int(len(data_i) * ratio), 5)
751+
).reset_index(drop=True)
752+
data_trimmed[datetime_col_name] = data_trimmed[
753+
datetime_col_name
754+
].apply(lambda x: x.timestamp())
755+
756+
# Explainer fails when boolean columns are passed
757+
758+
_, data_trimmed_encoded = _label_encode_dataframe(
759+
data_trimmed,
760+
no_encode={datetime_col_name, self.original_target_column},
761+
)
771762

772-
if not len(kernel_explnr_vals):
773-
logger.warning(
774-
"No explanations generated. Ensure that additional data has been provided."
763+
kernel_explnr = PermutationExplainer(
764+
model=explain_predict_fn, masker=data_trimmed_encoded
775765
)
776-
else:
777-
self.global_explanation[s_id] = dict(
778-
zip(
779-
data_trimmed.columns[1:],
780-
np.average(np.absolute(kernel_explnr_vals[:, 1:]), axis=0),
781-
)
766+
kernel_explnr_vals = kernel_explnr.shap_values(data_trimmed_encoded)
767+
exp_end_time = time.time()
768+
global_ex_time = global_ex_time + exp_end_time - exp_start_time
769+
self.local_explainer(
770+
kernel_explnr,
771+
series_id=s_id,
772+
datetime_col_name=datetime_col_name,
782773
)
774+
local_ex_time = local_ex_time + time.time() - exp_end_time
775+
776+
if not len(kernel_explnr_vals):
777+
logger.warning(
778+
"No explanations generated. Ensure that additional data has been provided."
779+
)
780+
else:
781+
self.global_explanation[s_id] = dict(
782+
zip(
783+
data_trimmed.columns[1:],
784+
np.average(
785+
np.absolute(kernel_explnr_vals[:, 1:]), axis=0
786+
),
787+
)
788+
)
789+
except Exception as e:
790+
if s_id in self.errors_dict:
791+
self.errors_dict[s_id]["explainer_error"] = str(e)
792+
self.errors_dict[s_id]["explainer_error_trace"] = (
793+
traceback.format_exc()
794+
)
795+
else:
796+
self.errors_dict[s_id] = {
797+
"model_name": self.spec.model,
798+
"explainer_error": str(e),
799+
"explainer_error_trace": traceback.format_exc(),
800+
}
783801
else:
784802
logger.warning(
785803
f"Skipping explanations for {s_id}, as forecast was not generated."
@@ -816,6 +834,13 @@ def local_explainer(self, kernel_explainer, series_id, datetime_col_name) -> Non
816834
local_kernel_explnr_df = pd.DataFrame(
817835
local_kernel_explnr_vals, columns=data.columns
818836
)
837+
838+
# Add date column to local explanation DataFrame
839+
local_kernel_explnr_df[ForecastOutputColumns.DATE] = (
840+
self.datasets.get_horizon_at_series(
841+
s_id=series_id
842+
)[self.spec.datetime_column.name].reset_index(drop=True)
843+
)
819844
self.local_explanation[series_id] = local_kernel_explnr_df
820845

821846
def get_explain_predict_fn(self, series_id, fcst_col_name="yhat"):

docs/source/user_guide/operators/anomaly_detection_operator/index.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ The Anomaly Detection Operator accepts a dataset with:
1717
* (Optionally) 1 or more seires columns (such that the target is indexed by datetime and series)
1818
* (Optionall) An arbitrary number of additional variables
1919

20-
Besides this input data, the user can also specify validation data, if available. Validation data should have all the columns of the input data plus a binary column titled "anomaly". The "anomaly" column should be -1 for anomalies and 1 for normal rows.
20+
Besides this input data, the user can also specify validation data, if available. Validation data should have all the columns of the input data plus a binary column titled "anomaly". The "anomaly" column should be 1 for anomalies and 0 for normal rows.
2121

22-
Finally the user can provide "test_data" in order to recieve test metrics and evaluate the Operator's performance more easily. Test data should indexed by date and (optionally) series. Test data should have a -1 for anomalous rows and 1 for normal rows.
22+
Finally the user can provide "test_data" in order to recieve test metrics and evaluate the Operator's performance more easily. Test data should indexed by date and (optionally) series. Test data should have a 1 for anomalous rows and 0 for normal rows.
2323

2424
**Multivariate vs. Univariate Anomaly Detection**
2525

0 commit comments

Comments
 (0)