Skip to content

Commit 1bf2bfe

Browse files
committed
add exception for other models in automlx mode and enable test for automlx
1 parent 6a76574 commit 1bf2bfe

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

ads/opctl/operator/lowcode/forecast/model/base_model.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -720,9 +720,15 @@ def explain_model(self):
720720
index="row", columns="Feature", values="Attribution"
721721
)
722722
explanations_df = explanations_df.reset_index(drop=True)
723-
# explanations_df[self.spec.datetime_column.name]=self.datasets.additional_data.get_data_for_series(series_id=s_id).tail(self.spec.horizon)[self.spec.datetime_column.name].reset_index(drop=True)
724723
self.local_explanation[s_id] = explanations_df
725-
724+
elif (
725+
self.spec.explanations_accuracy_mode == SpeedAccuracyMode.AUTOMLX
726+
and self.spec.model != SupportedModels.AutoMLX
727+
):
728+
raise ValueError(
729+
"AUTOMLX explanation accuracy mode is only supported for AutoMLX models. "
730+
"Please select mode other than AUTOMLX from the available explanations_accuracy_mode options"
731+
)
726732
elif s_id in self.models:
727733
explain_predict_fn = self.get_explain_predict_fn(series_id=s_id)
728734
data_trimmed = data_i.tail(

tests/operators/forecast/test_errors.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -659,6 +659,7 @@ def test_arima_automlx_errors(operator_setup, model):
659659
yaml_i["spec"]["model_kwargs"] = {"model_list": "superfast"}
660660
if model == "automlx":
661661
yaml_i["spec"]["model_kwargs"] = {"time_budget": 1}
662+
yaml_i["spec"]["explanations_accuracy_mode"] = "AUTOMLX"
662663

663664
run_yaml(
664665
tmpdirname=tmpdirname,

0 commit comments

Comments
 (0)