Skip to content

Commit 65eb5a3

Browse files
committed
add date to explanation csv, update anomaly docs
1 parent 71efc26 commit 65eb5a3

File tree

3 files changed

+83
-41
lines changed

3 files changed

+83
-41
lines changed

ads/opctl/operator/lowcode/forecast/model/automlx.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,9 @@ def _generate_report(self):
257257
)
258258

259259
self.formatted_global_explanation.rename(
260-
columns={self.spec.datetime_column.name: ForecastOutputColumns.DATE},
260+
columns={
261+
self.spec.datetime_column.name: ForecastOutputColumns.DATE
262+
},
261263
inplace=True,
262264
)
263265

@@ -462,14 +464,27 @@ def explain_model(self):
462464
index="row", columns="Feature", values="Attribution"
463465
)
464466
explanations_df = explanations_df.reset_index(drop=True)
465-
467+
explanations_df[ForecastOutputColumns.DATE] = (
468+
self.datasets.get_horizon_at_series(
469+
s_id=s_id
470+
)[self.spec.datetime_column.name].reset_index(drop=True)
471+
)
466472
# Store the explanations in the local_explanation dictionary
467473
self.local_explanation[s_id] = explanations_df
468474

469475
self.global_explanation[s_id] = dict(
470476
zip(
471-
self.local_explanation[s_id].columns,
472-
np.nanmean(np.abs(self.local_explanation[s_id]), axis=0),
477+
self.local_explanation[s_id]
478+
.drop(ForecastOutputColumns.DATE, axis=1)
479+
.columns,
480+
np.nanmean(
481+
np.abs(
482+
self.local_explanation[s_id].drop(
483+
ForecastOutputColumns.DATE, axis=1
484+
)
485+
),
486+
axis=0,
487+
),
473488
)
474489
)
475490
else:
@@ -478,7 +493,9 @@ def explain_model(self):
478493
except Exception as e:
479494
if s_id in self.errors_dict:
480495
self.errors_dict[s_id]["explainer_error"] = str(e)
481-
self.errors_dict[s_id]["explainer_error_trace"] = traceback.format_exc()
496+
self.errors_dict[s_id]["explainer_error_trace"] = (
497+
traceback.format_exc()
498+
)
482499
else:
483500
self.errors_dict[s_id] = {
484501
"model_name": self.spec.model,

ads/opctl/operator/lowcode/forecast/model/base_model.py

Lines changed: 59 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
AUTO_SELECT,
4747
BACKTEST_REPORT_NAME,
4848
SUMMARY_METRICS_HORIZON_LIMIT,
49+
ForecastOutputColumns,
4950
SpeedAccuracyMode,
5051
SupportedMetrics,
5152
SupportedModels,
@@ -743,43 +744,60 @@ def explain_model(self):
743744
include_horizon=False
744745
).items():
745746
if s_id in self.models:
746-
explain_predict_fn = self.get_explain_predict_fn(series_id=s_id)
747-
data_trimmed = data_i.tail(
748-
max(int(len(data_i) * ratio), 5)
749-
).reset_index(drop=True)
750-
data_trimmed[datetime_col_name] = data_trimmed[datetime_col_name].apply(
751-
lambda x: x.timestamp()
752-
)
753-
754-
# Explainer fails when boolean columns are passed
755-
756-
_, data_trimmed_encoded = _label_encode_dataframe(
757-
data_trimmed,
758-
no_encode={datetime_col_name, self.original_target_column},
759-
)
760-
761-
kernel_explnr = PermutationExplainer(
762-
model=explain_predict_fn, masker=data_trimmed_encoded
763-
)
764-
kernel_explnr_vals = kernel_explnr.shap_values(data_trimmed_encoded)
765-
exp_end_time = time.time()
766-
global_ex_time = global_ex_time + exp_end_time - exp_start_time
767-
self.local_explainer(
768-
kernel_explnr, series_id=s_id, datetime_col_name=datetime_col_name
769-
)
770-
local_ex_time = local_ex_time + time.time() - exp_end_time
747+
try:
748+
explain_predict_fn = self.get_explain_predict_fn(series_id=s_id)
749+
data_trimmed = data_i.tail(
750+
max(int(len(data_i) * ratio), 5)
751+
).reset_index(drop=True)
752+
data_trimmed[datetime_col_name] = data_trimmed[
753+
datetime_col_name
754+
].apply(lambda x: x.timestamp())
755+
756+
# Explainer fails when boolean columns are passed
757+
758+
_, data_trimmed_encoded = _label_encode_dataframe(
759+
data_trimmed,
760+
no_encode={datetime_col_name, self.original_target_column},
761+
)
771762

