Skip to content

Commit d9278a0

Browse files
committed
move automlx explaination to subclass
1 parent 4085550 commit d9278a0

File tree

2 files changed

+83
-50
lines changed

2 files changed

+83
-50
lines changed

ads/opctl/operator/lowcode/forecast/model/automlx.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,3 +370,75 @@ def _custom_predict_automlx(self, data):
370370
return self.models.get(self.series_id).forecast(
371371
X=data_temp, periods=data_temp.shape[0]
372372
)[self.series_id]
373+
374+
@runtime_dependency(
375+
module="automlx",
376+
err_msg=(
377+
"Please run `python3 -m pip install automlx` to install the required dependencies for model explanation."
378+
),
379+
)
380+
def explain_model(self):
381+
"""
382+
Generates explanations for the model using the AutoMLx library.
383+
384+
Parameters
385+
----------
386+
None
387+
388+
Returns
389+
-------
390+
None
391+
392+
Notes
393+
-----
394+
This function works by generating local explanations for each series in the dataset.
395+
It uses the ``MLExplainer`` class from the AutoMLx library to generate feature attributions
396+
for each series. The feature attributions are then stored in the ``self.local_explanation`` dictionary.
397+
398+
If the accuracy mode is set to AutoMLX, it uses the AutoMLx library to generate explanations.
399+
Otherwise, it falls back to the default explanation generation method.
400+
"""
401+
import automlx
402+
403+
# Loop through each series in the dataset
404+
for s_id, data_i in self.datasets.get_data_by_series(
405+
include_horizon=False
406+
).items():
407+
if self.spec.explanations_accuracy_mode == SpeedAccuracyMode.AUTOMLX:
408+
# Use the MLExplainer class from AutoMLx to generate explanations
409+
explainer = automlx.MLExplainer(
410+
self.models[s_id],
411+
self.datasets.additional_data.get_data_for_series(series_id=s_id)
412+
.drop(self.spec.datetime_column.name, axis=1)
413+
.head(-self.spec.horizon)
414+
if self.spec.additional_data
415+
else None,
416+
pd.DataFrame(data_i[self.spec.target_column]),
417+
task="forecasting",
418+
)
419+
420+
# Generate explanations for the forecast
421+
explanations = explainer.explain_prediction(
422+
X=self.datasets.additional_data.get_data_for_series(series_id=s_id)
423+
.drop(self.spec.datetime_column.name, axis=1)
424+
.tail(self.spec.horizon)
425+
if self.spec.additional_data
426+
else None,
427+
forecast_timepoints=list(range(self.spec.horizon + 1)),
428+
)
429+
430+
# Convert the explanations to a DataFrame
431+
explanations_df = pd.concat(
432+
[exp.to_dataframe() for exp in explanations]
433+
)
434+
explanations_df["row"] = explanations_df.groupby("Feature").cumcount()
435+
explanations_df = explanations_df.pivot(
436+
index="row", columns="Feature", values="Attribution"
437+
)
438+
explanations_df = explanations_df.reset_index(drop=True)
439+
440+
# Store the explanations in the local_explanation dictionary
441+
self.local_explanation[s_id] = explanations_df
442+
else:
443+
# Fall back to the default explanation generation method
444+
super().explain_model()

ads/opctl/operator/lowcode/forecast/model/base_model.py

Lines changed: 11 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -649,18 +649,19 @@ def _save_model(self, output_dir, storage_options):
649649
storage_options=storage_options,
650650
)
651651

652+
def _validate_automlx_explanation_mode(self):
653+
if self.spec.model != SupportedModels.AutoMLX and self.spec.explanations_accuracy_mode == SpeedAccuracyMode.AUTOMLX:
654+
raise ValueError(
655+
"AUTOMLX explanation accuracy mode is only supported for AutoMLX models. "
656+
"Please select mode other than AUTOMLX from the available explanations_accuracy_mode options"
657+
)
658+
652659
@runtime_dependency(
653660
module="shap",
654661
err_msg=(
655662
"Please run `python3 -m pip install shap` to install the required dependencies for model explanation."
656663
),
657664
)
658-
@runtime_dependency(
659-
module="automlx",
660-
err_msg=(
661-
"Please run `python3 -m pip install automlx` to install the required dependencies for model explanation."
662-
),
663-
)
664665
def explain_model(self):
665666
"""
666667
Generates an explanation for the model by using the SHAP (Shapley Additive exPlanations) library.
@@ -683,53 +684,13 @@ def explain_model(self):
683684
)
684685
ratio = SpeedAccuracyMode.ratio[self.spec.explanations_accuracy_mode]
685686

687+
# validate the automlx mode is use for automlx model
688+
self._validate_automlx_explanation_mode()
689+
686690
for s_id, data_i in self.datasets.get_data_by_series(
687691
include_horizon=False
688692
).items():
689-
if (
690-
self.spec.model == SupportedModels.AutoMLX
691-
and self.spec.explanations_accuracy_mode == SpeedAccuracyMode.AUTOMLX
692-
):
693-
import automlx
694-
695-
explainer = automlx.MLExplainer(
696-
self.models[s_id],
697-
self.datasets.additional_data.get_data_for_series(series_id=s_id)
698-
.drop(self.spec.datetime_column.name, axis=1)
699-
.head(-self.spec.horizon)
700-
if self.spec.additional_data
701-
else None,
702-
pd.DataFrame(data_i[self.spec.target_column]),
703-
task="forecasting",
704-
)
705-
706-
explanations = explainer.explain_prediction(
707-
X=self.datasets.additional_data.get_data_for_series(series_id=s_id)
708-
.drop(self.spec.datetime_column.name, axis=1)
709-
.tail(self.spec.horizon)
710-
if self.spec.additional_data
711-
else None,
712-
forecast_timepoints=list(range(self.spec.horizon + 1)),
713-
)
714-
715-
explanations_df = pd.concat(
716-
[exp.to_dataframe() for exp in explanations]
717-
)
718-
explanations_df["row"] = explanations_df.groupby("Feature").cumcount()
719-
explanations_df = explanations_df.pivot(
720-
index="row", columns="Feature", values="Attribution"
721-
)
722-
explanations_df = explanations_df.reset_index(drop=True)
723-
self.local_explanation[s_id] = explanations_df
724-
elif (
725-
self.spec.explanations_accuracy_mode == SpeedAccuracyMode.AUTOMLX
726-
and self.spec.model != SupportedModels.AutoMLX
727-
):
728-
raise ValueError(
729-
"AUTOMLX explanation accuracy mode is only supported for AutoMLX models. "
730-
"Please select mode other than AUTOMLX from the available explanations_accuracy_mode options"
731-
)
732-
elif s_id in self.models:
693+
if s_id in self.models:
733694
explain_predict_fn = self.get_explain_predict_fn(series_id=s_id)
734695
data_trimmed = data_i.tail(
735696
max(int(len(data_i) * ratio), 5)

0 commit comments

Comments
 (0)