Skip to content

Commit e684c5d

Browse files
authored
Merge branch 'main' into ODSC-70287/recommendation-doc-update
2 parents 43fcdfa + c5edd1c commit e684c5d

File tree

3 files changed

+83
-41
lines changed

3 files changed

+83
-41
lines changed

ads/opctl/operator/lowcode/forecast/model/automlx.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,9 @@ def _generate_report(self):
257257
)
258258

259259
self.formatted_global_explanation.rename(
260-
columns={self.spec.datetime_column.name: ForecastOutputColumns.DATE},
260+
columns={
261+
self.spec.datetime_column.name: ForecastOutputColumns.DATE
262+
},
261263
inplace=True,
262264
)
263265

@@ -462,14 +464,27 @@ def explain_model(self):
462464
index="row", columns="Feature", values="Attribution"
463465
)
464466
explanations_df = explanations_df.reset_index(drop=True)
465-
467+
explanations_df[ForecastOutputColumns.DATE] = (
468+
self.datasets.get_horizon_at_series(
469+
s_id=s_id
470+
)[self.spec.datetime_column.name].reset_index(drop=True)
471+
)
466472
# Store the explanations in the local_explanation dictionary
467473
self.local_explanation[s_id] = explanations_df
468474

469475
self.global_explanation[s_id] = dict(
470476
zip(
471-
self.local_explanation[s_id].columns,
472-
np.nanmean(np.abs(self.local_explanation[s_id]), axis=0),
477+
self.local_explanation[s_id]
478+
.drop(ForecastOutputColumns.DATE, axis=1)
479+
.columns,
480+
np.nanmean(
481+
np.abs(
482+
self.local_explanation[s_id].drop(
483+
ForecastOutputColumns.DATE, axis=1
484+
)
485+
),
486+
axis=0,
487+
),
473488
)
474489
)
475490
else:
@@ -478,7 +493,9 @@ def explain_model(self):
478493
except Exception as e:
479494
if s_id in self.errors_dict:
480495
self.errors_dict[s_id]["explainer_error"] = str(e)
481-
self.errors_dict[s_id]["explainer_error_trace"] = traceback.format_exc()
496+
self.errors_dict[s_id]["explainer_error_trace"] = (
497+
traceback.format_exc()
498+
)
482499
else:
483500
self.errors_dict[s_id] = {
484501
"model_name": self.spec.model,

ads/opctl/operator/lowcode/forecast/model/base_model.py

Lines changed: 59 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
AUTO_SELECT,
4747
BACKTEST_REPORT_NAME,
4848
SUMMARY_METRICS_HORIZON_LIMIT,
49+
ForecastOutputColumns,
4950
SpeedAccuracyMode,
5051
SupportedMetrics,
5152
SupportedModels,
@@ -742,43 +743,60 @@ def explain_model(self):
742743
include_horizon=False
743744
).items():
744745
if s_id in self.models:
745-
explain_predict_fn = self.get_explain_predict_fn(series_id=s_id)
746-
data_trimmed = data_i.tail(
747-
max(int(len(data_i) * ratio), 5)
748-
).reset_index(drop=True)
749-
data_trimmed[datetime_col_name] = data_trimmed[datetime_col_name].apply(
750-
lambda x: x.timestamp()
751-
)
752-
753-
# Explainer fails when boolean columns are passed
754-
755-
_, data_trimmed_encoded = _label_encode_dataframe(
756-
data_trimmed,
757-
no_encode={datetime_col_name, self.original_target_column},
758-
)
759-
760-
kernel_explnr = PermutationExplainer(
761-
model=explain_predict_fn, masker=data_trimmed_encoded
762-
)
763-
kernel_explnr_vals = kernel_explnr.shap_values(data_trimmed_encoded)
764-
exp_end_time = time.time()
765-
global_ex_time = global_ex_time + exp_end_time - exp_start_time
766-
self.local_explainer(
767-
kernel_explnr, series_id=s_id, datetime_col_name=datetime_col_name
768-
)
769-
local_ex_time = local_ex_time + time.time() - exp_end_time
746+
try:
747+
explain_predict_fn = self.get_explain_predict_fn(series_id=s_id)
748+
data_trimmed = data_i.tail(
749+
max(int(len(data_i) * ratio), 5)
750+
).reset_index(drop=True)
751+
data_trimmed[datetime_col_name] = data_trimmed[
752+
datetime_col_name
753+
].apply(lambda x: x.timestamp())
754+
755+
# Explainer fails when boolean columns are passed
756+
757+
_, data_trimmed_encoded = _label_encode_dataframe(
758+
data_trimmed,
759+
no_encode={datetime_col_name, self.original_target_column},
760+
)
770761

