Skip to content

Commit 60d48f2

Browse files
authored
Bump automlx version (#991)
2 parents 5ff0d48 + 31549f7 commit 60d48f2

File tree

10 files changed

+53
-51
lines changed

10 files changed

+53
-51
lines changed

.github/workflows/run-forecast-unit-tests.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ jobs:
5656
$CONDA/bin/conda init
5757
source /home/runner/.bashrc
5858
pip install -r test-requirements-operators.txt
59-
pip install "oracle-automlx[classic]>=24.2.0"
60-
pip install "oracle-automlx[forecasting]>=24.2.0"
59+
pip install "oracle-automlx[forecasting]>=24.4.0"
6160
pip install pandas>=2.2.0
6261
python -m pytest -v -p no:warnings --durations=5 tests/operators/forecast

ads/opctl/operator/lowcode/forecast/cmd.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,12 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*--
32

4-
# Copyright (c) 2023 Oracle and/or its affiliates.
3+
# Copyright (c) 2023, 2024 Oracle and/or its affiliates.
54
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
65

76
from typing import Dict
87

9-
import click
10-
11-
from ads.opctl import logger
12-
from ads.opctl.operator.common.utils import _load_yaml_from_uri
138
from ads.opctl.operator.common.operator_yaml_generator import YamlGenerator
14-
15-
from .const import SupportedModels
9+
from ads.opctl.operator.common.utils import _load_yaml_from_uri
1610

1711

1812
def init(**kwargs: Dict) -> str:
@@ -39,7 +33,7 @@ def init(**kwargs: Dict) -> str:
3933
# type=click.Choice(SupportedModels.values()),
4034
# default=SupportedModels.Auto,
4135
# )
42-
model_type = "auto"
36+
model_type = "prophet"
4337

4438
return YamlGenerator(
4539
schema=_load_yaml_from_uri(__file__.replace("cmd.py", "schema.yaml"))

ads/opctl/operator/lowcode/forecast/const.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*--
32

4-
# Copyright (c) 2023 Oracle and/or its affiliates.
3+
# Copyright (c) 2023, 2024 Oracle and/or its affiliates.
54
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
65

76
from ads.common.extended_enum import ExtendedEnumMeta
@@ -17,7 +16,7 @@ class SupportedModels(str, metaclass=ExtendedEnumMeta):
1716
LGBForecast = "lgbforecast"
1817
AutoMLX = "automlx"
1918
AutoTS = "autots"
20-
Auto = "auto"
19+
# Auto = "auto"
2120

2221

2322
class SpeedAccuracyMode(str, metaclass=ExtendedEnumMeta):
@@ -28,7 +27,7 @@ class SpeedAccuracyMode(str, metaclass=ExtendedEnumMeta):
2827
HIGH_ACCURACY = "HIGH_ACCURACY"
2928
BALANCED = "BALANCED"
3029
FAST_APPROXIMATE = "FAST_APPROXIMATE"
31-
ratio = dict()
30+
ratio = {}
3231
ratio[HIGH_ACCURACY] = 1 # 100 % data used for generating explanations
3332
ratio[BALANCED] = 0.5 # 50 % data used for generating explanations
3433
ratio[FAST_APPROXIMATE] = 0 # constant

ads/opctl/operator/lowcode/forecast/model/factory.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,19 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*--
32

4-
# Copyright (c) 2023 Oracle and/or its affiliates.
3+
# Copyright (c) 2023, 2024 Oracle and/or its affiliates.
54
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
65

7-
from ..const import SupportedModels, AUTO_SELECT
6+
from ..const import AUTO_SELECT, SupportedModels
7+
from ..model_evaluator import ModelEvaluator
88
from ..operator_config import ForecastOperatorConfig
99
from .arima import ArimaOperatorModel
1010
from .automlx import AutoMLXOperatorModel
1111
from .autots import AutoTSOperatorModel
1212
from .base_model import ForecastOperatorBaseModel
13+
from .forecast_datasets import ForecastDatasets
1314
from .neuralprophet import NeuralProphetOperatorModel
1415
from .prophet import ProphetOperatorModel
15-
from .forecast_datasets import ForecastDatasets
16-
from .ml_forecast import MLForecastOperatorModel
17-
from ..model_evaluator import ModelEvaluator
16+
1817

1918
class UnSupportedModelError(Exception):
2019
def __init__(self, model_type: str):
@@ -33,9 +32,9 @@ class ForecastOperatorModelFactory:
3332
SupportedModels.Prophet: ProphetOperatorModel,
3433
SupportedModels.Arima: ArimaOperatorModel,
3534
SupportedModels.NeuralProphet: NeuralProphetOperatorModel,
36-
SupportedModels.LGBForecast: MLForecastOperatorModel,
35+
# SupportedModels.LGBForecast: MLForecastOperatorModel,
3736
SupportedModels.AutoMLX: AutoMLXOperatorModel,
38-
SupportedModels.AutoTS: AutoTSOperatorModel
37+
SupportedModels.AutoTS: AutoTSOperatorModel,
3938
}
4039

4140
@classmethod
@@ -65,14 +64,14 @@ def get_model(
6564
model_type = operator_config.spec.model
6665
if model_type == AUTO_SELECT:
6766
model_type = cls.auto_select_model(datasets, operator_config)
68-
operator_config.spec.model_kwargs = dict()
67+
operator_config.spec.model_kwargs = {}
6968
if model_type not in cls._MAP:
7069
raise UnSupportedModelError(model_type)
7170
return cls._MAP[model_type](config=operator_config, datasets=datasets)
7271

7372
@classmethod
7473
def auto_select_model(
75-
cls, datasets: ForecastDatasets, operator_config: ForecastOperatorConfig
74+
cls, datasets: ForecastDatasets, operator_config: ForecastOperatorConfig
7675
) -> str:
7776
"""
7877
Selects AutoMLX or Arima model based on column count.
@@ -90,8 +89,10 @@ def auto_select_model(
9089
str
9190
The type of the model.
9291
"""
93-
all_models = operator_config.spec.model_kwargs.get("model_list", cls._MAP.keys())
92+
all_models = operator_config.spec.model_kwargs.get(
93+
"model_list", cls._MAP.keys()
94+
)
9495
num_backtests = operator_config.spec.model_kwargs.get("num_backtests", 5)
9596
sample_ratio = operator_config.spec.model_kwargs.get("sample_ratio", 0.20)
9697
model_evaluator = ModelEvaluator(all_models, num_backtests, sample_ratio)
97-
return model_evaluator.find_best_model(datasets, operator_config)
98+
return model_evaluator.find_best_model(datasets, operator_config)

ads/opctl/operator/lowcode/forecast/model/ml_forecast.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
# Copyright (c) 2024 Oracle and/or its affiliates.
44
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5-
import numpy as np
5+
import traceback
6+
67
import pandas as pd
78

89
from ads.common.decorator import runtime_dependency
@@ -164,7 +165,7 @@ def _train_model(self, data_train, data_test, model_kwargs):
164165
self.errors_dict[self.spec.model] = {
165166
"model_name": self.spec.model,
166167
"error": str(e),
167-
"error_trace": traceback.format_exc()
168+
"error_trace": traceback.format_exc(),
168169
}
169170
logger.warn(f"Encountered Error: {e}. Skipping.")
170171
logger.warn(traceback.format_exc())
@@ -173,7 +174,7 @@ def _train_model(self, data_train, data_test, model_kwargs):
173174
def _build_model(self) -> pd.DataFrame:
174175
data_train = self.datasets.get_all_data_long(include_horizon=False)
175176
data_test = self.datasets.get_all_data_long_forecast_horizon()
176-
self.models = dict()
177+
self.models = {}
177178
model_kwargs = self.set_kwargs()
178179
self.forecast_output = ForecastOutput(
179180
confidence_interval_width=self.spec.confidence_interval_width,

ads/opctl/operator/lowcode/forecast/operator_config.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,24 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*--
32

4-
# Copyright (c) 2023 Oracle and/or its affiliates.
3+
# Copyright (c) 2023, 2024 Oracle and/or its affiliates.
54
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
65

76
import os
87
from dataclasses import dataclass, field
98
from typing import Dict, List
109

1110
from ads.common.serializer import DataClassSerializable
11+
from ads.opctl.operator.common.operator_config import (
12+
InputData,
13+
OperatorConfig,
14+
OutputDirectory,
15+
)
1216
from ads.opctl.operator.common.utils import _load_yaml_from_uri
13-
from ads.opctl.operator.common.operator_config import OperatorConfig, OutputDirectory, InputData
14-
15-
from .const import SupportedMetrics, SpeedAccuracyMode
16-
from .const import SupportedModels
1717
from ads.opctl.operator.lowcode.common.utils import find_output_dirname
1818

19+
from .const import SpeedAccuracyMode, SupportedMetrics, SupportedModels
20+
21+
1922
@dataclass(repr=True)
2023
class TestData(InputData):
2124
"""Class representing operator specification test data details."""
@@ -90,13 +93,17 @@ class ForecastOperatorSpec(DataClassSerializable):
9093

9194
def __post_init__(self):
9295
"""Adjusts the specification details."""
93-
self.output_directory = self.output_directory or OutputDirectory(url=find_output_dirname(self.output_directory))
96+
self.output_directory = self.output_directory or OutputDirectory(
97+
url=find_output_dirname(self.output_directory)
98+
)
9499
self.metric = (self.metric or "").lower() or SupportedMetrics.SMAPE.lower()
95-
self.model = self.model or SupportedModels.Auto
100+
self.model = self.model or SupportedModels.Prophet
96101
self.confidence_interval_width = self.confidence_interval_width or 0.80
97102
self.report_filename = self.report_filename or "report.html"
98103
self.preprocessing = (
99-
self.preprocessing if self.preprocessing is not None else DataPreprocessor(enabled=True)
104+
self.preprocessing
105+
if self.preprocessing is not None
106+
else DataPreprocessor(enabled=True)
100107
)
101108
# For Report Generation. When user doesn't specify defaults to True
102109
self.generate_report = (
@@ -138,7 +145,7 @@ def __post_init__(self):
138145
)
139146
self.target_column = self.target_column or "Sales"
140147
self.errors_dict_filename = "errors.json"
141-
self.model_kwargs = self.model_kwargs or dict()
148+
self.model_kwargs = self.model_kwargs or {}
142149

143150

144151
@dataclass(repr=True)

ads/opctl/operator/lowcode/forecast/schema.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -374,12 +374,12 @@ spec:
374374
model:
375375
type: string
376376
required: false
377-
default: auto-select
377+
default: prophet
378378
allowed:
379379
- prophet
380380
- arima
381381
- neuralprophet
382-
- lgbforecast
382+
# - lgbforecast
383383
- automlx
384384
- autots
385385
- auto-select

docs/source/user_guide/operators/forecast_operator/yaml_schema.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ Below is an example of a ``forecast.yaml`` file with every parameter specified:
5151
* **format**: (Optional) Specify the format for output data (e.g., ``csv``, ``json``, ``excel``).
5252
* **options**: (Optional) Include any additional arguments, such as connection parameters for storage.
5353

54-
* **model**: (Optional) The name of the model framework to use. Defaults to ``auto-select``. Available options include ``arima``, ``lgbforecaster``, ``prophet``, ``neuralprophet``, ``autots``, and ``auto-select``.
54+
* **model**: (Optional) The name of the model framework to use. Defaults to ``auto-select``. Available options include ``arima``, ``prophet``, ``neuralprophet``, ``autots``, and ``auto-select``.
5555

5656
* **model_kwargs**: (Optional) A dictionary of arguments to pass directly to the model framework, allowing for detailed control over modeling.
5757

tests/operators/forecast/test_datasets.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
"prophet",
3333
"neuralprophet",
3434
"autots",
35-
"lgbforecast",
35+
# "lgbforecast",
3636
"auto-select",
3737
]
3838

@@ -135,16 +135,18 @@ def test_load_datasets(model, data_details):
135135
if model == "automlx":
136136
yaml_i["spec"]["model_kwargs"] = {"time_budget": 2}
137137
if model == "auto-select":
138-
yaml_i["spec"]["model_kwargs"] = {"model_list": ['prophet', 'arima', 'lgbforecast']}
139-
if dataset_name == f'{DATASET_PREFIX}dataset4.csv':
138+
yaml_i["spec"]["model_kwargs"] = {
139+
"model_list": ["prophet", "arima"]
140+
} # 'lgbforecast'
141+
if dataset_name == f"{DATASET_PREFIX}dataset4.csv":
140142
pytest.skip("Skipping dataset4 with auto-select") # todo:// ODSC-58584
141143

142144
run(yaml_i, backend="operator.local", debug=False)
143145
subprocess.run(f"ls -a {output_data_path}", shell=True)
144146
if yaml_i["spec"]["generate_explanations"] and model not in [
145147
"automlx",
146-
"lgbforecast",
147-
"auto-select"
148+
# "lgbforecast",
149+
"auto-select",
148150
]:
149151
verify_explanations(
150152
tmpdirname=tmpdirname,

tests/operators/forecast/test_errors.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,7 @@
141141
"prophet",
142142
"neuralprophet",
143143
"autots",
144-
"lgbforecast",
145-
# "auto",
144+
# "lgbforecast",
146145
]
147146

148147
TEMPLATE_YAML = {
@@ -687,7 +686,7 @@ def test_arima_automlx_errors(operator_setup, model):
687686
in error_content["13"]["error"]
688687
), "Error message mismatch"
689688

690-
if model not in ["autots", "automlx", "lgbforecast"]:
689+
if model not in ["autots", "automlx"]: # , "lgbforecast"
691690
global_fn = f"{tmpdirname}/results/global_explanation.csv"
692691
assert os.path.exists(
693692
global_fn

0 commit comments

Comments
 (0)