Skip to content

Commit 3bd0b3d

Browse files
authored
Merge branch 'main' into dataflow_changes
2 parents 7a3d109 + 3e3a2a0 commit 3bd0b3d

File tree

5 files changed

+121
-22
lines changed

5 files changed

+121
-22
lines changed

ads/opctl/operator/lowcode/forecast/const.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,12 @@ class SpeedAccuracyMode(str, metaclass=ExtendedEnumMeta):
2727
HIGH_ACCURACY = "HIGH_ACCURACY"
2828
BALANCED = "BALANCED"
2929
FAST_APPROXIMATE = "FAST_APPROXIMATE"
30+
AUTOMLX = "AUTOMLX"
3031
ratio = {}
3132
ratio[HIGH_ACCURACY] = 1 # 100 % data used for generating explanations
3233
ratio[BALANCED] = 0.5 # 50 % data used for generating explanations
3334
ratio[FAST_APPROXIMATE] = 0 # constant
35+
ratio[AUTOMLX] = 0 # constant
3436

3537

3638
class SupportedMetrics(str, metaclass=ExtendedEnumMeta):

ads/opctl/operator/lowcode/forecast/model/automlx.py

Lines changed: 89 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from ads.opctl.operator.lowcode.forecast.const import (
1818
AUTOMLX_METRIC_MAP,
1919
ForecastOutputColumns,
20+
SpeedAccuracyMode,
2021
SupportedModels,
2122
)
2223
from ads.opctl.operator.lowcode.forecast.utils import _label_encode_dataframe
@@ -245,18 +246,18 @@ def _generate_report(self):
245246
# If the key is present, call the "explain_model" method
246247
self.explain_model()
247248

248-
# Convert the global explanation data to a DataFrame
249-
global_explanation_df = pd.DataFrame(self.global_explanation)
249+
global_explanation_section = None
250+
if self.spec.explanations_accuracy_mode != SpeedAccuracyMode.AUTOMLX:
251+
# Convert the global explanation data to a DataFrame
252+
global_explanation_df = pd.DataFrame(self.global_explanation)
250253

251-
self.formatted_global_explanation = (
252-
global_explanation_df / global_explanation_df.sum(axis=0) * 100
253-
)
254-
self.formatted_global_explanation = (
255-
self.formatted_global_explanation.rename(
254+
self.formatted_global_explanation = (
255+
global_explanation_df / global_explanation_df.sum(axis=0) * 100
256+
)
257+
self.formatted_global_explanation = self.formatted_global_explanation.rename(
256258
{self.spec.datetime_column.name: ForecastOutputColumns.DATE},
257259
axis=1,
258260
)
259-
)
260261

261262
aggregate_local_explanations = pd.DataFrame()
262263
for s_id, local_ex_df in self.local_explanation.items():
@@ -297,8 +298,11 @@ def _generate_report(self):
297298
)
298299

299300
# Append the global explanation text and section to the "other_sections" list
301+
if global_explanation_section:
302+
other_sections.append(global_explanation_section)
303+
304+
# Append the local explanation text and section to the "other_sections" list
300305
other_sections = other_sections + [
301-
global_explanation_section,
302306
local_explanation_section,
303307
]
304308
except Exception as e:
@@ -379,3 +383,79 @@ def _custom_predict_automlx(self, data):
379383
return self.models.get(self.series_id).forecast(
380384
X=data_temp, periods=data_temp.shape[0]
381385
)[self.series_id]
386+
387+
@runtime_dependency(
388+
module="automlx",
389+
err_msg=(
390+
"Please run `python3 -m pip install automlx` to install the required dependencies for model explanation."
391+
),
392+
)
393+
def explain_model(self):
394+
"""
395+
Generates explanations for the model using the AutoMLx library.
396+
397+
Parameters
398+
----------
399+
None
400+
401+
Returns
402+
-------
403+
None
404+
405+
Notes
406+
-----
407+
This function works by generating local explanations for each series in the dataset.
408+
It uses the ``MLExplainer`` class from the AutoMLx library to generate feature attributions
409+
for each series. The feature attributions are then stored in the ``self.local_explanation`` dictionary.
410+
411+
If the accuracy mode is set to AutoMLX, it uses the AutoMLx library to generate explanations.
412+
Otherwise, it falls back to the default explanation generation method.
413+
"""
414+
import automlx
415+
416+
# Loop through each series in the dataset
417+
for s_id, data_i in self.datasets.get_data_by_series(
418+
include_horizon=False
419+
).items():
420+
try:
421+
if self.spec.explanations_accuracy_mode == SpeedAccuracyMode.AUTOMLX:
422+
# Use the MLExplainer class from AutoMLx to generate explanations
423+
explainer = automlx.MLExplainer(
424+
self.models[s_id],
425+
self.datasets.additional_data.get_data_for_series(series_id=s_id)
426+
.drop(self.spec.datetime_column.name, axis=1)
427+
.head(-self.spec.horizon)
428+
if self.spec.additional_data
429+
else None,
430+
pd.DataFrame(data_i[self.spec.target_column]),
431+
task="forecasting",
432+
)
433+
434+
# Generate explanations for the forecast
435+
explanations = explainer.explain_prediction(
436+
X=self.datasets.additional_data.get_data_for_series(series_id=s_id)
437+
.drop(self.spec.datetime_column.name, axis=1)
438+
.tail(self.spec.horizon)
439+
if self.spec.additional_data
440+
else None,
441+
forecast_timepoints=list(range(self.spec.horizon + 1)),
442+
)
443+
444+
# Convert the explanations to a DataFrame
445+
explanations_df = pd.concat(
446+
[exp.to_dataframe() for exp in explanations]
447+
)
448+
explanations_df["row"] = explanations_df.groupby("Feature").cumcount()
449+
explanations_df = explanations_df.pivot(
450+
index="row", columns="Feature", values="Attribution"
451+
)
452+
explanations_df = explanations_df.reset_index(drop=True)
453+
454+
# Store the explanations in the local_explanation dictionary
455+
self.local_explanation[s_id] = explanations_df
456+
else:
457+
# Fall back to the default explanation generation method
458+
super().explain_model()
459+
except Exception as e:
460+
logger.warning(f"Failed to generate explanations for series {s_id} with error: {e}.")
461+
logger.debug(f"Full Traceback: {traceback.format_exc()}")

