Skip to content

Commit e24dbf7

Browse files
committed
changes for error in explanation
1 parent bea07dd commit e24dbf7

File tree

4 files changed

+39
-37
lines changed

4 files changed

+39
-37
lines changed

ads/opctl/operator/lowcode/forecast/model/arima.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,9 @@ def _custom_predict(
245245
"""
246246
data: ForecastDatasets.get_data_at_series(s_id)
247247
"""
248+
if series_id in self.constant_cols:
249+
data = data.drop(columns=self.constant_cols[series_id])
250+
248251
data = data.drop([target_col], axis=1)
249252
data[dt_column_name] = seconds_to_datetime(
250253
data[dt_column_name], dt_format=self.spec.datetime_column.format

ads/opctl/operator/lowcode/forecast/model/automlx.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
seconds_to_datetime,
2323
datetime_to_seconds,
2424
)
25+
from ads.opctl.operator.lowcode.forecast.utils import _label_encode_dataframe
2526

2627
AUTOMLX_N_ALGOS_TUNED = 4
2728
AUTOMLX_DEFAULT_SCORE_METRIC = "neg_sym_mean_abs_percent_error"
@@ -51,8 +52,13 @@ def set_kwargs(self):
5152
] = self.spec.preprocessing or model_kwargs_cleaned.get("preprocessing", True)
5253
return model_kwargs_cleaned, time_budget
5354

54-
def preprocess(self, data, series_id=None):
55-
return data.set_index(self.spec.datetime_column.name)
55+
56+
def preprocess(self, data, series_id=None): # TODO: re-use self.le for explanations
57+
_, df_encoded = _label_encode_dataframe(
58+
data,
59+
no_encode={self.spec.datetime_column.name, self.original_target_column},
60+
)
61+
return df_encoded.set_index(self.spec.datetime_column.name)
5662

