Skip to content

Commit 4d323cc

Browse files
authored
Dataflow changes (#1018)
2 parents beef7b1 + 723a763 commit 4d323cc

File tree

14 files changed

+524
-275
lines changed

14 files changed

+524
-275
lines changed

.github/workflows/run-forecast-unit-tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,6 @@ jobs:
5656
$CONDA/bin/conda init
5757
source /home/runner/.bashrc
5858
pip install -r test-requirements-operators.txt
59-
pip install "oracle-automlx[forecasting]>=24.4.0"
59+
pip install "oracle-automlx[forecasting]>=24.4.1"
6060
pip install pandas>=2.2.0
6161
python -m pytest -v -p no:warnings --durations=5 tests/operators/forecast

ads/opctl/config/merger.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,33 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8; -*-
32

4-
# Copyright (c) 2022, 2023 Oracle and/or its affiliates.
3+
# Copyright (c) 2022, 2024 Oracle and/or its affiliates.
54
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
65

76
import os
87
from string import Template
98
from typing import Dict
10-
import json
119

1210
import yaml
1311

1412
from ads.common.auth import AuthType, ResourcePrincipal
1513
from ads.opctl import logger
1614
from ads.opctl.config.base import ConfigProcessor
17-
from ads.opctl.config.utils import read_from_ini, _DefaultNoneDict
18-
from ads.opctl.utils import is_in_notebook_session, get_service_pack_prefix
15+
from ads.opctl.config.utils import _DefaultNoneDict, read_from_ini
1916
from ads.opctl.constants import (
20-
DEFAULT_PROFILE,
21-
DEFAULT_OCI_CONFIG_FILE,
22-
DEFAULT_CONDA_PACK_FOLDER,
23-
DEFAULT_ADS_CONFIG_FOLDER,
24-
ADS_JOBS_CONFIG_FILE_NAME,
2517
ADS_CONFIG_FILE_NAME,
26-
ADS_ML_PIPELINE_CONFIG_FILE_NAME,
2718
ADS_DATAFLOW_CONFIG_FILE_NAME,
19+
ADS_JOBS_CONFIG_FILE_NAME,
2820
ADS_LOCAL_BACKEND_CONFIG_FILE_NAME,
21+
ADS_ML_PIPELINE_CONFIG_FILE_NAME,
2922
ADS_MODEL_DEPLOYMENT_CONFIG_FILE_NAME,
30-
DEFAULT_NOTEBOOK_SESSION_CONDA_DIR,
3123
BACKEND_NAME,
24+
DEFAULT_ADS_CONFIG_FOLDER,
25+
DEFAULT_CONDA_PACK_FOLDER,
26+
DEFAULT_NOTEBOOK_SESSION_CONDA_DIR,
27+
DEFAULT_OCI_CONFIG_FILE,
28+
DEFAULT_PROFILE,
3229
)
30+
from ads.opctl.utils import get_service_pack_prefix, is_in_notebook_session
3331

3432

3533
class ConfigMerger(ConfigProcessor):
@@ -41,8 +39,9 @@ class ConfigMerger(ConfigProcessor):
4139
"""
4240

4341
def process(self, **kwargs) -> None:
44-
config_string = Template(json.dumps(self.config)).safe_substitute(os.environ)
45-
self.config = json.loads(config_string)
42+
for key, value in self.config.items():
43+
if isinstance(value, str): # Substitute only if the value is a string
44+
self.config[key] = Template(value).safe_substitute(os.environ)
4645

4746
if "runtime" not in self.config:
4847
self.config["runtime"] = {}

ads/opctl/operator/common/operator_config.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8; -*-
32

4-
# Copyright (c) 2023 Oracle and/or its affiliates.
3+
# Copyright (c) 2023, 2024 Oracle and/or its affiliates.
54
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
65

76

@@ -11,15 +10,16 @@
1110
from typing import Any, Dict, List
1211

1312
from ads.common.serializer import DataClassSerializable
14-
15-
from ads.opctl.operator.common.utils import OperatorValidator
1613
from ads.opctl.operator.common.errors import InvalidParameterError
14+
from ads.opctl.operator.common.utils import OperatorValidator
15+
1716

1817
@dataclass(repr=True)
1918
class InputData(DataClassSerializable):
2019
"""Class representing operator specification input data details."""
2120

2221
connect_args: Dict = None
22+
data: Dict = None
2323
format: str = None
2424
columns: List[str] = None
2525
url: str = None

ads/opctl/operator/lowcode/common/transformations.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
#!/usr/bin/env python
22

3-
# Copyright (c) 2023, 2024 Oracle and/or its affiliates.
3+
# Copyright (c) 2023, 2025 Oracle and/or its affiliates.
44
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
55

66
from abc import ABC
77

8+
import numpy as np
89
import pandas as pd
910

1011
from ads.opctl import logger
@@ -209,18 +210,24 @@ def _outlier_treatment(self, df):
209210
-------
210211
A new Pandas DataFrame with treated outliears.
211212
"""
212-
df["z_score"] = (
213+
return df
214+
df["__z_score"] = (
213215
df[self.target_column_name]
214216
.groupby(DataColumns.Series)
215217
.transform(lambda x: (x - x.mean()) / x.std())
216218
)
217-
outliers_mask = df["z_score"].abs() > 3
219+
outliers_mask = df["__z_score"].abs() > 3
220+
221+
if df[self.target_column_name].dtype == np.int:
222+
df[self.target_column_name].astype(np.float)
223+
218224
df.loc[outliers_mask, self.target_column_name] = (
219225
df[self.target_column_name]
220226
.groupby(DataColumns.Series)
221-
.transform(lambda x: x.mean())
227+
.transform(lambda x: np.median(x))
222228
)
223-
return df.drop("z_score", axis=1)
229+
df_ret = df.drop("__z_score", axis=1)
230+
return df_ret
224231

225232
def _check_historical_dataset(self, df):
226233
expected_names = [self.target_column_name, self.dt_column_name] + (

ads/opctl/operator/lowcode/common/utils.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/usr/bin/env python
22

3-
# Copyright (c) 2024 Oracle and/or its affiliates.
3+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
44
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
55

66
import logging
@@ -40,6 +40,7 @@ def load_data(data_spec, storage_options=None, **kwargs):
4040
if data_spec is None:
4141
raise InvalidParameterError("No details provided for this data source.")
4242
filename = data_spec.url
43+
data = data_spec.data
4344
format = data_spec.format
4445
columns = data_spec.columns
4546
connect_args = data_spec.connect_args
@@ -51,9 +52,12 @@ def load_data(data_spec, storage_options=None, **kwargs):
5152
default_signer() if ObjectStorageDetails.is_oci_path(filename) else {}
5253
)
5354
if vault_secret_id is not None and connect_args is None:
54-
connect_args = dict()
55+
connect_args = {}
5556

56-
if filename is not None:
57+
if data is not None:
58+
if format == "spark":
59+
data = data.toPandas()
60+
elif filename is not None:
5761
if not format:
5862
_, format = os.path.splitext(filename)
5963
format = format[1:]
@@ -98,7 +102,7 @@ def load_data(data_spec, storage_options=None, **kwargs):
98102
except Exception as e:
99103
raise Exception(
100104
f"Could not retrieve database credentials from vault {vault_secret_id}: {e}"
101-
)
105+
) from e
102106

103107
con = oracledb.connect(**connect_args)
104108
if table_name is not None:
@@ -122,6 +126,7 @@ def load_data(data_spec, storage_options=None, **kwargs):
122126

123127

124128
def write_data(data, filename, format, storage_options, index=False, **kwargs):
129+
disable_print()
125130
if not format:
126131
_, format = os.path.splitext(filename)
127132
format = format[1:]
@@ -130,7 +135,8 @@ def write_data(data, filename, format, storage_options, index=False, **kwargs):
130135
return call_pandas_fsspec(
131136
write_fn, filename, index=index, storage_options=storage_options, **kwargs
132137
)
133-
raise OperatorYamlContentError(
138+
enable_print()
139+
raise InvalidParameterError(
134140
f"The format {format} is not currently supported for writing data. Please change the format parameter for the data output: {filename} ."
135141
)
136142

ads/opctl/operator/lowcode/forecast/model/automlx.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -82,22 +82,6 @@ def _build_model(self) -> pd.DataFrame:
8282

8383
from automlx import Pipeline, init
8484

85-
cpu_count = os.cpu_count()
86-
try:
87-
if cpu_count < 4:
88-
engine = "local"
89-
engine_opts = None
90-
else:
91-
engine = "ray"
92-
engine_opts = ({"ray_setup": {"_temp_dir": "/tmp/ray-temp"}},)
93-
init(
94-
engine=engine,
95-
engine_opts=engine_opts,
96-
loglevel=logging.CRITICAL,
97-
)
98-
except Exception as e:
99-
logger.info(f"Error. Has Ray already been initialized? Skipping. {e}")
100-
10185
full_data_dict = self.datasets.get_data_by_series()
10286

10387
self.models = {}
@@ -113,6 +97,26 @@ def _build_model(self) -> pd.DataFrame:
11397
# Clean up kwargs for pass through
11498
model_kwargs_cleaned, time_budget = self.set_kwargs()
11599

100+
cpu_count = os.cpu_count()
101+
try:
102+
engine_type = model_kwargs_cleaned.pop(
103+
"engine", "local" if cpu_count <= 4 else "ray"
104+
)
105+
engine_opts = (
106+
None
107+
if engine_type == "local"
108+
else ({"ray_setup": {"_temp_dir": "/tmp/ray-temp"}},)
109+
)
110+
init(
111+
engine=engine_type,
112+
engine_opts=engine_opts,
113+
loglevel=logging.CRITICAL,
114+
)
115+
except Exception as e:
116+
logger.info(
117+
f"Error initializing automlx. Has Ray already been initialized? Skipping. {e}"
118+
)
119+
116120
for s_id, df in full_data_dict.items():
117121
try:
118122
logger.debug(f"Running automlx on series {s_id}")

ads/opctl/operator/lowcode/forecast/model/base_model.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444

4545
from ..const import (
4646
AUTO_SELECT,
47+
BACKTEST_REPORT_NAME,
4748
SUMMARY_METRICS_HORIZON_LIMIT,
4849
SpeedAccuracyMode,
4950
SupportedMetrics,
@@ -321,7 +322,14 @@ def generate_report(self):
321322
forecast_plots = [forecast_text, forecast_sec]
322323

323324
yaml_appendix_title = rc.Heading("Reference: YAML File", level=2)
324-
yaml_appendix = rc.Yaml(self.config.to_dict())
325+
config_dict = self.config.to_dict()
326+
# pop the data incase it isn't json serializable
327+
config_dict["spec"]["historical_data"].pop("data")
328+
if config_dict["spec"].get("additional_data"):
329+
config_dict["spec"]["additional_data"].pop("data")
330+
if config_dict["spec"].get("test_data"):
331+
config_dict["spec"]["test_data"].pop("data")
332+
yaml_appendix = rc.Yaml(config_dict)
325333
report_sections = (
326334
[summary]
327335
+ backtest_sections

ads/opctl/operator/lowcode/forecast/model/prophet.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -369,11 +369,7 @@ def _generate_report(self):
369369
logger.debug(f"Full Traceback: {traceback.format_exc()}")
370370

371371
model_description = rc.Text(
372-
"Prophet is a procedure for forecasting time series data based on an additive "
373-
"model where non-linear trends are fit with yearly, weekly, and daily seasonality, "
374-
"plus holiday effects. It works best with time series that have strong seasonal "
375-
"effects and several seasons of historical data. Prophet is robust to missing "
376-
"data and shifts in the trend, and typically handles outliers well."
372+
"""Prophet is a procedure for forecasting time series data based on an additive model where non-linear trends are fit with yearly, weekly, and daily seasonality, plus holiday effects. It works best with time series that have strong seasonal effects and several seasons of historical data. Prophet is robust to missing data and shifts in the trend, and typically handles outliers well."""
377373
)
378374
other_sections = all_sections
379375

ads/opctl/operator/lowcode/forecast/schema.yaml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ spec:
3737
nullable: true
3838
required: false
3939
type: dict
40+
data:
41+
nullable: true
42+
required: false
4043
format:
4144
allowed:
4245
- csv
@@ -48,6 +51,7 @@ spec:
4851
- sql_query
4952
- hdf
5053
- tsv
54+
- pandas
5155
required: false
5256
type: string
5357
columns:
@@ -92,6 +96,9 @@ spec:
9296
nullable: true
9397
required: false
9498
type: dict
99+
data:
100+
nullable: true
101+
required: false
95102
format:
96103
allowed:
97104
- csv
@@ -103,6 +110,7 @@ spec:
103110
- sql_query
104111
- hdf
105112
- tsv
113+
- pandas
106114
required: false
107115
type: string
108116
columns:
@@ -146,6 +154,9 @@ spec:
146154
nullable: true
147155
required: false
148156
type: dict
157+
data:
158+
nullable: true
159+
required: false
149160
format:
150161
allowed:
151162
- csv
@@ -157,6 +168,7 @@ spec:
157168
- sql_query
158169
- hdf
159170
- tsv
171+
- pandas
160172
required: false
161173
type: string
162174
columns:

pyproject.toml

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -157,27 +157,26 @@ forecast = [
157157
"oci-cli",
158158
"py-cpuinfo",
159159
"rich",
160-
"autots[additional]",
160+
"autots",
161161
"mlforecast",
162162
"neuralprophet>=0.7.0",
163163
"numpy<2.0.0",
164164
"oci-cli",
165165
"optuna",
166-
"oracle-ads",
167166
"pmdarima",
168167
"prophet",
169168
"shap",
170169
"sktime",
171170
"statsmodels",
172171
"plotly",
173172
"oracledb",
174-
"report-creator==1.0.28",
173+
"report-creator==1.0.32",
175174
]
176175
anomaly = [
177176
"oracle_ads[opctl]",
178177
"autots",
179178
"oracledb",
180-
"report-creator==1.0.28",
179+
"report-creator==1.0.32",
181180
"rrcf==0.4.4",
182181
"scikit-learn<1.6.0",
183182
"salesforce-merlion[all]==2.0.4"
@@ -186,7 +185,7 @@ recommender = [
186185
"oracle_ads[opctl]",
187186
"scikit-surprise",
188187
"plotly",
189-
"report-creator==1.0.28",
188+
"report-creator==1.0.32",
190189
]
191190
feature-store-marketplace = [
192191
"oracle-ads[opctl]",
@@ -202,7 +201,7 @@ pii = [
202201
"scrubadub_spacy",
203202
"spacy-transformers==1.2.5",
204203
"spacy==3.6.1",
205-
"report-creator==1.0.28",
204+
"report-creator==1.0.32",
206205
]
207206
llm = ["langchain>=0.2", "langchain-community", "langchain_openai", "pydantic>=2,<3", "evaluate>=0.4.0"]
208207
aqua = ["jupyter_server"]

0 commit comments

Comments
 (0)