Skip to content

Commit 5b4d6dd

Browse files
committed
automlx global explainer
1 parent 14b4b6c commit 5b4d6dd

File tree

2 files changed

+42
-20
lines changed

2 files changed

+42
-20
lines changed

ads/opctl/operator/lowcode/forecast/model/automlx.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,12 @@ def explain_model(self):
453453

454454
# Store the explanations in the local_explanation dictionary
455455
self.local_explanation[s_id] = explanations_df
456+
457+
self.global_explanation[s_id] = dict(zip(
458+
data_i.columns[1:],
459+
np.average(np.absolute(explanations_df[:, 1:]), axis=0),
460+
)
461+
)
456462
else:
457463
# Fall back to the default explanation generation method
458464
super().explain_model()

ads/opctl/operator/lowcode/forecast/model/base_model.py

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from ads.common.decorator.runtime_dependency import runtime_dependency
2020
from ads.common.object_storage_details import ObjectStorageDetails
2121
from ads.opctl import logger
22+
from ads.opctl.operator.lowcode.common.const import DataColumns
2223
from ads.opctl.operator.lowcode.common.utils import (
2324
datetime_to_seconds,
2425
disable_print,
@@ -28,7 +29,6 @@
2829
seconds_to_datetime,
2930
write_data,
3031
)
31-
from ads.opctl.operator.lowcode.common.const import DataColumns
3232
from ads.opctl.operator.lowcode.forecast.model.forecast_datasets import TestData
3333
from ads.opctl.operator.lowcode.forecast.utils import (
3434
_build_metrics_df,
@@ -49,7 +49,6 @@
4949
SpeedAccuracyMode,
5050
SupportedMetrics,
5151
SupportedModels,
52-
BACKTEST_REPORT_NAME,
5352
)
5453
from ..operator_config import ForecastOperatorConfig, ForecastOperatorSpec
5554
from .forecast_datasets import ForecastDatasets
@@ -127,8 +126,9 @@ def generate_report(self):
127126
if self.spec.generate_report or self.spec.generate_metrics:
128127
self.eval_metrics = self.generate_train_metrics()
129128
if not self.target_cat_col:
130-
self.eval_metrics.rename({"Series 1": self.original_target_column},
131-
axis=1, inplace=True)
129+
self.eval_metrics.rename(
130+
{"Series 1": self.original_target_column}, axis=1, inplace=True
131+
)
132132

133133
if self.spec.test_data:
134134
try:
@@ -140,8 +140,11 @@ def generate_report(self):
140140
elapsed_time=elapsed_time,
141141
)
142142
if not self.target_cat_col:
143-
self.test_eval_metrics.rename({"Series 1": self.original_target_column},
144-
axis=1, inplace=True)
143+
self.test_eval_metrics.rename(
144+
{"Series 1": self.original_target_column},
145+
axis=1,
146+
inplace=True,
147+
)
145148
except Exception:
146149
logger.warn("Unable to generate Test Metrics.")
147150
logger.debug(f"Full Traceback: {traceback.format_exc()}")
@@ -223,17 +226,23 @@ def generate_report(self):
223226
rc.Block(
224227
first_10_title,
225228
# series_subtext,
226-
rc.Select(blocks=first_5_rows_blocks) if self.target_cat_col else first_5_rows_blocks[0],
229+
rc.Select(blocks=first_5_rows_blocks)
230+
if self.target_cat_col
231+
else first_5_rows_blocks[0],
227232
),
228233
rc.Block(
229234
last_10_title,
230235
# series_subtext,
231-
rc.Select(blocks=last_5_rows_blocks) if self.target_cat_col else last_5_rows_blocks[0],
236+
rc.Select(blocks=last_5_rows_blocks)
237+
if self.target_cat_col
238+
else last_5_rows_blocks[0],
232239
),
233240
rc.Block(
234241
summary_title,
235242
# series_subtext,
236-
rc.Select(blocks=data_summary_blocks) if self.target_cat_col else data_summary_blocks[0],
243+
rc.Select(blocks=data_summary_blocks)
244+
if self.target_cat_col
245+
else data_summary_blocks[0],
237246
),
238247
rc.Separator(),
239248
)
@@ -308,7 +317,7 @@ def generate_report(self):
308317
horizon=self.spec.horizon,
309318
test_data=test_data,
310319
ci_interval_width=self.spec.confidence_interval_width,
311-
target_category_column=self.target_cat_col
320+
target_category_column=self.target_cat_col,
312321
)
313322
if (
314323
series_name is not None
@@ -491,7 +500,11 @@ def _save_report(
491500
f2.write(f1.read())
492501

493502
# forecast csv report
494-
result_df = result_df if self.target_cat_col else result_df.drop(DataColumns.Series, axis=1)
503+
result_df = (
504+
result_df
505+
if self.target_cat_col
506+
else result_df.drop(DataColumns.Series, axis=1)
507+
)
495508
write_data(
496509
data=result_df,
497510
filename=os.path.join(unique_output_dir, self.spec.forecast_filename),
@@ -667,7 +680,10 @@ def _save_model(self, output_dir, storage_options):
667680
)
668681

669682
def _validate_automlx_explanation_mode(self):
670-
if self.spec.model != SupportedModels.AutoMLX and self.spec.explanations_accuracy_mode == SpeedAccuracyMode.AUTOMLX:
683+
if (
684+
self.spec.model != SupportedModels.AutoMLX
685+
and self.spec.explanations_accuracy_mode == SpeedAccuracyMode.AUTOMLX
686+
):
671687
raise ValueError(
672688
"AUTOMLX explanation accuracy mode is only supported for AutoMLX models. "
673689
"Please select mode other than AUTOMLX from the available explanations_accuracy_mode options"
@@ -738,14 +754,14 @@ def explain_model(self):
738754
logger.warn(
739755
"No explanations generated. Ensure that additional data has been provided."
740756
)
741-
elif (
742-
self.spec.model == SupportedModels.AutoMLX
743-
and self.spec.explanations_accuracy_mode
744-
== SpeedAccuracyMode.AUTOMLX
745-
):
746-
logger.warning(
747-
"Global explanations not available for AutoMLX models with inherent explainability"
748-
)
757+
# elif (
758+
# self.spec.model == SupportedModels.AutoMLX
759+
# and self.spec.explanations_accuracy_mode
760+
# == SpeedAccuracyMode.AUTOMLX
761+
# ):
762+
# logger.warning(
763+
# "Global explanations not available for AutoMLX models with inherent explainability"
764+
# )
749765
else:
750766
self.global_explanation[s_id] = dict(
751767
zip(

0 commit comments

Comments
 (0)