ads/opctl/operator/lowcode/forecast/model/base_model.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
SpeedAccuracyMode,
5050
SupportedMetrics,
5151
SupportedModels,
52+
BACKTEST_REPORT_NAME,
5253
)
5354
from ..operator_config import ForecastOperatorConfig, ForecastOperatorSpec
5455
from .forecast_datasets import ForecastDatasets
@@ -665,6 +666,13 @@ def _save_model(self, output_dir, storage_options):
665666
storage_options=storage_options,
666667
)
667668

669+
def _validate_automlx_explanation_mode(self):
670+
if self.spec.model != SupportedModels.AutoMLX and self.spec.explanations_accuracy_mode == SpeedAccuracyMode.AUTOMLX:
671+
raise ValueError(
672+
"AUTOMLX explanation accuracy mode is only supported for AutoMLX models. "
673+
"Please select mode other than AUTOMLX from the available explanations_accuracy_mode options"
674+
)
675+
668676
@runtime_dependency(
669677
module="shap",
670678
err_msg=(
@@ -693,6 +701,9 @@ def explain_model(self):
693701
)
694702
ratio = SpeedAccuracyMode.ratio[self.spec.explanations_accuracy_mode]
695703

704+
# validate the automlx mode is use for automlx model
705+
self._validate_automlx_explanation_mode()
706+
696707
for s_id, data_i in self.datasets.get_data_by_series(
697708
include_horizon=False
698709
).items():
@@ -727,6 +738,14 @@ def explain_model(self):
727738
logger.warn(
728739
"No explanations generated. Ensure that additional data has been provided."
729740
)
741+
elif (
742+
self.spec.model == SupportedModels.AutoMLX
743+
and self.spec.explanations_accuracy_mode
744+
== SpeedAccuracyMode.AUTOMLX
745+
):
746+
logger.warning(
747+
"Global explanations not available for AutoMLX models with inherent explainability"
748+
)
730749
else:
731750
self.global_explanation[s_id] = dict(
732751
zip(

ads/opctl/operator/lowcode/forecast/schema.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,7 @@ spec:
344344
- HIGH_ACCURACY
345345
- BALANCED
346346
- FAST_APPROXIMATE
347+
- AUTOMLX
347348

348349
generate_report:
349350
type: boolean

tests/operators/forecast/test_errors.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -591,6 +591,8 @@ def test_all_series_failure(model):
591591
yaml_i["spec"]["preprocessing"] = {"enabled": True, "steps": preprocessing_steps}
592592
if yaml_i["spec"].get("additional_data") is not None and model != "autots":
593593
yaml_i["spec"]["generate_explanations"] = True
594+
else:
595+
yaml_i["spec"]["generate_explanations"] = False
594596
if model == "autots":
595597
yaml_i["spec"]["model_kwargs"] = {"model_list": "superfast"}
596598
if model == "automlx":
@@ -672,6 +674,7 @@ def test_arima_automlx_errors(operator_setup, model):
672674
yaml_i["spec"]["model_kwargs"] = {"model_list": "superfast"}
673675
if model == "automlx":
674676
yaml_i["spec"]["model_kwargs"] = {"time_budget": 1}
677+
yaml_i["spec"]["explanations_accuracy_mode"] = "AUTOMLX"
675678

676679
run_yaml(
677680
tmpdirname=tmpdirname,
@@ -699,21 +702,15 @@ def test_arima_automlx_errors(operator_setup, model):
699702
in error_content["13"]["error"]
700703
), "Error message mismatch"
701704

702-
if model not in ["autots", "automlx"]: # , "lgbforecast"
703-
global_fn = f"{tmpdirname}/results/global_explanation.csv"
704-
assert os.path.exists(
705-
global_fn
706-
), f"Global explanation file not found at {report_path}"
705+
if model not in ["autots"]: # , "lgbforecast"
706+
if yaml_i["spec"].get("explanations_accuracy_mode") != "AUTOMLX":
707+
global_fn = f"{tmpdirname}/results/global_explanation.csv"
708+
assert os.path.exists(global_fn), f"Global explanation file not found at {report_path}"
709+
assert not pd.read_csv(global_fn, index_col=0).empty
707710

708711
local_fn = f"{tmpdirname}/results/local_explanation.csv"
709-
assert os.path.exists(
710-
local_fn
711-
), f"Local explanation file not found at {report_path}"
712-
713-
glb_expl = pd.read_csv(global_fn, index_col=0)
714-
loc_expl = pd.read_csv(local_fn)
715-
assert not glb_expl.empty
716-
assert not loc_expl.empty
712+
assert os.path.exists(local_fn), f"Local explanation file not found at {report_path}"
713+
assert not pd.read_csv(local_fn).empty
717714

718715

719716
def test_smape_error():

0 commit comments

Comments
 (0)