Skip to content

Commit ebf73e7

Browse files
committed
polish NP code
1 parent 3720e7d commit ebf73e7

File tree

2 files changed

+31
-25
lines changed

2 files changed

+31
-25
lines changed

ads/opctl/operator/lowcode/forecast/model/neuralprophet.py

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -149,36 +149,42 @@ def _train_model(self, i, s_id, df, model_kwargs):
149149
logger.debug(f"-----------------Model {i}----------------------")
150150
logger.debug(forecast.tail())
151151

152-
# TODO; could also extract trend and seasonality?
153-
cols_to_read = set(
154-
forecast.columns[forecast.columns.str.startswith("future_regressor")]
155-
+ ["ds", "trend"]
156-
)
157-
cols_to_read = cols_to_read - {
158-
"future_regressors_additive",
159-
"future_regressors_multiplicative",
160-
}
161-
combine_terms = cols_to_read - set(self.accepted_regressors[s_id])
162-
temp_df = (
163-
forecast[list(cols_to_read)]
164-
.rename({"ds": "Date"}, axis=1)
165-
.set_index("Date")
166-
)
167-
temp_df[self.spec.target_column] = temp_df[combine_terms].sum(axis=1)
168-
self.explanations_info[s_id] = temp_df.drop(combine_terms, axis=1)
169-
170152
self.outputs[s_id] = forecast
153+
upper_bound_col_name = f"yhat1 {model_kwargs['quantiles'][1]*100}%"
154+
lower_bound_col_name = f"yhat1 {model_kwargs['quantiles'][0]*100}%"
171155
self.forecast_output.populate_series_output(
172156
series_id=s_id,
173157
fit_val=self.drop_horizon(forecast["yhat1"]).values,
174158
forecast_val=self.get_horizon(forecast["yhat1"]).values,
175-
upper_bound=self.get_horizon(
176-
forecast[f"yhat1 {model_kwargs['quantiles'][1]*100}%"]
177-
).values,
178-
lower_bound=self.get_horizon(
179-
forecast[f"yhat1 {model_kwargs['quantiles'][0]*100}%"]
180-
).values,
159+
upper_bound=self.get_horizon(forecast[upper_bound_col_name]).values,
160+
lower_bound=self.get_horizon(forecast[lower_bound_col_name]).values,
181161
)
162+
core_columns = set(forecast.columns) - set(
163+
[
164+
"y",
165+
"yhat1",
166+
upper_bound_col_name,
167+
lower_bound_col_name,
168+
"future_regressors_additive",
169+
"future_regressors_multiplicative",
170+
]
171+
)
172+
exog_variables = set(
173+
filter(lambda x: x.startswith("future_regressor_"), list(core_columns))
174+
)
175+
combine_terms = list(core_columns - exog_variables - set(["ds"]))
176+
temp_df = (
177+
forecast[list(core_columns)]
178+
.rename({"ds": "Date"}, axis=1)
179+
.set_index("Date")
180+
)
181+
if combine_terms:
182+
temp_df[self.spec.target_column] = temp_df[combine_terms].sum(axis=1)
183+
temp_df = temp_df.drop(combine_terms, axis=1)
184+
else:
185+
temp_df[self.spec.target_column] = 0
186+
# Todo: check for columns that were dropped, and set them to 0
187+
self.explanations_info[s_id] = temp_df
182188

183189
self.trainers[s_id] = model.trainer
184190
self.models[s_id] = {}

tests/operators/forecast/test_explainers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def test_explanations_output_and_columns(model, freq, num_series):
197197
), f"Column {column} missing in local explanations"
198198

199199

200-
@pytest.mark.parametrize("model", MODELS)
200+
@pytest.mark.parametrize("model", MODELS) # MODELS
201201
@pytest.mark.parametrize("num_series", [1])
202202
def test_explanations_filenames(model, num_series):
203203
"""

0 commit comments

Comments
 (0)