Skip to content

Commit 2d34167

Browse files
committed
skip check for global explanations, exception handling
1 parent fe011f7 commit 2d34167

File tree

2 files changed

+48
-48
lines changed

2 files changed

+48
-48
lines changed

ads/opctl/operator/lowcode/forecast/model/automlx.py

Lines changed: 39 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -413,41 +413,45 @@ def explain_model(self):
413413
for s_id, data_i in self.datasets.get_data_by_series(
414414
include_horizon=False
415415
).items():
416-
if self.spec.explanations_accuracy_mode == SpeedAccuracyMode.AUTOMLX:
417-
# Use the MLExplainer class from AutoMLx to generate explanations
418-
explainer = automlx.MLExplainer(
419-
self.models[s_id],
420-
self.datasets.additional_data.get_data_for_series(series_id=s_id)
421-
.drop(self.spec.datetime_column.name, axis=1)
422-
.head(-self.spec.horizon)
423-
if self.spec.additional_data
424-
else None,
425-
pd.DataFrame(data_i[self.spec.target_column]),
426-
task="forecasting",
427-
)
416+
try:
417+
if self.spec.explanations_accuracy_mode == SpeedAccuracyMode.AUTOMLX:
418+
# Use the MLExplainer class from AutoMLx to generate explanations
419+
explainer = automlx.MLExplainer(
420+
self.models[s_id],
421+
self.datasets.additional_data.get_data_for_series(series_id=s_id)
422+
.drop(self.spec.datetime_column.name, axis=1)
423+
.head(-self.spec.horizon)
424+
if self.spec.additional_data
425+
else None,
426+
pd.DataFrame(data_i[self.spec.target_column]),
427+
task="forecasting",
428+
)
428429

429-
# Generate explanations for the forecast
430-
explanations = explainer.explain_prediction(
431-
X=self.datasets.additional_data.get_data_for_series(series_id=s_id)
432-
.drop(self.spec.datetime_column.name, axis=1)
433-
.tail(self.spec.horizon)
434-
if self.spec.additional_data
435-
else None,
436-
forecast_timepoints=list(range(self.spec.horizon + 1)),
437-
)
430+
# Generate explanations for the forecast
431+
explanations = explainer.explain_prediction(
432+
X=self.datasets.additional_data.get_data_for_series(series_id=s_id)
433+
.drop(self.spec.datetime_column.name, axis=1)
434+
.tail(self.spec.horizon)
435+
if self.spec.additional_data
436+
else None,
437+
forecast_timepoints=list(range(self.spec.horizon + 1)),
438+
)
438439

439-
# Convert the explanations to a DataFrame
440-
explanations_df = pd.concat(
441-
[exp.to_dataframe() for exp in explanations]
442-
)
443-
explanations_df["row"] = explanations_df.groupby("Feature").cumcount()
444-
explanations_df = explanations_df.pivot(
445-
index="row", columns="Feature", values="Attribution"
446-
)
447-
explanations_df = explanations_df.reset_index(drop=True)
440+
# Convert the explanations to a DataFrame
441+
explanations_df = pd.concat(
442+
[exp.to_dataframe() for exp in explanations]
443+
)
444+
explanations_df["row"] = explanations_df.groupby("Feature").cumcount()
445+
explanations_df = explanations_df.pivot(
446+
index="row", columns="Feature", values="Attribution"
447+
)
448+
explanations_df = explanations_df.reset_index(drop=True)
448449

449-
# Store the explanations in the local_explanation dictionary
450-
self.local_explanation[s_id] = explanations_df
451-
else:
452-
# Fall back to the default explanation generation method
453-
super().explain_model()
450+
# Store the explanations in the local_explanation dictionary
451+
self.local_explanation[s_id] = explanations_df
452+
else:
453+
# Fall back to the default explanation generation method
454+
super().explain_model()
455+
except Exception as e:
456+
logger.warning(f"Failed to generate explanations for series {s_id} with error: {e}.")
457+
logger.debug(f"Full Traceback: {traceback.format_exc()}")

tests/operators/forecast/test_errors.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -591,6 +591,8 @@ def test_all_series_failure(model):
591591
yaml_i["spec"]["preprocessing"] = {"enabled": True, "steps": preprocessing_steps}
592592
if yaml_i["spec"].get("additional_data") is not None and model != "autots":
593593
yaml_i["spec"]["generate_explanations"] = True
594+
else:
595+
yaml_i["spec"]["generate_explanations"] = False
594596
if model == "autots":
595597
yaml_i["spec"]["model_kwargs"] = {"model_list": "superfast"}
596598
if model == "automlx":
@@ -700,21 +702,15 @@ def test_arima_automlx_errors(operator_setup, model):
700702
in error_content["13"]["error"]
701703
), "Error message mismatch"
702704

703-
if model not in ["autots", "automlx"]: # , "lgbforecast"
704-
global_fn = f"{tmpdirname}/results/global_explanation.csv"
705-
assert os.path.exists(
706-
global_fn
707-
), f"Global explanation file not found at {report_path}"
705+
if model not in ["autots"]: # , "lgbforecast"
706+
if yaml_i["spec"].get("explanations_accuracy_mode") != "AUTOMLX":
707+
global_fn = f"{tmpdirname}/results/global_explanation.csv"
708+
assert os.path.exists(global_fn), f"Global explanation file not found at {report_path}"
709+
assert not pd.read_csv(global_fn, index_col=0).empty
708710

709711
local_fn = f"{tmpdirname}/results/local_explanation.csv"
710-
assert os.path.exists(
711-
local_fn
712-
), f"Local explanation file not found at {report_path}"
713-
714-
glb_expl = pd.read_csv(global_fn, index_col=0)
715-
loc_expl = pd.read_csv(local_fn)
716-
assert not glb_expl.empty
717-
assert not loc_expl.empty
712+
assert os.path.exists(local_fn), f"Local explanation file not found at {report_path}"
713+
assert not pd.read_csv(local_fn).empty
718714

719715

720716
def test_smape_error():

0 commit comments

Comments
 (0)