Skip to content

Commit 990bdd6

Browse files
codeloopahosler
authored andcommitted
add support for mlforecast in the forecasting operator
1 parent 4b3775c commit 990bdd6

File tree

6 files changed

+175
-0
lines changed

6 files changed

+175
-0
lines changed

ads/opctl/operator/lowcode/forecast/const.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ class SupportedModels(str, metaclass=ExtendedEnumMeta):
1414
Prophet = "prophet"
1515
Arima = "arima"
1616
NeuralProphet = "neuralprophet"
17+
MLForecast = "mlforecast"
1718
AutoMLX = "automlx"
1819
AutoTS = "autots"
1920
Auto = "auto"

ads/opctl/operator/lowcode/forecast/environment.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ dependencies:
88
- oracle-ads>=2.9.0
99
- prophet
1010
- neuralprophet
11+
- mlforecast
1112
- pmdarima
1213
- statsmodels
1314
- report-creator

ads/opctl/operator/lowcode/forecast/model/factory.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from .base_model import ForecastOperatorBaseModel
1313
from .neuralprophet import NeuralProphetOperatorModel
1414
from .prophet import ProphetOperatorModel
15+
from .ml_forecast import MLForecastOperatorModel
1516
from ..utils import select_auto_model
1617
from .forecast_datasets import ForecastDatasets
1718

@@ -32,6 +33,7 @@ class ForecastOperatorModelFactory:
3233
SupportedModels.Prophet: ProphetOperatorModel,
3334
SupportedModels.Arima: ArimaOperatorModel,
3435
SupportedModels.NeuralProphet: NeuralProphetOperatorModel,
36+
SupportedModels.MLForecast: MLForecastOperatorModel,
3537
SupportedModels.AutoMLX: AutoMLXOperatorModel,
3638
SupportedModels.AutoTS: AutoTSOperatorModel
3739
}

ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ def __init__(self, config: ForecastOperatorConfig):
135135

136136
self._horizon = config.spec.horizon
137137
self._datetime_column_name = config.spec.datetime_column.name
138+
self._target_col = config.spec.target_column
138139
self._load_data(config.spec)
139140

140141
def _load_data(self, spec):
@@ -158,6 +159,15 @@ def get_all_data_long(self, include_horizon=True):
158159
on=[self._datetime_column_name, ForecastOutputColumns.SERIES],
159160
).reset_index()
160161

