Skip to content

Commit 8380862

Browse files
Merge branch 'main' into ODSC-59433/incorrect_override_from_config
2 parents 2190848 + 9bbabf6 commit 8380862

File tree

1 file changed

+14
-18
lines changed

1 file changed

+14
-18
lines changed

ads/opctl/operator/lowcode/forecast/model/ml_forecast.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,18 @@ def _train_model(self, data_train, data_test, model_kwargs):
6161
"verbosity": -1,
6262
"num_leaves": 512,
6363
}
64+
additional_data_params = {}
65+
if len(self.datasets.get_additional_data_column_names()) > 0:
66+
additional_data_params = {
67+
"target_transforms": [Differences([12])],
68+
"lags": model_kwargs.get("lags", [1, 6, 12]),
69+
"lag_transforms": (
70+
{
71+
1: [ExpandingMean()],
72+
12: [RollingMean(window_size=24)],
73+
}
74+
),
75+
}
6476

6577
fcst = MLForecast(
6678
models={
@@ -80,24 +92,7 @@ def _train_model(self, data_train, data_test, model_kwargs):
8092
},
8193
freq=pd.infer_freq(data_train[self.date_col].drop_duplicates())
8294
or pd.infer_freq(data_train[self.date_col].drop_duplicates()[-5:]),
83-
target_transforms=[Differences([12])],
84-
lags=model_kwargs.get(
85-
"lags",
86-
(
87-
[1, 6, 12]
88-
if len(self.datasets.get_additional_data_column_names()) > 0
89-
else []
90-
),
91-
),
92-
lag_transforms=(
93-
{
94-
1: [ExpandingMean()],
95-
12: [RollingMean(window_size=24)],
96-
}
97-
if len(self.datasets.get_additional_data_column_names()) > 0
98-
else {}
99-
),
100-
# date_features=[hour_index],
95+
**additional_data_params,
10196
)
10297

10398
num_models = model_kwargs.get("recursive_models", False)
@@ -164,6 +159,7 @@ def _train_model(self, data_train, data_test, model_kwargs):
164159
"error": str(e),
165160
}
166161
logger.debug(f"Encountered Error: {e}. Skipping.")
162+
raise e
167163

168164
def _build_model(self) -> pd.DataFrame:
169165
data_train = self.datasets.get_all_data_long(include_horizon=False)

0 commit comments

Comments
 (0)