Skip to content

Commit 320aeeb

Browse files
committed
add data attribute and unit testing
1 parent 5cf2b6d commit 320aeeb

File tree

6 files changed

+170
-26
lines changed

6 files changed

+170
-26
lines changed

ads/opctl/config/merger.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,33 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8; -*-
32

4-
# Copyright (c) 2022, 2023 Oracle and/or its affiliates.
3+
# Copyright (c) 2022, 2024 Oracle and/or its affiliates.
54
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
65

76
import os
87
from string import Template
98
from typing import Dict
10-
import json
119

1210
import yaml
1311

1412
from ads.common.auth import AuthType, ResourcePrincipal
1513
from ads.opctl import logger
1614
from ads.opctl.config.base import ConfigProcessor
17-
from ads.opctl.config.utils import read_from_ini, _DefaultNoneDict
18-
from ads.opctl.utils import is_in_notebook_session, get_service_pack_prefix
15+
from ads.opctl.config.utils import _DefaultNoneDict, read_from_ini
1916
from ads.opctl.constants import (
20-
DEFAULT_PROFILE,
21-
DEFAULT_OCI_CONFIG_FILE,
22-
DEFAULT_CONDA_PACK_FOLDER,
23-
DEFAULT_ADS_CONFIG_FOLDER,
24-
ADS_JOBS_CONFIG_FILE_NAME,
2517
ADS_CONFIG_FILE_NAME,
26-
ADS_ML_PIPELINE_CONFIG_FILE_NAME,
2718
ADS_DATAFLOW_CONFIG_FILE_NAME,
19+
ADS_JOBS_CONFIG_FILE_NAME,
2820
ADS_LOCAL_BACKEND_CONFIG_FILE_NAME,
21+
ADS_ML_PIPELINE_CONFIG_FILE_NAME,
2922
ADS_MODEL_DEPLOYMENT_CONFIG_FILE_NAME,
30-
DEFAULT_NOTEBOOK_SESSION_CONDA_DIR,
3123
BACKEND_NAME,
24+
DEFAULT_ADS_CONFIG_FOLDER,
25+
DEFAULT_CONDA_PACK_FOLDER,
26+
DEFAULT_NOTEBOOK_SESSION_CONDA_DIR,
27+
DEFAULT_OCI_CONFIG_FILE,
28+
DEFAULT_PROFILE,
3229
)
30+
from ads.opctl.utils import get_service_pack_prefix, is_in_notebook_session
3331

3432