771-
if not len(kernel_explnr_vals):
772-
logger.warn(
773-
"No explanations generated. Ensure that additional data has been provided."
762+
kernel_explnr = PermutationExplainer(
763+
model=explain_predict_fn, masker=data_trimmed_encoded
774764
)
775-
else:
776-
self.global_explanation[s_id] = dict(
777-
zip(
778-
data_trimmed.columns[1:],
779-
np.average(np.absolute(kernel_explnr_vals[:, 1:]), axis=0),
780-
)
765+
kernel_explnr_vals = kernel_explnr.shap_values(data_trimmed_encoded)
766+
exp_end_time = time.time()
767+
global_ex_time = global_ex_time + exp_end_time - exp_start_time
768+
self.local_explainer(
769+
kernel_explnr,
770+
series_id=s_id,
771+
datetime_col_name=datetime_col_name,
781772
)
773+
local_ex_time = local_ex_time + time.time() - exp_end_time
774+
775+
if not len(kernel_explnr_vals):
776+
logger.warn(
777+
"No explanations generated. Ensure that additional data has been provided."
778+
)
779+
else:
780+
self.global_explanation[s_id] = dict(
781+
zip(
782+
data_trimmed.columns[1:],
783+
np.average(
784+
np.absolute(kernel_explnr_vals[:, 1:]), axis=0
785+
),
786+
)
787+
)
788+
except Exception as e:
789+
if s_id in self.errors_dict:
790+
self.errors_dict[s_id]["explainer_error"] = str(e)
791+
self.errors_dict[s_id]["explainer_error_trace"] = (
792+
traceback.format_exc()
793+
)
794+
else:
795+
self.errors_dict[s_id] = {
796+
"model_name": self.spec.model,
797+
"explainer_error": str(e),
798+
"explainer_error_trace": traceback.format_exc(),
799+
}
782800
else:
783801
logger.warn(
784802
f"Skipping explanations for {s_id}, as forecast was not generated."
@@ -815,6 +833,13 @@ def local_explainer(self, kernel_explainer, series_id, datetime_col_name) -> Non
815833
local_kernel_explnr_df = pd.DataFrame(
816834
local_kernel_explnr_vals, columns=data.columns
817835
)
836+
837+
# Add date column to local explanation DataFrame
838+
local_kernel_explnr_df[ForecastOutputColumns.DATE] = (
839+
self.datasets.get_horizon_at_series(
840+
s_id=series_id
841+
)[self.spec.datetime_column.name].reset_index(drop=True)
842+
)
818843
self.local_explanation[series_id] = local_kernel_explnr_df
819844

820845
def get_explain_predict_fn(self, series_id, fcst_col_name="yhat"):

docs/source/user_guide/operators/anomaly_detection_operator/index.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ The Anomaly Detection Operator accepts a dataset with:
1717
* (Optionally) 1 or more seires columns (such that the target is indexed by datetime and series)
1818
* (Optionall) An arbitrary number of additional variables
1919

20-
Besides this input data, the user can also specify validation data, if available. Validation data should have all the columns of the input data plus a binary column titled "anomaly". The "anomaly" column should be -1 for anomalies and 1 for normal rows.
20+
Besides this input data, the user can also specify validation data, if available. Validation data should have all the columns of the input data plus a binary column titled "anomaly". The "anomaly" column should be 1 for anomalies and 0 for normal rows.
2121

22-
Finally the user can provide "test_data" in order to recieve test metrics and evaluate the Operator's performance more easily. Test data should indexed by date and (optionally) series. Test data should have a -1 for anomalous rows and 1 for normal rows.
22+
Finally the user can provide "test_data" in order to recieve test metrics and evaluate the Operator's performance more easily. Test data should indexed by date and (optionally) series. Test data should have a 1 for anomalous rows and 0 for normal rows.
2323

2424
**Multivariate vs. Univariate Anomaly Detection**
2525

0 commit comments

Comments
 (0)