5763
@runtime_dependency(
5864
module="automlx",

ads/opctl/operator/lowcode/forecast/model/base_model.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def __init__(self, config: ForecastOperatorConfig, datasets: ForecastDatasets):
8888
self.formatted_local_explanation = None
8989

9090
self.forecast_col_name = "yhat"
91-
self.perform_tuning = self.spec.tuning != None
91+
self.perform_tuning = (self.spec.tuning != None) and (self.spec.tuning.n_trials != None)
9292

9393
def generate_report(self):
9494
"""Generates the forecasting report."""
@@ -657,20 +657,18 @@ def explain_model(self):
657657
if s_id in self.models:
658658

659659
explain_predict_fn = self.get_explain_predict_fn(series_id=s_id)
660-
if self.spec.model == SupportedModels.Arima and s_id in self.constant_cols:
661-
data_i = data_i.drop(columns=self.constant_cols[s_id])
662660
data_trimmed = data_i.tail(max(int(len(data_i) * ratio), 5)).reset_index(
663661
drop=True
664662
)
665663
data_trimmed[datetime_col_name] = data_trimmed[datetime_col_name].apply(
666664
lambda x: x.timestamp()
667665
)
668666

669-
# Explainer fails when boolean columns are passed for arima
670-
if self.spec.model == SupportedModels.Arima:
671-
_, data_trimmed_encoded = _label_encode_dataframe(
672-
data_trimmed, no_encode={datetime_col_name, self.original_target_column}
673-
)
667+
# Explainer fails when boolean columns are passed
668+
669+
_, data_trimmed_encoded = _label_encode_dataframe(
670+
data_trimmed, no_encode={datetime_col_name, self.original_target_column}
671+
)
674672

675673
kernel_explnr = PermutationExplainer(
676674
model=explain_predict_fn, masker=data_trimmed_encoded
@@ -716,16 +714,13 @@ def local_explainer(self, kernel_explainer, series_id, datetime_col_name) -> Non
716714
"""
717715
data = self.datasets.get_horizon_at_series(s_id=series_id)
718716
# columns that were dropped in train_model in arima, should be dropped here as well
719-
if self.spec.model == SupportedModels.Arima and series_id in self.constant_cols:
720-
data = data.drop(columns=self.constant_cols[series_id])
721717
data[datetime_col_name] = datetime_to_seconds(data[datetime_col_name])
722718
data = data.reset_index(drop=True)
723719

724-
# Explainer fails when boolean columns are passed for arima
725-
if self.spec.model == SupportedModels.Arima:
726-
_, data = _label_encode_dataframe(
727-
data, no_encode={datetime_col_name, self.original_target_column}
728-
)
720+
# Explainer fails when boolean columns are passed
721+
_, data = _label_encode_dataframe(
722+
data, no_encode={datetime_col_name, self.original_target_column}
723+
)
729724
# Generate local SHAP values using the kernel explainer
730725
local_kernel_explnr_vals = kernel_explainer.shap_values(data)
731726

tests/operators/forecast/test_errors.py

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -524,7 +524,7 @@ def test_all_series_failure(model):
524524
module_to_patch = {
525525
"arima": 'pmdarima.auto_arima',
526526
"autots": 'autots.AutoTS',
527-
"automlx": 'automl.Pipeline',
527+
"automlx": 'automlx.Pipeline',
528528
"prophet": 'prophet.Prophet',
529529
"neuralprophet": 'neuralprophet.NeuralProphet'
530530
}
@@ -551,7 +551,7 @@ def test_all_series_failure(model):
551551
local_fn = f"{tmpdirname}/results/local_explanation.csv"
552552
assert os.path.exists(local_fn), f"Local explanation file not found at {report_path}"
553553

554-
@pytest.mark.parametrize("model", ["arima", "automlx"])
554+
@pytest.mark.parametrize("model", MODELS)
555555
def test_arima_automlx_errors(operator_setup, model):
556556
tmpdirname = operator_setup
557557
historical_data_path, additional_data_path = setup_faulty_rossman()
@@ -572,14 +572,15 @@ def test_arima_automlx_errors(operator_setup, model):
572572
outputs get generated and that error is shown in errors.json
573573
"""
574574

575+
"""
576+
explanations generation is failing when boolean columns are passed. So we added label_encode before passing data to
577+
explainer
578+
"""
579+
575580
yaml_i['spec']['horizon'] = 10
576581
yaml_i['spec']['preprocessing'] = True
577582
yaml_i['spec']['generate_explanations'] = True
578583
yaml_i['spec']['model'] = model
579-
if model == "automlx":
580-
yaml_i['spec']['model_kwargs'] = {
581-
'model_list': ['ProphetForecaster']
582-
}
583584

584585
run_yaml(tmpdirname=tmpdirname, yaml_i=yaml_i, output_data_path=output_data_path, test_metrics_check=False)
585586

@@ -594,23 +595,24 @@ def test_arima_automlx_errors(operator_setup, model):
594595
error_path = f"{tmpdirname}/results/errors.json"
595596
if model == "arima":
596597
assert not os.path.exists(error_path), f"Error file not found at {error_path}"
597-
else:
598+
elif model == "automlx":
598599
assert os.path.exists(error_path), f"Error file not found at {error_path}"
599600
with open(error_path, 'r') as error_file:
600601
error_content = json.load(error_file)
601602
assert "Input data does not have a consistent (in terms of diff) DatetimeIndex." in error_content["13"][
602603
"error"], "Error message mismatch"
603604

604-
global_fn = f"{tmpdirname}/results/global_explanation.csv"
605-
assert os.path.exists(global_fn), f"Global explanation file not found at {report_path}"
605+
if model != "autots":
606+
global_fn = f"{tmpdirname}/results/global_explanation.csv"
607+
assert os.path.exists(global_fn), f"Global explanation file not found at {report_path}"
606608

607-
local_fn = f"{tmpdirname}/results/local_explanation.csv"
608-
assert os.path.exists(local_fn), f"Local explanation file not found at {report_path}"
609+
local_fn = f"{tmpdirname}/results/local_explanation.csv"
610+
assert os.path.exists(local_fn), f"Local explanation file not found at {report_path}"
609611

610-
glb_expl = pd.read_csv(global_fn, index_col=0)
611-
loc_expl = pd.read_csv(local_fn)
612-
assert not glb_expl.empty
613-
assert not loc_expl.empty
612+
glb_expl = pd.read_csv(global_fn, index_col=0)
613+
loc_expl = pd.read_csv(local_fn)
614+
assert not glb_expl.empty
615+
assert not loc_expl.empty
614616

615617

616618
def test_smape_error():
@@ -631,11 +633,7 @@ def test_date_format(operator_setup, model):
631633
yaml_i["spec"]["model"] = model
632634
if model == "autots":
633635
yaml_i["spec"]["model_kwargs"] = {"model_list": "superfast"}
634-
if model == "automlx":
635-
yaml_i['spec']['model_kwargs'] = {
636-
'model_list': ['ProphetForecaster'],
637-
"time_budget": 1
638-
}
636+
639637
run_yaml(tmpdirname=tmpdirname, yaml_i=yaml_i, output_data_path=output_data_path, test_metrics_check=False)
640638
assert pd.read_csv(additional_data_path)['Date'].equals(pd.read_csv(f"{tmpdirname}/results/forecast.csv")['Date'])
641639

0 commit comments

Comments
 (0)