Skip to content

Commit fa25cbf

Browse files
committed
update the column names and enable the global explainer in the reports
1 parent 5239f76 commit fa25cbf

File tree

2 files changed

+33
-26
lines changed

2 files changed

+33
-26
lines changed

ads/opctl/operator/lowcode/forecast/model/automlx.py

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -247,17 +247,19 @@ def _generate_report(self):
247247
self.explain_model()
248248

249249
global_explanation_section = None
250-
if self.spec.explanations_accuracy_mode != SpeedAccuracyMode.AUTOMLX:
251-
# Convert the global explanation data to a DataFrame
252-
global_explanation_df = pd.DataFrame(self.global_explanation)
253250

254-
self.formatted_global_explanation = (
255-
global_explanation_df / global_explanation_df.sum(axis=0) * 100
256-
)
257-
self.formatted_global_explanation = self.formatted_global_explanation.rename(
251+
# Convert the global explanation data to a DataFrame
252+
global_explanation_df = pd.DataFrame(self.global_explanation)
253+
254+
self.formatted_global_explanation = (
255+
global_explanation_df / global_explanation_df.sum(axis=0) * 100
256+
)
257+
self.formatted_global_explanation = (
258+
self.formatted_global_explanation.rename(
258259
{self.spec.datetime_column.name: ForecastOutputColumns.DATE},
259260
axis=1,
260261
)
262+
)
261263

262264
aggregate_local_explanations = pd.DataFrame()
263265
for s_id, local_ex_df in self.local_explanation.items():
@@ -269,11 +271,15 @@ def _generate_report(self):
269271
self.formatted_local_explanation = aggregate_local_explanations
270272

271273
if not self.target_cat_col:
272-
self.formatted_global_explanation = self.formatted_global_explanation.rename(
273-
{"Series 1": self.original_target_column},
274-
axis=1,
274+
self.formatted_global_explanation = (
275+
self.formatted_global_explanation.rename(
276+
{"Series 1": self.original_target_column},
277+
axis=1,
278+
)
279+
)
280+
self.formatted_local_explanation.drop(
281+
"Series", axis=1, inplace=True
275282
)
276-
self.formatted_local_explanation.drop("Series", axis=1, inplace=True)
277283

278284
# Create a markdown section for the global explainability
279285
global_explanation_section = rc.Block(
@@ -422,7 +428,9 @@ def explain_model(self):
422428
# Use the MLExplainer class from AutoMLx to generate explanations
423429
explainer = automlx.MLExplainer(
424430
self.models[s_id],
425-
self.datasets.additional_data.get_data_for_series(series_id=s_id)
431+
self.datasets.additional_data.get_data_for_series(
432+
series_id=s_id
433+
)
426434
.drop(self.spec.datetime_column.name, axis=1)
427435
.head(-self.spec.horizon)
428436
if self.spec.additional_data
@@ -433,7 +441,9 @@ def explain_model(self):
433441

434442
# Generate explanations for the forecast
435443
explanations = explainer.explain_prediction(
436-
X=self.datasets.additional_data.get_data_for_series(series_id=s_id)
444+
X=self.datasets.additional_data.get_data_for_series(
445+
series_id=s_id
446+
)
437447
.drop(self.spec.datetime_column.name, axis=1)
438448
.tail(self.spec.horizon)
439449
if self.spec.additional_data
@@ -445,7 +455,9 @@ def explain_model(self):
445455
explanations_df = pd.concat(
446456
[exp.to_dataframe() for exp in explanations]
447457
)
448-
explanations_df["row"] = explanations_df.groupby("Feature").cumcount()
458+
explanations_df["row"] = explanations_df.groupby(
459+
"Feature"
460+
).cumcount()
449461
explanations_df = explanations_df.pivot(
450462
index="row", columns="Feature", values="Attribution"
451463
)
@@ -454,14 +466,17 @@ def explain_model(self):
454466
# Store the explanations in the local_explanation dictionary
455467
self.local_explanation[s_id] = explanations_df
456468

457-
self.global_explanation[s_id] = dict(zip(
458-
data_i.columns[1:],
459-
np.average(np.absolute(explanations_df[:, 1:]), axis=0),
469+
self.global_explanation[s_id] = dict(
470+
zip(
471+
self.local_explanation[s_id].columns,
472+
np.nanmean((self.local_explanation[s_id]), axis=0),
460473
)
461474
)
462475
else:
463476
# Fall back to the default explanation generation method
464477
super().explain_model()
465478
except Exception as e:
466-
logger.warning(f"Failed to generate explanations for series {s_id} with error: {e}.")
479+
logger.warning(
480+
f"Failed to generate explanations for series {s_id} with error: {e}."
481+
)
467482
logger.debug(f"Full Traceback: {traceback.format_exc()}")

ads/opctl/operator/lowcode/forecast/model/base_model.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -754,14 +754,6 @@ def explain_model(self):
754754
logger.warn(
755755
"No explanations generated. Ensure that additional data has been provided."
756756
)
757-
# elif (
758-
# self.spec.model == SupportedModels.AutoMLX
759-
# and self.spec.explanations_accuracy_mode
760-
# == SpeedAccuracyMode.AUTOMLX
761-
# ):
762-
# logger.warning(
763-
# "Global explanations not available for AutoMLX models with inherent explainability"
764-
# )
765757
else:
766758
self.global_explanation[s_id] = dict(
767759
zip(

0 commit comments

Comments
 (0)