Skip to content

Commit 6a76574

Browse files
committed
add automlx internal explainability as a mode explainaibility mode
1 parent 3d8f148 commit 6a76574

File tree

4 files changed

+94
-24
lines changed

4 files changed

+94
-24
lines changed

ads/opctl/operator/lowcode/forecast/const.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,12 @@ class SpeedAccuracyMode(str, metaclass=ExtendedEnumMeta):
2727
HIGH_ACCURACY = "HIGH_ACCURACY"
2828
BALANCED = "BALANCED"
2929
FAST_APPROXIMATE = "FAST_APPROXIMATE"
30+
AUTOMLX = "AUTOMLX"
3031
ratio = {}
3132
ratio[HIGH_ACCURACY] = 1 # 100 % data used for generating explanations
3233
ratio[BALANCED] = 0.5 # 50 % data used for generating explanations
3334
ratio[FAST_APPROXIMATE] = 0 # constant
35+
ratio[AUTOMLX] = 0 # constant
3436

3537

3638
class SupportedMetrics(str, metaclass=ExtendedEnumMeta):

ads/opctl/operator/lowcode/forecast/model/automlx.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from ads.opctl.operator.lowcode.forecast.const import (
1818
AUTOMLX_METRIC_MAP,
1919
ForecastOutputColumns,
20+
SpeedAccuracyMode,
2021
SupportedModels,
2122
)
2223
from ads.opctl.operator.lowcode.forecast.utils import _label_encode_dataframe
@@ -239,27 +240,27 @@ def _generate_report(self):
239240
# If the key is present, call the "explain_model" method
240241
self.explain_model()
241242

242-
# Convert the global explanation data to a DataFrame
243-
global_explanation_df = pd.DataFrame(self.global_explanation)
243+
global_explanation_section = None
244+
if self.spec.explanations_accuracy_mode != SpeedAccuracyMode.AUTOMLX:
245+
# Convert the global explanation data to a DataFrame
246+
global_explanation_df = pd.DataFrame(self.global_explanation)
244247

245-
self.formatted_global_explanation = (
246-
global_explanation_df / global_explanation_df.sum(axis=0) * 100
247-
)
248-
self.formatted_global_explanation = (
249-
self.formatted_global_explanation.rename(
248+
self.formatted_global_explanation = (
249+
global_explanation_df / global_explanation_df.sum(axis=0) * 100
250+
)
251+
self.formatted_global_explanation = self.formatted_global_explanation.rename(
250252
{self.spec.datetime_column.name: ForecastOutputColumns.DATE},
251253
axis=1,
252254
)
253-
)
254255

255-
# Create a markdown section for the global explainability
256-
global_explanation_section = rc.Block(
257-
rc.Heading("Global Explanation of Models", level=2),
258-
rc.Text(
259-
"The following tables provide the feature attribution for the global explainability."
260-
),
261-
rc.DataTable(self.formatted_global_explanation, index=True),
262-
)
256+
# Create a markdown section for the global explainability
257+
global_explanation_section = rc.Block(
258+
rc.Heading("Global Explanation of Models", level=2),
259+
rc.Text(
260+
"The following tables provide the feature attribution for the global explainability."
261+
),
262+
rc.DataTable(self.formatted_global_explanation, index=True),
263+
)
263264

264265
aggregate_local_explanations = pd.DataFrame()
265266
for s_id, local_ex_df in self.local_explanation.items():
@@ -284,8 +285,11 @@ def _generate_report(self):
284285
)
285286

286287
# Append the global explanation text and section to the "other_sections" list
288+
if global_explanation_section:
289+
other_sections.append(global_explanation_section)
290+
291+
# Append the local explanation text and section to the "other_sections" list
287292
other_sections = other_sections + [
288-
global_explanation_section,
289293
local_explanation_section,
290294
]
291295
except Exception as e:

ads/opctl/operator/lowcode/forecast/model/base_model.py

