Skip to content

Commit 7d13919

Browse files
authored
ODSC 68406: AutoMLx Global Explainer (#1071)
2 parents a030882 + 1619dee commit 7d13919

File tree

2 files changed

+22
-19
lines changed

2 files changed

+22
-19
lines changed

ads/opctl/operator/lowcode/forecast/model/automlx.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -249,17 +249,18 @@ def _generate_report(self):
249249
self.explain_model()
250250

251251
global_explanation_section = None
252-
if self.spec.explanations_accuracy_mode != SpeedAccuracyMode.AUTOMLX:
253-
# Convert the global explanation data to a DataFrame
254-
global_explanation_df = pd.DataFrame(self.global_explanation)
255252

256-
self.formatted_global_explanation = (
257-
global_explanation_df / global_explanation_df.sum(axis=0) * 100
258-
)
259-
self.formatted_global_explanation = self.formatted_global_explanation.rename(
260-
{self.spec.datetime_column.name: ForecastOutputColumns.DATE},
261-
axis=1,
262-
)
253+
# Convert the global explanation data to a DataFrame
254+
global_explanation_df = pd.DataFrame(self.global_explanation)
255+
256+
self.formatted_global_explanation = (
257+
global_explanation_df / global_explanation_df.sum(axis=0) * 100
258+
)
259+
260+
self.formatted_global_explanation.rename(
261+
columns={self.spec.datetime_column.name: ForecastOutputColumns.DATE},
262+
inplace=True,
263+
)
263264

264265
aggregate_local_explanations = pd.DataFrame()
265266
for s_id, local_ex_df in self.local_explanation.items():
@@ -428,7 +429,9 @@ def explain_model(self):
428429
# Use the MLExplainer class from AutoMLx to generate explanations
429430
explainer = automlx.MLExplainer(
430431
self.models[s_id]["model"],
431-
self.datasets.additional_data.get_data_for_series(series_id=s_id)
432+
self.datasets.additional_data.get_data_for_series(
433+
series_id=s_id
434+
)
432435
.drop(self.spec.datetime_column.name, axis=1)
433436
.head(-self.spec.horizon)
434437
if self.spec.additional_data
@@ -463,6 +466,13 @@ def explain_model(self):
463466

464467
# Store the explanations in the local_explanation dictionary
465468
self.local_explanation[s_id] = explanations_df
469+
470+
self.global_explanation[s_id] = dict(
471+
zip(
472+
self.local_explanation[s_id].columns,
473+
np.nanmean((self.local_explanation[s_id]), axis=0),
474+
)
475+
)
466476
else:
467477
# Fall back to the default explanation generation method
468478
super().explain_model()

ads/opctl/operator/lowcode/forecast/model/base_model.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,7 @@ def _save_report(
503503
f2.write(f1.read())
504504

505505
# forecast csv report
506+
# todo: add test data into forecast.csv
506507
# if self.spec.test_data is not None:
507508
# test_data_dict = test_data.get_dict_by_series()
508509
# for series_id, test_data_values in test_data_dict.items():
@@ -772,14 +773,6 @@ def explain_model(self):
772773
logger.warn(
773774
"No explanations generated. Ensure that additional data has been provided."
774775
)
775-
elif (
776-
self.spec.model == SupportedModels.AutoMLX
777-
and self.spec.explanations_accuracy_mode
778-
== SpeedAccuracyMode.AUTOMLX
779-
):
780-
logger.warning(
781-
"Global explanations not available for AutoMLX models with inherent explainability"
782-
)
783776
else:
784777
self.global_explanation[s_id] = dict(
785778
zip(

0 commit comments

Comments
 (0)