Skip to content

Commit 02b9809

Browse files
committed
ODSC-60854: pass model params via kwargs, ruff formatting
1 parent 7f8b70d commit 02b9809

File tree

1 file changed

+16
-9
lines changed

1 file changed

+16
-9
lines changed

ads/opctl/operator/lowcode/forecast/model/ml_forecast.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*--
32

43
# Copyright (c) 2024 Oracle and/or its affiliates.
54
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6-
import pandas as pd
75
import numpy as np
6+
import pandas as pd
87

9-
from ads.opctl import logger
108
from ads.common.decorator import runtime_dependency
9+
from ads.opctl import logger
1110
from ads.opctl.operator.lowcode.forecast.utils import _select_plot_list
11+
12+
from ..const import ForecastOutputColumns, SupportedModels
13+
from ..operator_config import ForecastOperatorConfig
1214
from .base_model import ForecastOperatorBaseModel
1315
from .forecast_datasets import ForecastDatasets, ForecastOutput
14-
from ..operator_config import ForecastOperatorConfig
15-
from ..const import ForecastOutputColumns, SupportedModels
1616

1717

1818
class MLForecastOperatorModel(ForecastOperatorBaseModel):
@@ -58,18 +58,25 @@ def _train_model(self, data_train, data_test, model_kwargs):
5858
from mlforecast.target_transforms import Differences
5959

6060
lgb_params = {
61-
"verbosity": -1,
62-
"num_leaves": 512,
61+
"verbosity": model_kwargs.get("verbosity", -1),
62+
"num_leaves": model_kwargs.get("num_leaves", 512),
6363
}
6464
additional_data_params = {}
6565
if len(self.datasets.get_additional_data_column_names()) > 0:
6666
additional_data_params = {
67-
"target_transforms": [Differences([12])],
67+
"target_transforms": [
68+
Differences([model_kwargs.get("Differences", 12)])
69+
],
6870
"lags": model_kwargs.get("lags", [1, 6, 12]),
6971
"lag_transforms": (
7072
{
7173
1: [ExpandingMean()],
72-
12: [RollingMean(window_size=24)],
74+
12: [
75+
RollingMean(
76+
window_size=model_kwargs.get("RollingMean", 24),
77+
min_samples=1,
78+
)
79+
],
7380
}
7481
),
7582
}

0 commit comments

Comments
 (0)