Lines changed: 70 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
SpeedAccuracyMode,
4848
SupportedMetrics,
4949
SupportedModels,
50-
BACKTEST_REPORT_NAME
50+
BACKTEST_REPORT_NAME,
5151
)
5252
from ..operator_config import ForecastOperatorConfig, ForecastOperatorSpec
5353
from .forecast_datasets import ForecastDatasets
@@ -259,7 +259,11 @@ def generate_report(self):
259259
output_dir = self.spec.output_directory.url
260260
file_path = f"{output_dir}/{BACKTEST_REPORT_NAME}"
261261
if self.spec.model == AUTO_SELECT:
262-
backtest_sections.append(rc.Heading("Auto-Select Backtesting and Performance Metrics", level=2))
262+
backtest_sections.append(
263+
rc.Heading(
264+
"Auto-Select Backtesting and Performance Metrics", level=2
265+
)
266+
)
263267
if not os.path.exists(file_path):
264268
failure_msg = rc.Text(
265269
"auto-select could not be executed. Please check the "
@@ -268,15 +272,23 @@ def generate_report(self):
268272
backtest_sections.append(failure_msg)
269273
else:
270274
backtest_stats = pd.read_csv(file_path)
271-
model_metric_map = backtest_stats.drop(columns=['metric', 'backtest'])
272-
average_dict = {k: round(v, 4) for k, v in model_metric_map.mean().to_dict().items()}
275+
model_metric_map = backtest_stats.drop(
276+
columns=["metric", "backtest"]
277+
)
278+
average_dict = {
279+
k: round(v, 4)
280+
for k, v in model_metric_map.mean().to_dict().items()
281+
}
273282
best_model = min(average_dict, key=average_dict.get)
274283
summary_text = rc.Text(
275284
f"Overall, the average {self.spec.metric} scores for the models are {average_dict}, with"
276-
f" {best_model} being identified as the top-performing model during backtesting.")
285+
f" {best_model} being identified as the top-performing model during backtesting."
286+
)
277287
backtest_table = rc.DataTable(backtest_stats, index=True)
278288
liner_plot = get_auto_select_plot(backtest_stats)
279-
backtest_sections.extend([backtest_table, summary_text, liner_plot])
289+
backtest_sections.extend(
290+
[backtest_table, summary_text, liner_plot]
291+
)
280292

281293
forecast_plots = []
282294
if len(self.forecast_output.list_series_ids()) > 0:
@@ -643,6 +655,12 @@ def _save_model(self, output_dir, storage_options):
643655
"Please run `python3 -m pip install shap` to install the required dependencies for model explanation."
644656
),
645657
)
658+
@runtime_dependency(
659+
module="automlx",
660+
err_msg=(
661+
"Please run `python3 -m pip install automlx` to install the required dependencies for model explanation."
662+
),
663+
)
646664
def explain_model(self):
647665
"""
648666
Generates an explanation for the model by using the SHAP (Shapley Additive exPlanations) library.
@@ -668,7 +686,44 @@ def explain_model(self):
668686
for s_id, data_i in self.datasets.get_data_by_series(
669687
include_horizon=False
670688
).items():
671-
if s_id in self.models:
689+
if (
690+
self.spec.model == SupportedModels.AutoMLX
691+
and self.spec.explanations_accuracy_mode == SpeedAccuracyMode.AUTOMLX
692+
):
693+
import automlx
694+
695+
explainer = automlx.MLExplainer(
696+
self.models[s_id],
697+
self.datasets.additional_data.get_data_for_series(series_id=s_id)
698+
.drop(self.spec.datetime_column.name, axis=1)
699+
.head(-self.spec.horizon)
700+
if self.spec.additional_data
701+
else None,
702+
pd.DataFrame(data_i[self.spec.target_column]),
703+
task="forecasting",
704+
)
705+
706+
explanations = explainer.explain_prediction(
707+
X=self.datasets.additional_data.get_data_for_series(series_id=s_id)
708+
.drop(self.spec.datetime_column.name, axis=1)
709+
.tail(self.spec.horizon)
710+
if self.spec.additional_data
711+
else None,
712+
forecast_timepoints=list(range(self.spec.horizon + 1)),
713+
)
714+
715+
explanations_df = pd.concat(
716+
[exp.to_dataframe() for exp in explanations]
717+
)
718+
explanations_df["row"] = explanations_df.groupby("Feature").cumcount()
719+
explanations_df = explanations_df.pivot(
720+
index="row", columns="Feature", values="Attribution"
721+
)
722+
explanations_df = explanations_df.reset_index(drop=True)
723+
# explanations_df[self.spec.datetime_column.name]=self.datasets.additional_data.get_data_for_series(series_id=s_id).tail(self.spec.horizon)[self.spec.datetime_column.name].reset_index(drop=True)
724+
self.local_explanation[s_id] = explanations_df
725+
726+
elif s_id in self.models:
672727
explain_predict_fn = self.get_explain_predict_fn(series_id=s_id)
673728
data_trimmed = data_i.tail(
674729
max(int(len(data_i) * ratio), 5)
@@ -699,6 +754,14 @@ def explain_model(self):
699754
logger.warn(
700755
"No explanations generated. Ensure that additional data has been provided."
701756
)
757+
elif (
758+
self.spec.model == SupportedModels.AutoMLX
759+
and self.spec.explanations_accuracy_mode
760+
== SpeedAccuracyMode.AUTOMLX
761+
):
762+
logger.warning(
763+
"Global explanations not available for AutoMLX models with inherent explainability"
764+
)
702765
else:
703766
self.global_explanation[s_id] = dict(
704767
zip(

ads/opctl/operator/lowcode/forecast/schema.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,7 @@ spec:
332332
- HIGH_ACCURACY
333333
- BALANCED
334334
- FAST_APPROXIMATE
335+
- AUTOMLX
335336

336337
generate_report:
337338
type: boolean

0 commit comments

Comments
 (0)