Skip to content

Commit 8a942c3

Browse files
committed
remove extra cols from explainability
1 parent 48b27b3 commit 8a942c3

File tree

2 files changed

+19
-5
lines changed

2 files changed

+19
-5
lines changed

ads/opctl/operator/lowcode/forecast/model/prophet.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,18 @@ def _train_model(self, i, series_id, df, model_kwargs):
127127
~forecast.columns.str.endswith("_lower")
128128
& ~forecast.columns.str.endswith("_upper")
129129
][:-1]
130+
core_columns = set(core_columns) - set(
131+
"additive_terms",
132+
"extra_regressors_additive",
133+
"multiplicative_terms",
134+
"extra_regressors_multiplicative",
135+
"cap",
136+
"floor",
137+
)
130138
self.explanations_info[series_id] = (
131-
forecast[core_columns].rename({"ds": "Date"}, axis=1).set_index("Date")
139+
forecast[list(core_columns)]
140+
.rename({"ds": "Date"}, axis=1)
141+
.set_index("Date")
132142
)
133143

134144
self.models[series_id] = {}

tests/operators/forecast/test_errors.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -969,9 +969,11 @@ def test_generate_files(operator_setup, model):
969969
yaml_i["spec"]["additional_data"]["data"] = df_add
970970
operator_config = ForecastOperatorConfig.from_dict(yaml_i)
971971
results = operate(operator_config)
972-
files = os.listdir(operator_setup)
972+
files = os.listdir(yaml_i["spec"]["output_directory"]["url"])
973973
if "errors.json" in files:
974-
with open(os.path.join(operator_setup, "errors.json")) as f:
974+
with open(
975+
os.path.join(yaml_i["spec"]["output_directory"]["url"], "errors.json")
976+
) as f:
975977
assert False, f"Failed due to errors.json being created: {f.read()}"
976978
assert "report.html" in files, "Failed to generate report"
977979
assert (
@@ -996,9 +998,11 @@ def test_generate_files(operator_setup, model):
996998
yaml_i["spec"].pop("generate_metrics_file")
997999
operator_config = ForecastOperatorConfig.from_dict(yaml_i)
9981000
results = operate(operator_config)
999-
files = os.listdir(operator_setup)
1001+
files = os.listdir(yaml_i["spec"]["output_directory"]["url"])
10001002
if "errors.json" in files:
1001-
with open(os.path.join(operator_setup, "errors.json")) as f:
1003+
with open(
1004+
os.path.join(yaml_i["spec"]["output_directory"]["url"], "errors.json")
1005+
) as f:
10021006
print(f"Errors in build! {f.read()}")
10031007
assert False, "Failed due to errors.json being created"
10041008
assert "report.html" in files, "Failed to generate report"

0 commit comments

Comments
 (0)