162+
def get_all_data_long_test(self):
163+
test_data = pd.merge(
164+
self.historical_data.data,
165+
self.additional_data.data,
166+
how="outer",
167+
on=[self._datetime_column_name, ForecastOutputColumns.SERIES],
168+
).reset_index()
169+
return test_data[test_data[self._target_col].isnull()].reset_index(drop=True)
170+
161171
def get_data_multi_indexed(self):
162172
return pd.concat(
163173
[
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
import pandas as pd
2+
import numpy as np
3+
4+
from ads.opctl import logger
5+
from ads.common.decorator import runtime_dependency
6+
from .base_model import ForecastOperatorBaseModel
7+
from .forecast_datasets import ForecastDatasets, ForecastOutput
8+
from ..operator_config import ForecastOperatorConfig
9+
from ..const import ForecastOutputColumns, SupportedModels
10+
11+
12+
class MLForecastOperatorModel(ForecastOperatorBaseModel):
13+
"""Class representing MLForecast operator model."""
14+
15+
def __init__(self, config: ForecastOperatorConfig, datasets: ForecastDatasets):
16+
super().__init__(config=config, datasets=datasets)
17+
self.global_explanation = {}
18+
self.local_explanation = {}
19+
self.formatted_global_explanation = None
20+
self.formatted_local_explanation = None
21+
22+
def set_kwargs(self):
23+
"""
24+
Returns the model parameters.
25+
"""
26+
model_kwargs = self.spec.model_kwargs
27+
28+
uppper_quantile = round(0.5 + self.spec.confidence_interval_width / 2, 2)
29+
lower_quantile = round(0.5 - self.spec.confidence_interval_width / 2, 2)
30+
31+
model_kwargs["lower_quantile"] = lower_quantile
32+
model_kwargs["uppper_quantile"] = uppper_quantile
33+
return model_kwargs
34+
35+
def preprocess(self, df, series_id):
36+
pass
37+
38+
@runtime_dependency(
39+
module="mlforecast",
40+
err_msg="MLForecast is not installed, please install it with 'pip install mlforecast'",
41+
)
42+
@runtime_dependency(
43+
module="lightgbm",
44+
err_msg="lightgbm is not installed, please install it with 'pip install lightgbm'",
45+
)
46+
def _train_model(self, data_train, data_test, model_kwargs):
47+
try:
48+
49+
import lightgbm as lgb
50+
from mlforecast import MLForecast
51+
from mlforecast.lag_transforms import ExpandingMean, RollingMean
52+
from mlforecast.target_transforms import Differences
53+
54+
lgb_params = {
55+
"verbosity": -1,
56+
"num_leaves": 512,
57+
}
58+
59+
fcst = MLForecast(
60+
models={
61+
"forecast": lgb.LGBMRegressor(**lgb_params),
62+
# "p" + str(int(model_kwargs["uppper_quantile"] * 100))
63+
"upper": lgb.LGBMRegressor(
64+
**lgb_params,
65+
objective="quantile",
66+
alpha=model_kwargs["uppper_quantile"],
67+
),
68+
# "p" + str(int(model_kwargs["lower_quantile"] * 100))
69+
"lower": lgb.LGBMRegressor(
70+
**lgb_params,
71+
objective="quantile",
72+
alpha=model_kwargs["lower_quantile"],
73+
),
74+
},
75+
freq=pd.infer_freq(data_train.Date.drop_duplicates()),
76+
target_transforms=[Differences([12])],
77+
lags=model_kwargs.get("lags", [1, 6, 12]),
78+
lag_transforms={
79+
1: [ExpandingMean()],
80+
12: [RollingMean(window_size=24)],
81+
},
82+
# date_features=[hour_index],
83+
)
84+
85+
num_models = model_kwargs.get("recursive_models", False)
86+
87+
fcst.fit(
88+
data_train,
89+
static_features=model_kwargs.get("static_features", []),
90+
id_col=ForecastOutputColumns.SERIES,
91+
time_col=self.spec.datetime_column.name,
92+
target_col=self.spec.target_column,
93+
fitted=True,
94+
max_horizon=None if num_models is False else self.spec.horizon,
95+
)
96+
97+
self.outputs = fcst.predict(
98+
h=self.spec.horizon,
99+
X_df=pd.concat(
100+
[
101+
data_test,
102+
fcst.get_missing_future(h=self.spec.horizon, X_df=data_test),
103+
],
104+
axis=0,
105+
ignore_index=True,
106+
).fillna(0),
107+
)
108+
fitted_values = fcst.forecast_fitted_values()
109+
for s_id in self.datasets.list_series_ids():
110+
self.forecast_output.init_series_output(
111+
series_id=s_id,
112+
data_at_series=self.datasets.get_data_at_series(s_id),
113+
)
114+
115+
self.forecast_output.populate_series_output(
116+
series_id=s_id,
117+
fit_val=fitted_values[
118+
fitted_values[ForecastOutputColumns.SERIES] == s_id
119+
].forecast.values,
120+
forecast_val=self.outputs[
121+
self.outputs[ForecastOutputColumns.SERIES] == s_id
122+
].forecast.values,
123+
upper_bound=self.outputs[
124+
self.outputs[ForecastOutputColumns.SERIES] == s_id
125+
].upper.values,
126+
lower_bound=self.outputs[
127+
self.outputs[ForecastOutputColumns.SERIES] == s_id
128+
].lower.values,
129+
)
130+
131+
self.model_parameters[s_id] = {
132+
"framework": SupportedModels.MLForecast,
133+
**lgb_params,
134+
}
135+
136+
logger.debug("===========Done===========")
137+
138+
return self.forecast_output.get_forecast_long()
139+
except Exception as e:
140+
self.errors_dict[self.spec.model] = {
141+
"model_name": self.spec.model,
142+
"error": str(e),
143+
}
144+
145+
def _build_model(self) -> pd.DataFrame:
146+
data_train = self.datasets.get_all_data_long(include_horizon=False)
147+
data_test = self.datasets.get_all_data_long_test()
148+
self.models = dict()
149+
model_kwargs = self.set_kwargs()
150+
self.forecast_output = ForecastOutput(
151+
confidence_interval_width=self.spec.confidence_interval_width,
152+
horizon=self.spec.horizon,
153+
target_column=self.original_target_column,
154+
dt_column=self.spec.datetime_column.name,
155+
)
156+
self._train_model(data_train, data_test, model_kwargs)
157+
pass
158+
159+
def _generate_report(self):
160+
pass

ads/opctl/operator/lowcode/forecast/schema.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,7 @@ spec:
370370
- prophet
371371
- arima
372372
- neuralprophet
373+
- mlforecast
373374
- automlx
374375
- autots
375376
- auto

0 commit comments

Comments
 (0)