Skip to content

Commit d410494

Browse files
codeloopahosler
authored andcommitted
fix the unit tests
1 parent 40cdd3f commit d410494

File tree

4 files changed

+49
-20
lines changed

4 files changed

+49
-20
lines changed

ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,10 @@ def create_horizon(self, spec, historical_data):
8686
pd.date_range(
8787
start=historical_data.get_max_time(),
8888
periods=spec.horizon + 1,
89-
freq=historical_data.freq,
89+
freq=historical_data.freq
90+
or pd.infer_freq(
91+
historical_data.data.reset_index()[spec.datetime_column.name][-5:]
92+
),
9093
),
9194
name=spec.datetime_column.name,
9295
)

ads/opctl/operator/lowcode/forecast/model/ml_forecast.py

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -73,20 +73,35 @@ def _train_model(self, data_train, data_test, model_kwargs):
7373
alpha=model_kwargs["lower_quantile"],
7474
),
7575
},
76-
freq=pd.infer_freq(data_train.Date.drop_duplicates()),
76+
freq=pd.infer_freq(data_train["Date"].drop_duplicates())
77+
or pd.infer_freq(data_train["Date"].drop_duplicates()[-5:]),
7778
target_transforms=[Differences([12])],
78-
lags=model_kwargs.get("lags", [1, 6, 12]),
79-
lag_transforms={
80-
1: [ExpandingMean()],
81-
12: [RollingMean(window_size=24)],
82-
},
79+
lags=model_kwargs.get(
80+
"lags",
81+
(
82+
[1, 6, 12]
83+
if len(self.datasets.get_additional_data_column_names()) > 0
84+
else []
85+
),
86+
),
87+
lag_transforms=(
88+
{
89+
1: [ExpandingMean()],
90+
12: [RollingMean(window_size=24)],
91+
}
92+
if len(self.datasets.get_additional_data_column_names()) > 0
93+
else {}
94+
),
8395
# date_features=[hour_index],
8496
)
8597

8698
num_models = model_kwargs.get("recursive_models", False)
8799

100+
self.model_columns = [
101+
ForecastOutputColumns.SERIES
102+
] + data_train.select_dtypes(exclude=["object"]).columns.to_list()
88103
fcst.fit(
89-
data_train,
104+
data_train[self.model_columns],
90105
static_features=model_kwargs.get("static_features", []),
91106
id_col=ForecastOutputColumns.SERIES,
92107
time_col=self.spec.datetime_column.name,
@@ -99,8 +114,10 @@ def _train_model(self, data_train, data_test, model_kwargs):
99114
h=self.spec.horizon,
100115
X_df=pd.concat(
101116
[
102-
data_test,
103-
fcst.get_missing_future(h=self.spec.horizon, X_df=data_test),
117+
data_test[self.model_columns],
118+
fcst.get_missing_future(
119+
h=self.spec.horizon, X_df=data_test[self.model_columns]
120+
),
104121
],
105122
axis=0,
106123
ignore_index=True,
@@ -166,12 +183,16 @@ def _generate_report(self):
166183
# Section 1: Forecast Overview
167184
sec1_text = rc.Block(
168185
rc.Heading("Forecast Overview", level=2),
169-
rc.Text("These plots show your forecast in the context of historical data.")
186+
rc.Text(
187+
"These plots show your forecast in the context of historical data."
188+
),
170189
)
171190
sec_1 = _select_plot_list(
172191
lambda s_id: plot_series(
173192
self.datasets.get_all_data_long(include_horizon=False),
174-
pd.concat([self.fitted_values,self.outputs], axis=0, ignore_index=True),
193+
pd.concat(
194+
[self.fitted_values, self.outputs], axis=0, ignore_index=True
195+
),
175196
id_col=ForecastOutputColumns.SERIES,
176197
time_col=self.spec.datetime_column.name,
177198
target_col=self.original_target_column,
@@ -184,7 +205,7 @@ def _generate_report(self):
184205
# Section 2: MlForecast Model Parameters
185206
sec2_text = rc.Block(
186207
rc.Heading("MlForecast Model Parameters", level=2),
187-
rc.Text("These are the parameters used for the MlForecast model.")
208+
rc.Text("These are the parameters used for the MlForecast model."),
188209
)
189210

190211
blocks = [
@@ -197,9 +218,11 @@ def _generate_report(self):
197218
sec_2 = rc.Select(blocks=blocks)
198219

199220
all_sections = [sec1_text, sec_1, sec2_text, sec_2]
200-
model_description = rc.Text("mlforecast is a framework to perform time series forecasting using machine learning models"
201-
"with the option to scale to massive amounts of data using remote clusters."
202-
"Fastest implementations of feature engineering for time series forecasting in Python."
203-
"Support for exogenous variables and static covariates.")
221+
model_description = rc.Text(
222+
"mlforecast is a framework to perform time series forecasting using machine learning models"
223+
"with the option to scale to massive amounts of data using remote clusters."
224+
"Fastest implementations of feature engineering for time series forecasting in Python."
225+
"Support for exogenous variables and static covariates."
226+
)
204227

205-
return model_description, all_sections
228+
return model_description, all_sections

tests/operators/forecast/test_datasets.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,10 @@ def test_load_datasets(model, data_details):
137137

138138
run(yaml_i, backend="operator.local", debug=False)
139139
subprocess.run(f"ls -a {output_data_path}", shell=True)
140-
if yaml_i["spec"]["generate_explanations"] and model != "automlx":
140+
if yaml_i["spec"]["generate_explanations"] and model not in [
141+
"automlx",
142+
"mlforecast",
143+
]:
141144
verify_explanations(
142145
tmpdirname=tmpdirname,
143146
additional_cols=additional_cols,

tests/operators/forecast/test_errors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -687,7 +687,7 @@ def test_arima_automlx_errors(operator_setup, model):
687687
in error_content["13"]["error"]
688688
), "Error message mismatch"
689689

690-
if model not in ["autots", "automlx"]:
690+
if model not in ["autots", "automlx", "mlforecast"]:
691691
global_fn = f"{tmpdirname}/results/global_explanation.csv"
692692
assert os.path.exists(
693693
global_fn

0 commit comments

Comments
 (0)