3533
class ConfigMerger(ConfigProcessor):
@@ -41,8 +39,9 @@ class ConfigMerger(ConfigProcessor):
4139
"""
4240

4341
def process(self, **kwargs) -> None:
44-
config_string = Template(json.dumps(self.config)).safe_substitute(os.environ)
45-
self.config = json.loads(config_string)
42+
for key, value in self.config.items():
43+
if isinstance(value, str): # Substitute only if the value is a string
44+
self.config[key] = Template(value).safe_substitute(os.environ)
4645

4746
if "runtime" not in self.config:
4847
self.config["runtime"] = {}

ads/opctl/operator/common/operator_config.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8; -*-
32

4-
# Copyright (c) 2023 Oracle and/or its affiliates.
3+
# Copyright (c) 2023, 2024 Oracle and/or its affiliates.
54
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
65

76

@@ -11,15 +10,16 @@
1110
from typing import Any, Dict, List
1211

1312
from ads.common.serializer import DataClassSerializable
14-
15-
from ads.opctl.operator.common.utils import OperatorValidator
1613
from ads.opctl.operator.common.errors import InvalidParameterError
14+
from ads.opctl.operator.common.utils import OperatorValidator
15+
1716

1817
@dataclass(repr=True)
1918
class InputData(DataClassSerializable):
2019
"""Class representing operator specification input data details."""
2120

2221
connect_args: Dict = None
22+
data: Dict = None
2323
format: str = None
2424
columns: List[str] = None
2525
url: str = None

ads/opctl/operator/lowcode/common/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def load_data(data_spec, storage_options=None, **kwargs):
4040
if data_spec is None:
4141
raise InvalidParameterError("No details provided for this data source.")
4242
filename = data_spec.url
43+
data = data_spec.data
4344
format = data_spec.format
4445
columns = data_spec.columns
4546
connect_args = data_spec.connect_args
@@ -53,7 +54,10 @@ def load_data(data_spec, storage_options=None, **kwargs):
5354
if vault_secret_id is not None and connect_args is None:
5455
connect_args = dict()
5556

56-
if filename is not None:
57+
if data is not None:
58+
if format == "spark":
59+
data = data.toPandas()
60+
elif filename is not None:
5761
if not format:
5862
_, format = os.path.splitext(filename)
5963
format = format[1:]

ads/opctl/operator/lowcode/forecast/model/base_model.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,11 @@
4343

4444
from ..const import (
4545
AUTO_SELECT,
46+
BACKTEST_REPORT_NAME,
4647
SUMMARY_METRICS_HORIZON_LIMIT,
4748
SpeedAccuracyMode,
4849
SupportedMetrics,
4950
SupportedModels,
50-
BACKTEST_REPORT_NAME
5151
)
5252
from ..operator_config import ForecastOperatorConfig, ForecastOperatorSpec
5353
from .forecast_datasets import ForecastDatasets
@@ -259,7 +259,11 @@ def generate_report(self):
259259
output_dir = self.spec.output_directory.url
260260
file_path = f"{output_dir}/{BACKTEST_REPORT_NAME}"
261261
if self.spec.model == AUTO_SELECT:
262-
backtest_sections.append(rc.Heading("Auto-Select Backtesting and Performance Metrics", level=2))
262+
backtest_sections.append(
263+
rc.Heading(
264+
"Auto-Select Backtesting and Performance Metrics", level=2
265+
)
266+
)
263267
if not os.path.exists(file_path):
264268
failure_msg = rc.Text(
265269
"auto-select could not be executed. Please check the "
@@ -268,15 +272,23 @@ def generate_report(self):
268272
backtest_sections.append(failure_msg)
269273
else:
270274
backtest_stats = pd.read_csv(file_path)
271-
model_metric_map = backtest_stats.drop(columns=['metric', 'backtest'])
272-
average_dict = {k: round(v, 4) for k, v in model_metric_map.mean().to_dict().items()}
275+
model_metric_map = backtest_stats.drop(
276+
columns=["metric", "backtest"]
277+
)
278+
average_dict = {
279+
k: round(v, 4)
280+
for k, v in model_metric_map.mean().to_dict().items()
281+
}
273282
best_model = min(average_dict, key=average_dict.get)
274283
summary_text = rc.Text(
275284
f"Overall, the average {self.spec.metric} scores for the models are {average_dict}, with"
276-
f" {best_model} being identified as the top-performing model during backtesting.")
285+
f" {best_model} being identified as the top-performing model during backtesting."
286+
)
277287
backtest_table = rc.DataTable(backtest_stats, index=True)
278288
liner_plot = get_auto_select_plot(backtest_stats)
279-
backtest_sections.extend([backtest_table, summary_text, liner_plot])
289+
backtest_sections.extend(
290+
[backtest_table, summary_text, liner_plot]
291+
)
280292

281293
forecast_plots = []
282294
if len(self.forecast_output.list_series_ids()) > 0:
@@ -301,7 +313,14 @@ def generate_report(self):
301313
forecast_plots = [forecast_text, forecast_sec]
302314

303315
yaml_appendix_title = rc.Heading("Reference: YAML File", level=2)
304-
yaml_appendix = rc.Yaml(self.config.to_dict())
316+
config_dict = self.config.to_dict()
317+
# pop the data incase it isn't json serializable
318+
config_dict["spec"]["historical_data"].pop("data")
319+
if config_dict["spec"].get("additional_data"):
320+
config_dict["spec"]["additional_data"].pop("data")
321+
if config_dict["spec"].get("test_data"):
322+
config_dict["spec"]["test_data"].pop("data")
323+
yaml_appendix = rc.Yaml(config_dict)
305324
report_sections = (
306325
[summary]
307326
+ backtest_sections

ads/opctl/operator/lowcode/forecast/schema.yaml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ spec:
3737
nullable: true
3838
required: false
3939
type: dict
40+
data:
41+
nullable: true
42+
required: false
4043
format:
4144
allowed:
4245
- csv
@@ -48,6 +51,7 @@ spec:
4851
- sql_query
4952
- hdf
5053
- tsv
54+
- pandas
5155
required: false
5256
type: string
5357
columns:
@@ -92,6 +96,9 @@ spec:
9296
nullable: true
9397
required: false
9498
type: dict
99+
data:
100+
nullable: true
101+
required: false
95102
format:
96103
allowed:
97104
- csv
@@ -103,6 +110,7 @@ spec:
103110
- sql_query
104111
- hdf
105112
- tsv
113+
- pandas
106114
required: false
107115
type: string
108116
columns:
@@ -146,6 +154,9 @@ spec:
146154
nullable: true
147155
required: false
148156
type: dict
157+
data:
158+
nullable: true
159+
required: false
149160
format:
150161
allowed:
151162
- csv
@@ -157,6 +168,7 @@ spec:
157168
- sql_query
158169
- hdf
159170
- tsv
171+
- pandas
160172
required: false
161173
type: string
162174
columns:

tests/operators/forecast/test_errors.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -708,6 +708,116 @@ def test_smape_error():
708708
assert result == 0
709709

710710

711+
@pytest.mark.parametrize("model", MODELS)
712+
def test_pandas_historical_input(operator_setup, model):
713+
from ads.opctl.operator.lowcode.forecast.__main__ import operate
714+
from ads.opctl.operator.lowcode.forecast.model.forecast_datasets import (
715+
ForecastDatasets,
716+
)
717+
from ads.opctl.operator.lowcode.forecast.operator_config import (
718+
ForecastOperatorConfig,
719+
)
720+
721+
tmpdirname = operator_setup
722+
historical_data_path, additional_data_path = setup_small_rossman()
723+
yaml_i, output_data_path = populate_yaml(
724+
tmpdirname=tmpdirname,
725+
historical_data_path=historical_data_path,
726+
additional_data_path=additional_data_path,
727+
)
728+
yaml_i["spec"]["horizon"] = 10
729+
yaml_i["spec"]["model"] = model
730+
df = pd.read_csv(historical_data_path)
731+
yaml_i["spec"]["historical_data"].pop("url")
732+
yaml_i["spec"]["historical_data"]["data"] = df
733+
yaml_i["spec"]["historical_data"]["format"] = "pandas"
734+
735+
operator_config = ForecastOperatorConfig.from_dict(yaml_i)
736+
operate(operator_config)
737+
assert pd.read_csv(additional_data_path)["Date"].equals(
738+
pd.read_csv(f"{tmpdirname}/results/forecast.csv")["Date"]
739+
)
740+
741+
742+
@pytest.mark.parametrize("model", MODELS)
743+
def test_pandas_additional_input(operator_setup, model):
744+
from ads.opctl.operator.lowcode.forecast.__main__ import operate
745+
from ads.opctl.operator.lowcode.forecast.model.forecast_datasets import (
746+
ForecastDatasets,
747+
)
748+
from ads.opctl.operator.lowcode.forecast.operator_config import (
749+
ForecastOperatorConfig,
750+
)
751+
752+
tmpdirname = operator_setup
753+
historical_data_path, additional_data_path = setup_small_rossman()
754+
yaml_i, output_data_path = populate_yaml(
755+
tmpdirname=tmpdirname,
756+
historical_data_path=historical_data_path,
757+
additional_data_path=additional_data_path,
758+
)
759+
yaml_i["spec"]["horizon"] = 10
760+
yaml_i["spec"]["model"] = model
761+
df = pd.read_csv(historical_data_path)
762+
yaml_i["spec"]["historical_data"].pop("url")
763+
yaml_i["spec"]["historical_data"]["data"] = df
764+
yaml_i["spec"]["historical_data"]["format"] = "pandas"
765+
766+
df_add = pd.read_csv(additional_data_path)
767+
yaml_i["spec"]["additional_data"].pop("url")
768+
yaml_i["spec"]["additional_data"]["data"] = df_add
769+
yaml_i["spec"]["additional_data"]["format"] = "pandas"
770+
771+
operator_config = ForecastOperatorConfig.from_dict(yaml_i)
772+
operate(operator_config)
773+
assert pd.read_csv(additional_data_path)["Date"].equals(
774+
pd.read_csv(f"{tmpdirname}/results/forecast.csv")["Date"]
775+
)
776+
777+
778+
# def test_spark_additional_input(operator_setup):
779+
# from ads.opctl.operator.lowcode.forecast.__main__ import operate
780+
# from ads.opctl.operator.lowcode.forecast.model.forecast_datasets import ForecastDatasets
781+
# from ads.opctl.operator.lowcode.forecast.operator_config import ForecastOperatorConfig
782+
# from pyspark.sql import SparkSession
783+
# from pyspark import SparkContext
784+
785+
# spark = SparkSession.builder.getOrCreate()
786+
787+
# tmpdirname = operator_setup
788+
# historical_data_path, additional_data_path = setup_small_rossman()
789+
# yaml_i, output_data_path = populate_yaml(
790+
# tmpdirname=tmpdirname,
791+
# historical_data_path=historical_data_path,
792+
# additional_data_path=additional_data_path,
793+
# )
794+
# yaml_i["spec"]["horizon"] = 10
795+
# yaml_i["spec"]["model"] = "prophet"
796+
797+
# df = pd.read_csv(historical_data_path)
798+
# spark_df = spark.createDataFrame(df)
799+
800+
# def _run_operator(df):
801+
# yaml_i["spec"]["historical_data"].pop("url")
802+
# yaml_i["spec"]["historical_data"]["data"] = spark_df
803+
# yaml_i["spec"]["historical_data"]["format"] = "spark"
804+
# operator_config = ForecastOperatorConfig.from_dict(yaml_i)
805+
# operate(operator_config)
806+
807+
# # df_add = pd.read_csv(additional_data_path)
808+
# # spark_df_add = spark.createDataFrame(df_add)
809+
# # yaml_i["spec"]["additional_data"].pop("url")
810+
# # yaml_i["spec"]["additional_data"]["data"] = spark_df_add
811+
# # yaml_i["spec"]["additional_data"]["format"] = "spark"
812+
813+
# rdd_processed = spark_df.rdd.map(lambda x: _run_operator(x, broadcast_yaml_i))
814+
# print(rdd_processed.collect())
815+
816+
# assert pd.read_csv(additional_data_path)["Date"].equals(
817+
# pd.read_csv(f"{tmpdirname}/results/forecast.csv")["Date"]
818+
# )
819+
820+
711821
@pytest.mark.parametrize("model", MODELS)
712822
def test_date_format(operator_setup, model):
713823
tmpdirname = operator_setup

0 commit comments

Comments
 (0)