Skip to content

Commit 89a1f9c

Browse files
authored
Merge branch 'main' into ahosler-patch-1
2 parents 24d6fbe + 1e99804 commit 89a1f9c

File tree

7 files changed

+56
-45
lines changed

7 files changed

+56
-45
lines changed

ads/opctl/operator/lowcode/forecast/model/arima.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,10 @@ def _train_model(self, i, s_id, df, model_kwargs):
116116
lower_bound=self.get_horizon(forecast["yhat_lower"]).values,
117117
)
118118

119-
self.models[s_id] = model
119+
self.models[s_id] = {}
120+
self.models[s_id]["model"] = model
121+
self.models[s_id]["le"] = self.le[s_id]
122+
self.models[s_id]["predict_component_cols"] = X_pred.columns
120123

121124
params = vars(model).copy()
122125
for param in ["arima_res_", "endog_index_"]:
@@ -163,7 +166,7 @@ def _generate_report(self):
163166
sec5_text = rc.Heading("ARIMA Model Parameters", level=2)
164167
blocks = [
165168
rc.Html(
166-
m.summary().as_html(),
169+
m['model'].summary().as_html(),
167170
label=s_id if self.target_cat_col else None,
168171
)
169172
for i, (s_id, m) in enumerate(self.models.items())
@@ -251,7 +254,7 @@ def _generate_report(self):
251254
def get_explain_predict_fn(self, series_id):
252255
def _custom_predict(
253256
data,
254-
model=self.models[series_id],
257+
model=self.models[series_id]["model"],
255258
dt_column_name=self.datasets._datetime_column_name,
256259
target_col=self.original_target_column,
257260
):

ads/opctl/operator/lowcode/forecast/model/automlx.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ def set_kwargs(self):
5656
)
5757
return model_kwargs_cleaned, time_budget
5858

59-
def preprocess(self, data): # TODO: re-use self.le for explanations
60-
_, df_encoded = _label_encode_dataframe(
59+
def preprocess(self, data, series_id): # TODO: re-use self.le for explanations
60+
self.le[series_id], df_encoded = _label_encode_dataframe(
6161
data,
6262
no_encode={self.spec.datetime_column.name, self.original_target_column},
6363
)
@@ -124,7 +124,7 @@ def _build_model(self) -> pd.DataFrame:
124124
self.forecast_output.init_series_output(
125125
series_id=s_id, data_at_series=df
126126
)
127-
data = self.preprocess(df)
127+
data = self.preprocess(df, s_id)
128128
data_i = self.drop_horizon(data)
129129
X_pred = self.get_horizon(data).drop(target, axis=1)
130130

@@ -156,7 +156,9 @@ def _build_model(self) -> pd.DataFrame:
156156
target
157157
].values
158158

159-
self.models[s_id] = model
159+
self.models[s_id] = {}
160+
self.models[s_id]["model"] = model
161+
self.models[s_id]["le"] = self.le[s_id]
160162

161163
# In case of Naive model, model.forecast function call does not return confidence intervals.
162164
if f"{target}_ci_upper" not in summary_frame:
@@ -217,7 +219,8 @@ def _generate_report(self):
217219
other_sections = []
218220

219221
if len(self.models) > 0:
220-
for s_id, m in models.items():
222+
for s_id, artifacts in models.items():
223+
m = artifacts["model"]
221224
selected_models[s_id] = {
222225
"series_id": s_id,
223226
"selected_model": m.selected_model_,
@@ -323,7 +326,7 @@ def _generate_report(self):
323326
)
324327

325328
def get_explain_predict_fn(self, series_id):
326-
selected_model = self.models[series_id]
329+
selected_model = self.models[series_id]["model"]
327330

