Skip to content

Commit 3720e7d

Browse files
committed
update NP too
1 parent 0af84eb commit 3720e7d

File tree

1 file changed

+15
-10
lines changed

1 file changed

+15
-10
lines changed

ads/opctl/operator/lowcode/forecast/model/neuralprophet.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -150,12 +150,22 @@ def _train_model(self, i, s_id, df, model_kwargs):
150150
logger.debug(forecast.tail())
151151

152152
# TODO; could also extract trend and seasonality?
153-
cols_to_read = filter(
154-
lambda x: x.startswith("future_regressor"), forecast.columns
153+
cols_to_read = set(
154+
forecast.columns[forecast.columns.str.startswith("future_regressor")]
155+
+ ["ds", "trend"]
155156
)
156-
self.explanations_info[s_id] = (
157-
forecast[cols_to_read].rename({"ds": "Date"}, axis=1).set_index("Date")
157+
cols_to_read = cols_to_read - {
158+
"future_regressors_additive",
159+
"future_regressors_multiplicative",
160+
}
161+
combine_terms = cols_to_read - set(self.accepted_regressors[s_id])
162+
temp_df = (
163+
forecast[list(cols_to_read)]
164+
.rename({"ds": "Date"}, axis=1)
165+
.set_index("Date")
158166
)
167+
temp_df[self.spec.target_column] = temp_df[combine_terms].sum(axis=1)
168+
self.explanations_info[s_id] = temp_df.drop(combine_terms, axis=1)
159169

160170
self.outputs[s_id] = forecast
161171
self.forecast_output.populate_series_output(
@@ -457,19 +467,14 @@ def explain_model(self):
457467
for s_id, expl_df in self.explanations_info.items():
458468
expl_df = expl_df.rename(rename_cols, axis=1)
459469
# Local Expl
460-
self.local_explanation[s_id] = self.get_horizon(expl_df).drop(
461-
["future_regressors_additive"], axis=1
462-
)
470+
self.local_explanation[s_id] = self.get_horizon(expl_df)
463471
self.local_explanation[s_id]["Series"] = s_id
464472
self.local_explanation[s_id].index.rename(self.dt_column_name, inplace=True)
465473
# Global Expl
466474
g_expl = self.drop_horizon(expl_df).mean()
467475
g_expl.name = s_id
468476
global_expl.append(g_expl)
469477
self.global_explanation = pd.concat(global_expl, axis=1)
470-
self.global_explanation = self.global_explanation.drop(
471-
index=["future_regressors_additive"], axis=0
472-
)
473478
self.formatted_global_explanation = (
474479
self.global_explanation / self.global_explanation.sum(axis=0) * 100
475480
)

0 commit comments

Comments
 (0)