772-
if not len(kernel_explnr_vals):
773-
logger.warn(
774-
"No explanations generated. Ensure that additional data has been provided."
763+
kernel_explnr = PermutationExplainer(
764+
model=explain_predict_fn, masker=data_trimmed_encoded
775765
)
776-
else:
777-
self.global_explanation[s_id] = dict(
778-
zip(
779-
data_trimmed.columns[1:],
780-
np.average(np.absolute(kernel_explnr_vals[:, 1:]), axis=0),
781-
)
766+
kernel_explnr_vals = kernel_explnr.shap_values(data_trimmed_encoded)
767+
exp_end_time = time.time()
768+
global_ex_time = global_ex_time + exp_end_time - exp_start_time
769+
self.local_explainer(
770+
kernel_explnr,
771+
series_id=s_id,
772+
datetime_col_name=datetime_col_name,
782773
)
774+
local_ex_time = local_ex_time + time.time() - exp_end_time
775+
776+
if not len(kernel_explnr_vals):
777+
logger.warn(
778+
"No explanations generated. Ensure that additional data has been provided."
779+
)
780+
else:
781+
self.global_explanation[s_id] = dict(
782+
zip(
783+
data_trimmed.columns[1:],
784+
np.average(
785+
np.absolute(kernel_explnr_vals[:, 1:]), axis=0
786+
),
787+
)
788+
)
789+
except Exception as e:
790+
if s_id in self.errors_dict:
791+
self.errors_dict[s_id]["explainer_error"] = str(e)
792+
self.errors_dict[s_id]["explainer_error_trace"] = (
793+
traceback.format_exc()
794+
)
795+
else:
796+
self.errors_dict[s_id] = {
797+
"model_name": self.spec.model,
798+
"explainer_error": str(e),
799+
"explainer_error_trace": traceback.format_exc(),
800+
}
783801
else:
784802
logger.warn(
785803
f"Skipping explanations for {s_id}, as forecast was not generated."
@@ -816,6 +834,13 @@ def local_explainer(self, kernel_explainer, series_id, datetime_col_name) -> Non
816834
local_kernel_explnr_df = pd.DataFrame(
817835
local_kernel_explnr_vals, columns=data.columns
818836
)
837+
838+
# Add date column to local explanation DataFrame
839+
local_kernel_explnr_df[ForecastOutputColumns.DATE] = (
840+
self.datasets.get_horizon_at_series(
841+
s_id=series_id
842+
)[self.spec.datetime_column.name].reset_index(drop=True)
843+
)
819844
self.local_explanation[series_id] = local_kernel_explnr_df
820845

821846
def get_explain_predict_fn(self, series_id, fcst_col_name="yhat"):

docs/source/user_guide/operators/anomaly_detection_operator/index.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ The Anomaly Detection Operator accepts a dataset with:
1717
* (Optionally) 1 or more seires columns (such that the target is indexed by datetime and series)
1818
* (Optionall) An arbitrary number of additional variables
1919

20-
Besides this input data, the user can also specify validation data, if available. Validation data should have all the columns of the input data plus a binary column titled "anomaly". The "anomaly" column should be -1 for anomalies and 1 for normal rows.
20+
Besides this input data, the user can also specify validation data, if available. Validation data should have all the columns of the input data plus a binary column titled "anomaly". The "anomaly" column should be 1 for anomalies and 0 for normal rows.
2121

22-
Finally the user can provide "test_data" in order to recieve test metrics and evaluate the Operator's performance more easily. Test data should indexed by date and (optionally) series. Test data should have a -1 for anomalous rows and 1 for normal rows.
22+
Finally the user can provide "test_data" in order to recieve test metrics and evaluate the Operator's performance more easily. Test data should indexed by date and (optionally) series. Test data should have a 1 for anomalous rows and 0 for normal rows.
2323

2424
**Multivariate vs. Univariate Anomaly Detection**
2525

0 commit comments

Comments
 (0)