Skip to content

Commit 0af84eb

Browse files
committed
combine columns to match old strucutre
1 parent 909f7ae commit 0af84eb

File tree

2 files changed

+25
-4
lines changed

2 files changed

+25
-4
lines changed

ads/opctl/operator/lowcode/forecast/model/prophet.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,24 +122,41 @@ def _train_model(self, i, series_id, df, model_kwargs):
122122
upper_bound=self.get_horizon(forecast["yhat_upper"]).values,
123123
lower_bound=self.get_horizon(forecast["yhat_lower"]).values,
124124
)
125-
# Get all features that make up the forecast. Exclude CI (upper/lower) and drop yhat ([:-1])
125+
# Get all features that make up the forecast. Exclude CI (upper/lower)
126126
core_columns = forecast.columns[
127127
~forecast.columns.str.endswith("_lower")
128128
& ~forecast.columns.str.endswith("_upper")
129-
][:-1]
130-
core_columns = set(core_columns) - set(
129+
]
130+
core_columns = set(core_columns) - {
131131
"additive_terms",
132132
"extra_regressors_additive",
133133
"multiplicative_terms",
134134
"extra_regressors_multiplicative",
135135
"cap",
136136
"floor",
137+
"yhat",
138+
}
139+
combine_terms = list(
140+
core_columns.intersection(
141+
{
142+
"trend",
143+
"daily",
144+
"weekly",
145+
"yearly",
146+
"monthly",
147+
"holidays",
148+
"zeros",
149+
}
150+
)
137151
)
138-
self.explanations_info[series_id] = (
152+
153+
temp_df = (
139154
forecast[list(core_columns)]
140155
.rename({"ds": "Date"}, axis=1)
141156
.set_index("Date")
142157
)
158+
temp_df[self.spec.target_column] = temp_df[combine_terms].sum(axis=1)
159+
self.explanations_info[series_id] = temp_df.drop(combine_terms, axis=1)
143160

144161
self.models[series_id] = {}
145162
self.models[series_id]["model"] = model

tests/operators/forecast/test_explainers.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,10 @@ def test_explanations_filenames(model, num_series):
221221
operator_config.spec.local_explanation_filename = local_explanation_filename
222222

223223
results = forecast_operate(operator_config)
224+
assert (
225+
not results.get_global_explanations().empty
226+
), "Error generating Global Expl"
227+
assert not results.get_local_explanations().empty, "Error generating Local Expl"
224228

225229
global_explanation_path = os.path.join(
226230
output_directory, global_explanation_filename

0 commit comments

Comments
 (0)