328331
# If training date, use method below. If future date, use forecast!
329332
def _custom_predict_fn(
@@ -341,12 +344,12 @@ def _custom_predict_fn(
341344
data[dt_column_name] = seconds_to_datetime(
342345
data[dt_column_name], dt_format=self.spec.datetime_column.format
343346
)
344-
data = self.preprocess(data)
347+
data = self.preprocess(data, series_id)
345348
horizon_data = horizon_data.drop(target_col, axis=1)
346349
horizon_data[dt_column_name] = seconds_to_datetime(
347350
horizon_data[dt_column_name], dt_format=self.spec.datetime_column.format
348351
)
349-
horizon_data = self.preprocess(horizon_data)
352+
horizon_data = self.preprocess(horizon_data, series_id)
350353

351354
rows = []
352355
for i in range(data.shape[0]):
@@ -424,10 +427,8 @@ def explain_model(self):
424427
if self.spec.explanations_accuracy_mode == SpeedAccuracyMode.AUTOMLX:
425428
# Use the MLExplainer class from AutoMLx to generate explanations
426429
explainer = automlx.MLExplainer(
427-
self.models[s_id],
428-
self.datasets.additional_data.get_data_for_series(
429-
series_id=s_id
430-
)
430+
self.models[s_id]["model"],
431+
self.datasets.additional_data.get_data_for_series(series_id=s_id)
431432
.drop(self.spec.datetime_column.name, axis=1)
432433
.head(-self.spec.horizon)
433434
if self.spec.additional_data

ads/opctl/operator/lowcode/forecast/model/base_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -828,7 +828,7 @@ def local_explainer(self, kernel_explainer, series_id, datetime_col_name) -> Non
828828
def get_explain_predict_fn(self, series_id, fcst_col_name="yhat"):
829829
def _custom_predict(
830830
data,
831-
model=self.models[series_id],
831+
model=self.models[series_id]["model"],
832832
dt_column_name=self.datasets._datetime_column_name,
833833
):
834834
"""

ads/opctl/operator/lowcode/forecast/model/neuralprophet.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,10 @@ def _train_model(self, i, s_id, df, model_kwargs):
172172
).values,
173173
)
174174

175-
self.models[s_id] = model
176175
self.trainers[s_id] = model.trainer
176+
self.models[s_id] = {}
177+
self.models[s_id]["model"] = model
178+
self.models[s_id]["le"] = self.le[s_id]
177179

178180
self.model_parameters[s_id] = {
179181
"framework": SupportedModels.NeuralProphet,
@@ -355,7 +357,8 @@ def _generate_report(self):
355357

356358
sec5_text = rc.Heading("Neural Prophet Model Parameters", level=2)
357359
model_states = []
358-
for s_id, m in self.models.items():
360+
for s_id, artifacts in self.models.items():
361+
m = artifacts["model"]
359362
model_states.append(
360363
pd.Series(
361364
m.state_dict(),

ads/opctl/operator/lowcode/forecast/model/prophet.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,10 @@ def _train_model(self, i, series_id, df, model_kwargs):
112112
upper_bound=self.get_horizon(forecast["yhat_upper"]).values,
113113
lower_bound=self.get_horizon(forecast["yhat_lower"]).values,
114114
)
115-
self.models[series_id] = model
115+
116+
self.models[series_id] = {}
117+
self.models[series_id]["model"] = model
118+
self.models[series_id]["le"] = self.le[series_id]
116119

117120
params = vars(model).copy()
118121
for param in ["history", "history_dates", "stan_fit"]:
@@ -256,7 +259,7 @@ def _generate_report(self):
256259
all_sections = []
257260
if len(series_ids) > 0:
258261
sec1 = _select_plot_list(
259-
lambda s_id: self.models[s_id].plot(
262+
lambda s_id: self.models[s_id]["model"].plot(
260263
self.outputs[s_id], include_legend=True
261264
),
262265
series_ids=series_ids,
@@ -271,7 +274,7 @@ def _generate_report(self):
271274
)
272275

273276
sec2 = _select_plot_list(
274-
lambda s_id: self.models[s_id].plot_components(self.outputs[s_id]),
277+
lambda s_id: self.models[s_id]["model"].plot_components(self.outputs[s_id]),
275278
series_ids=series_ids,
276279
target_category_column=self.target_cat_col,
277280
)
@@ -280,11 +283,11 @@ def _generate_report(self):
280283
)
281284

282285
sec3_figs = {
283-
s_id: self.models[s_id].plot(self.outputs[s_id]) for s_id in series_ids
286+
s_id: self.models[s_id]["model"].plot(self.outputs[s_id]) for s_id in series_ids
284287
}
285288
for s_id in series_ids:
286289
add_changepoints_to_plot(
287-
sec3_figs[s_id].gca(), self.models[s_id], self.outputs[s_id]
290+
sec3_figs[s_id].gca(), self.models[s_id]["model"], self.outputs[s_id]
288291
)
289292
sec3 = _select_plot_list(
290293
lambda s_id: sec3_figs[s_id],
@@ -298,7 +301,7 @@ def _generate_report(self):
298301
sec5_text = rc.Heading("Prophet Model Seasonality Components", level=2)
299302
model_states = []
300303
for s_id in series_ids:
301-
m = self.models[s_id]
304+
m = self.models[s_id]["model"]
302305
model_states.append(
303306
pd.Series(
304307
m.seasonalities,

ads/opctl/operator/lowcode/forecast/whatifserve/score.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -151,34 +151,42 @@ def get_forecast(future_df, model_name, series_id, model_object, date_col, targe
151151
pred_obj = model_object.predict(future_regressor=future_reg)
152152
return pred_obj.forecast[series_id].tolist()
153153
elif model_name == SupportedModels.Prophet and series_id in model_object:
154-
model = model_object[series_id]
154+
model = model_object[series_id]['model']
155+
label_encoder = model_object[series_id]['le']
155156
processed = future_df.rename(columns={date_col_name: 'ds', target_column: 'y'})
156-
forecast = model.predict(processed)
157+
encoded_df = label_encoder.transform(processed)
158+
forecast = model.predict(encoded_df)
157159
return forecast['yhat'].tolist()
158160
elif model_name == SupportedModels.NeuralProphet and series_id in model_object:
159-
model = model_object[series_id]
161+
model = model_object[series_id]['model']
162+
label_encoder = model_object[series_id]['le']
160163
model.restore_trainer()
161164
accepted_regressors = list(model.config_regressors.regressors.keys())
162165
data = future_df.rename(columns={date_col_name: 'ds', target_column: 'y'})
163-
future = data[accepted_regressors + ["ds"]].reset_index(drop=True)
166+
encoded_df = label_encoder.transform(data)
167+
future = encoded_df[accepted_regressors + ["ds"]].reset_index(drop=True)
164168
future["y"] = None
165169
forecast = model.predict(future)
166170
return forecast['yhat1'].tolist()
167171
elif model_name == SupportedModels.Arima and series_id in model_object:
168-
model = model_object[series_id]
169-
future_df = future_df.set_index(date_col_name)
170-
x_pred = future_df.drop(target_cat_col, axis=1)
172+
model = model_object[series_id]['model']
173+
label_encoder = model_object[series_id]['le']
174+
predict_cols = model_object[series_id]["predict_component_cols"]
175+
encoded_df = label_encoder.transform(future_df)
176+
x_pred = encoded_df.set_index(date_col_name)
177+
x_pred = x_pred.drop(target_cat_col, axis=1)
171178
yhat, conf_int = model.predict(
172179
n_periods=horizon,
173-
X=x_pred,
180+
X=x_pred[predict_cols],
174181
return_conf_int=True
175182
)
176183
yhat_clean = pd.DataFrame(yhat, index=yhat.index, columns=["yhat"])
177184
return yhat_clean['yhat'].tolist()
178185
elif model_name == SupportedModels.AutoMLX and series_id in model_object:
179-
# automlx model
180-
model = model_object[series_id]
181-
x_pred = future_df.drop(target_cat_col, axis=1)
186+
model = model_object[series_id]['model']
187+
label_encoder = model_object[series_id]['le']
188+
encoded_df = label_encoder.transform(future_df)
189+
x_pred = encoded_df.drop(target_cat_col, axis=1)
182190
x_pred = x_pred.set_index(date_col_name)
183191
forecast = model.forecast(
184192
X=x_pred,

tests/operators/forecast/test_errors.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -835,16 +835,9 @@ def test_what_if_analysis(operator_setup, model):
835835
historical_data = pd.read_csv(historical_data_path, parse_dates=["Date"])
836836
historical_filtered = historical_data[historical_data["Date"] > "2013-03-01"]
837837
additional_data = pd.read_csv(additional_data_path, parse_dates=["Date"])
838-
add_filtered = additional_data[additional_data["Date"] > "2013-03-01"]
839-
numeric_columns = add_filtered.select_dtypes(
840-
include=["number", "object", "datetime64"]
841-
)
842-
non_constant_columns = numeric_columns.columns[
843-
(numeric_columns != numeric_columns.iloc[0]).any()
844-
]
845-
df_non_constant = numeric_columns[non_constant_columns.union(["Store"])]
846-
df_non_constant.to_csv(f"{additional_test_path}", index=False)
847-
historical_filtered.to_csv(f"{historical_test_path}", index=False)
838+
add_filtered = additional_data[additional_data['Date'] > "2013-03-01"]
839+
add_filtered.to_csv(f'{additional_test_path}', index=False)
840+
historical_filtered.to_csv(f'{historical_test_path}', index=False)
848841

849842
yaml_i, output_data_path = populate_yaml(
850843
tmpdirname=tmpdirname,

0 commit comments

Comments
 (0)