Skip to content

Commit 4f189ef

Browse files
authored
ODSC-60854: pass model params via kwargs, ruff formatting (#922)
2 parents 1eecd92 + 0f105dc commit 4f189ef

File tree

6 files changed

+29
-22
lines changed

6 files changed

+29
-22
lines changed

ads/opctl/operator/lowcode/forecast/const.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ class SupportedModels(str, metaclass=ExtendedEnumMeta):
1414
Prophet = "prophet"
1515
Arima = "arima"
1616
NeuralProphet = "neuralprophet"
17-
MLForecast = "mlforecast"
17+
LGBForecast = "lgbforecast"
1818
AutoMLX = "automlx"
1919
AutoTS = "autots"
2020
Auto = "auto"

ads/opctl/operator/lowcode/forecast/model/factory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class ForecastOperatorModelFactory:
3333
SupportedModels.Prophet: ProphetOperatorModel,
3434
SupportedModels.Arima: ArimaOperatorModel,
3535
SupportedModels.NeuralProphet: NeuralProphetOperatorModel,
36-
SupportedModels.MLForecast: MLForecastOperatorModel,
36+
SupportedModels.LGBForecast: MLForecastOperatorModel,
3737
SupportedModels.AutoMLX: AutoMLXOperatorModel,
3838
SupportedModels.AutoTS: AutoTSOperatorModel
3939
}

ads/opctl/operator/lowcode/forecast/model/ml_forecast.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*--
32

43
# Copyright (c) 2024 Oracle and/or its affiliates.
54
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6-
import pandas as pd
75
import numpy as np
6+
import pandas as pd
87

9-
from ads.opctl import logger
108
from ads.common.decorator import runtime_dependency
9+
from ads.opctl import logger
1110
from ads.opctl.operator.lowcode.forecast.utils import _select_plot_list
11+
12+
from ..const import ForecastOutputColumns, SupportedModels
13+
from ..operator_config import ForecastOperatorConfig
1214
from .base_model import ForecastOperatorBaseModel
1315
from .forecast_datasets import ForecastDatasets, ForecastOutput
14-
from ..operator_config import ForecastOperatorConfig
15-
from ..const import ForecastOutputColumns, SupportedModels
1616

1717

1818
class MLForecastOperatorModel(ForecastOperatorBaseModel):
@@ -58,18 +58,25 @@ def _train_model(self, data_train, data_test, model_kwargs):
5858
from mlforecast.target_transforms import Differences
5959

6060
lgb_params = {
61-
"verbosity": -1,
62-
"num_leaves": 512,
61+
"verbosity": model_kwargs.get("verbosity", -1),
62+
"num_leaves": model_kwargs.get("num_leaves", 512),
6363
}
6464
additional_data_params = {}
6565
if len(self.datasets.get_additional_data_column_names()) > 0:
6666
additional_data_params = {
67-
"target_transforms": [Differences([12])],
67+
"target_transforms": [
68+
Differences([model_kwargs.get("Differences", 12)])
69+
],
6870
"lags": model_kwargs.get("lags", [1, 6, 12]),
6971
"lag_transforms": (
7072
{
7173
1: [ExpandingMean()],
72-
12: [RollingMean(window_size=24)],
74+
12: [
75+
RollingMean(
76+
window_size=model_kwargs.get("RollingMean", 24),
77+
min_samples=1,
78+
)
79+
],
7380
}
7481
),
7582
}
@@ -147,7 +154,7 @@ def _train_model(self, data_train, data_test, model_kwargs):
147154
)
148155

149156
self.model_parameters[s_id] = {
150-
"framework": SupportedModels.MLForecast,
157+
"framework": SupportedModels.LGBForecast,
151158
**lgb_params,
152159
}
153160

@@ -204,10 +211,10 @@ def _generate_report(self):
204211
self.datasets.list_series_ids(),
205212
)
206213

207-
# Section 2: MlForecast Model Parameters
214+
# Section 2: LGBForecast Model Parameters
208215
sec2_text = rc.Block(
209-
rc.Heading("MlForecast Model Parameters", level=2),
210-
rc.Text("These are the parameters used for the MlForecast model."),
216+
rc.Heading("LGBForecast Model Parameters", level=2),
217+
rc.Text("These are the parameters used for the LGBForecast model."),
211218
)
212219

213220
blocks = [
@@ -221,7 +228,7 @@ def _generate_report(self):
221228

222229
all_sections = [sec1_text, sec_1, sec2_text, sec_2]
223230
model_description = rc.Text(
224-
"mlforecast is a framework to perform time series forecasting using machine learning models"
231+
"LGBForecast uses mlforecast framework to perform time series forecasting using machine learning models"
225232
"with the option to scale to massive amounts of data using remote clusters."
226233
"Fastest implementations of feature engineering for time series forecasting in Python."
227234
"Support for exogenous variables and static covariates."

ads/opctl/operator/lowcode/forecast/schema.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ spec:
379379
- prophet
380380
- arima
381381
- neuralprophet
382-
- mlforecast
382+
- lgbforecast
383383
- automlx
384384
- autots
385385
- auto-select

tests/operators/forecast/test_datasets.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
"prophet",
3333
"neuralprophet",
3434
"autots",
35-
"mlforecast",
35+
"lgbforecast",
3636
"auto-select",
3737
]
3838

@@ -135,15 +135,15 @@ def test_load_datasets(model, data_details):
135135
if model == "automlx":
136136
yaml_i["spec"]["model_kwargs"] = {"time_budget": 2}
137137
if model == "auto-select":
138-
yaml_i["spec"]["model_kwargs"] = {"model_list": ['prophet', 'arima', 'mlforecast']}
138+
yaml_i["spec"]["model_kwargs"] = {"model_list": ['prophet', 'arima', 'lgbforecast']}
139139
if dataset_name == f'{DATASET_PREFIX}dataset4.csv':
140140
pytest.skip("Skipping dataset4 with auto-select") # todo:// ODSC-58584
141141

142142
run(yaml_i, backend="operator.local", debug=False)
143143
subprocess.run(f"ls -a {output_data_path}", shell=True)
144144
if yaml_i["spec"]["generate_explanations"] and model not in [
145145
"automlx",
146-
"mlforecast",
146+
"lgbforecast",
147147
"auto-select"
148148
]:
149149
verify_explanations(

tests/operators/forecast/test_errors.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@
141141
"prophet",
142142
"neuralprophet",
143143
"autots",
144-
"mlforecast",
144+
"lgbforecast",
145145
# "auto",
146146
]
147147

@@ -687,7 +687,7 @@ def test_arima_automlx_errors(operator_setup, model):
687687
in error_content["13"]["error"]
688688
), "Error message mismatch"
689689

690-
if model not in ["autots", "automlx", "mlforecast"]:
690+
if model not in ["autots", "automlx", "lgbforecast"]:
691691
global_fn = f"{tmpdirname}/results/global_explanation.csv"
692692
assert os.path.exists(
693693
global_fn

0 commit comments

Comments
 (0)