Skip to content

Commit 06e99c3

Browse files
committed
changes to support new version
1 parent d8acef4 commit 06e99c3

File tree

1 file changed

+6
-23
lines changed
  • ads/opctl/operator/lowcode/forecast/model

1 file changed

+6
-23
lines changed

ads/opctl/operator/lowcode/forecast/model/automlx.py

Lines changed: 6 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def preprocess(self, data, series_id=None):
5555
return data.set_index(self.spec.datetime_column.name)
5656

5757
@runtime_dependency(
58-
module="automl",
58+
module="automlx",
5959
err_msg=(
6060
"Please run `pip3 install oracle-automlx==23.4.1` to install the required dependencies for automlx."
6161
),
@@ -67,15 +67,10 @@ def preprocess(self, data, series_id=None):
6767
),
6868
)
6969
def _build_model(self) -> pd.DataFrame:
70-
from automl import init
70+
from automlx import init
7171
from sktime.forecasting.model_selection import temporal_train_test_split
7272

73-
init(
74-
engine="local",
75-
engine_opts={"n_jobs": -1, "model_n_jobs": -1},
76-
check_deprecation_warnings=False,
77-
logger=50,
78-
)
73+
init(engine="ray", engine_opts={"ray_setup": {"_temp_dir": "/tmp/ray-temp"}})
7974

8075
full_data_dict = self.datasets.get_data_by_series()
8176

@@ -95,7 +90,7 @@ def _build_model(self) -> pd.DataFrame:
9590

9691
for i, (s_id, df) in enumerate(full_data_dict.items()):
9792
try:
98-
logger.debug(f"Running automl on series {s_id}")
93+
logger.debug(f"Running automlx on series {s_id}")
9994
model_kwargs = model_kwargs_cleaned.copy()
10095
target = self.original_target_column
10196
self.forecast_output.init_series_output(
@@ -110,14 +105,13 @@ def _build_model(self) -> pd.DataFrame:
110105
if self.loaded_models is not None:
111106
model = self.loaded_models[s_id]
112107
else:
113-
model = automl.Pipeline(
108+
model = automlx.Pipeline(
114109
task="forecasting",
115110
**model_kwargs,
116111
)
117112
model.fit(
118113
X=data_i.drop(target, axis=1),
119-
y=data_i[[target]],
120-
time_budget=time_budget,
114+
y=data_i[[target]]
121115
)
122116
logger.debug(f"Selected model: {model.selected_model_}")
123117
logger.debug(f"Selected model params: {model.selected_model_params_}")
@@ -149,18 +143,7 @@ def _build_model(self) -> pd.DataFrame:
149143

150144
self.model_parameters[s_id] = {
151145
"framework": SupportedModels.AutoMLX,
152-
"score_metric": model.score_metric,
153-
"random_state": model.random_state,
154-
"model_list": model.model_list,
155-
"n_algos_tuned": model.n_algos_tuned,
156-
"adaptive_sampling": model.adaptive_sampling,
157-
"min_features": model.min_features,
158-
"optimization": model.optimization,
159-
"preprocessing": model.preprocessing,
160-
"search_space": model.search_space,
161146
"time_series_period": model.time_series_period,
162-
"min_class_instances": model.min_class_instances,
163-
"max_tuning_trials": model.max_tuning_trials,
164147
"selected_model": model.selected_model_,
165148
"selected_model_params": model.selected_model_params_,
166149
}

0 commit comments

Comments
 (0)