|
1 | 1 | #!/usr/bin/env python
|
2 |
| -# -*- coding: utf-8 -*-- |
3 | 2 |
|
4 | 3 | # Copyright (c) 2024 Oracle and/or its affiliates.
|
5 | 4 | # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6 |
| -import pandas as pd |
7 | 5 | import numpy as np
|
| 6 | +import pandas as pd |
8 | 7 |
|
9 |
| -from ads.opctl import logger |
10 | 8 | from ads.common.decorator import runtime_dependency
|
| 9 | +from ads.opctl import logger |
11 | 10 | from ads.opctl.operator.lowcode.forecast.utils import _select_plot_list
|
| 11 | + |
| 12 | +from ..const import ForecastOutputColumns, SupportedModels |
| 13 | +from ..operator_config import ForecastOperatorConfig |
12 | 14 | from .base_model import ForecastOperatorBaseModel
|
13 | 15 | from .forecast_datasets import ForecastDatasets, ForecastOutput
|
14 |
| -from ..operator_config import ForecastOperatorConfig |
15 |
| -from ..const import ForecastOutputColumns, SupportedModels |
16 | 16 |
|
17 | 17 |
|
18 | 18 | class MLForecastOperatorModel(ForecastOperatorBaseModel):
|
@@ -58,18 +58,25 @@ def _train_model(self, data_train, data_test, model_kwargs):
|
58 | 58 | from mlforecast.target_transforms import Differences
|
59 | 59 |
|
60 | 60 | lgb_params = {
|
61 |
| - "verbosity": -1, |
62 |
| - "num_leaves": 512, |
| 61 | + "verbosity": model_kwargs.get("verbosity", -1), |
| 62 | + "num_leaves": model_kwargs.get("num_leaves", 512), |
63 | 63 | }
|
64 | 64 | additional_data_params = {}
|
65 | 65 | if len(self.datasets.get_additional_data_column_names()) > 0:
|
66 | 66 | additional_data_params = {
|
67 |
| - "target_transforms": [Differences([12])], |
| 67 | + "target_transforms": [ |
| 68 | + Differences([model_kwargs.get("Differences", 12)]) |
| 69 | + ], |
68 | 70 | "lags": model_kwargs.get("lags", [1, 6, 12]),
|
69 | 71 | "lag_transforms": (
|
70 | 72 | {
|
71 | 73 | 1: [ExpandingMean()],
|
72 |
| - 12: [RollingMean(window_size=24)], |
| 74 | + 12: [ |
| 75 | + RollingMean( |
| 76 | + window_size=model_kwargs.get("RollingMean", 24), |
| 77 | + min_samples=1, |
| 78 | + ) |
| 79 | + ], |
73 | 80 | }
|
74 | 81 | ),
|
75 | 82 | }
|
|
0 commit comments