Skip to content

Commit 737c713

Browse files
authored
Optimize auto-select data loading process and update dataset classes (#1110)
2 parents 35c03c6 + ed67206 commit 737c713

File tree

5 files changed

+43
-38
lines changed

5 files changed

+43
-38
lines changed

ads/opctl/operator/lowcode/common/data.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,16 @@
1919

2020

2121
class AbstractData(ABC):
22-
def __init__(self, spec: dict, name="input_data"):
22+
def __init__(self, spec, name="input_data", data=None):
2323
self.Transformations = Transformations
2424
self.data = None
2525
self._data_dict = dict()
2626
self.name = name
2727
self.spec = spec
28-
self.load_transform_ingest_data(spec)
28+
if data is not None:
29+
self.data = data
30+
else:
31+
self.load_transform_ingest_data(spec)
2932

3033
def get_raw_data_by_cat(self, category):
3134
mapping = self._data_transformer.get_target_category_columns_map()

ads/opctl/operator/lowcode/common/transformations.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ def __init__(self, dataset_info, name="historical_data"):
3131
dataset_info : ForecastOperatorConfig
3232
"""
3333
self.name = name
34-
self.has_artificial_series = False
3534
self.dataset_info = dataset_info
3635
self.target_category_columns = dataset_info.target_category_columns
3736
self.target_column_name = dataset_info.target_column
@@ -136,7 +135,6 @@ def _set_series_id_column(self, df):
136135
self._target_category_columns_map = {}
137136
if not self.target_category_columns:
138137
df[DataColumns.Series] = "Series 1"
139-
self.has_artificial_series = True
140138
else:
141139
df[DataColumns.Series] = merge_category_columns(
142140
df, self.target_category_columns

ads/opctl/operator/lowcode/forecast/model/base_model.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def generate_report(self):
120120

121121
# Generate metrics
122122
summary_metrics = None
123-
test_data = None
123+
test_data = self.datasets.test_data
124124
self.eval_metrics = None
125125

126126
if self.spec.generate_report or self.spec.generate_metrics:
@@ -130,12 +130,11 @@ def generate_report(self):
130130
{"Series 1": self.original_target_column}, axis=1, inplace=True
131131
)
132132

133-
if self.spec.test_data:
133+
if self.datasets.test_data is not None:
134134
try:
135135
(
136136
self.test_eval_metrics,
137-
summary_metrics,
138-
test_data,
137+
summary_metrics
139138
) = self._test_evaluate_metrics(
140139
elapsed_time=elapsed_time,
141140
)
@@ -361,7 +360,7 @@ def generate_report(self):
361360
def _test_evaluate_metrics(self, elapsed_time=0):
362361
total_metrics = pd.DataFrame()
363362
summary_metrics = pd.DataFrame()
364-
data = TestData(self.spec)
363+
data = self.datasets.test_data
365364

366365
# Generate y_pred and y_true for each series
367366
for s_id in self.forecast_output.list_series_ids():
@@ -398,7 +397,7 @@ def _test_evaluate_metrics(self, elapsed_time=0):
398397
total_metrics = pd.concat([total_metrics, metrics_df], axis=1)
399398

400399
if total_metrics.empty:
401-
return total_metrics, summary_metrics, data
400+
return total_metrics, summary_metrics
402401

403402
summary_metrics = pd.DataFrame(
404403
{
@@ -464,7 +463,7 @@ def _test_evaluate_metrics(self, elapsed_time=0):
464463
]
465464
summary_metrics = summary_metrics[new_column_order]
466465

467-
return total_metrics, summary_metrics, data
466+
return total_metrics, summary_metrics
468467

469468
def _save_report(
470469
self,
@@ -548,7 +547,7 @@ def _save_report(
548547
)
549548

550549
# test_metrics csv report
551-
if self.spec.test_data is not None:
550+
if self.datasets.test_data is not None:
552551
if test_metrics_df is not None:
553552
test_metrics_df_formatted = test_metrics_df.reset_index().rename(
554553
{"index": "metrics", "Series 1": metrics_col_name}, axis=1

ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323

2424

2525
class HistoricalData(AbstractData):
26-
def __init__(self, spec: dict):
27-
super().__init__(spec=spec, name="historical_data")
26+
def __init__(self, spec, historical_data = None):
27+
super().__init__(spec=spec, name="historical_data", data=historical_data)
2828

2929
def _ingest_data(self, spec):
3030
try:
@@ -52,8 +52,11 @@ def _verify_dt_col(self, spec):
5252

5353

5454
class AdditionalData(AbstractData):
55-
def __init__(self, spec, historical_data):
56-
if spec.additional_data is not None:
55+
def __init__(self, spec, historical_data, additional_data=None):
56+
if additional_data is not None:
57+
super().__init__(spec=spec, name="additional_data", data=additional_data)
58+
self.additional_regressors = list(self.data.columns)
59+
elif spec.additional_data is not None:
5760
super().__init__(spec=spec, name="additional_data")
5861
add_dates = self.data.index.get_level_values(0).unique().tolist()
5962
add_dates.sort()
@@ -110,14 +113,15 @@ def _ingest_data(self, spec):
110113

111114

112115
class TestData(AbstractData):
113-
def __init__(self, spec):
114-
super().__init__(spec=spec, name="test_data")
116+
def __init__(self, spec, test_data):
117+
if test_data is not None or spec.test_data is not None:
118+
super().__init__(spec=spec, name="test_data", data=test_data)
115119
self.dt_column_name = spec.datetime_column.name
116120
self.target_name = spec.target_column
117121

118122

119123
class ForecastDatasets:
120-
def __init__(self, config: ForecastOperatorConfig):
124+
def __init__(self, config: ForecastOperatorConfig, historical_data=None, additional_data=None, test_data=None):
121125
"""Instantiates the DataIO instance.
122126
123127
Properties
@@ -127,11 +131,15 @@ def __init__(self, config: ForecastOperatorConfig):
127131
"""
128132
self.historical_data: HistoricalData = None
129133
self.additional_data: AdditionalData = None
130-
131134
self._horizon = config.spec.horizon
132135
self._datetime_column_name = config.spec.datetime_column.name
133136
self._target_col = config.spec.target_column
134-
self._load_data(config.spec)
137+
if historical_data is not None:
138+
self.historical_data = HistoricalData(config.spec, historical_data)
139+
self.additional_data = AdditionalData(config.spec, self.historical_data, additional_data)
140+
else:
141+
self._load_data(config.spec)
142+
self.test_data = TestData(config.spec, test_data)
135143

136144
def _load_data(self, spec):
137145
"""Loads forecasting input data."""
@@ -200,7 +208,7 @@ def get_horizon_at_series(self, s_id):
200208
return self.get_data_at_series(s_id)[-self._horizon :]
201209

202210
def has_artificial_series(self):
203-
return self.historical_data._data_transformer.has_artificial_series
211+
return bool(self.historical_data.spec.target_category_columns)
204212

205213
def get_earliest_timestamp(self):
206214
return self.historical_data.get_min_time()
@@ -251,7 +259,7 @@ def __init__(
251259
target_column: str,
252260
dt_column: str,
253261
):
254-
"""Forecast Output contains all of the details required to generate the forecast.csv output file.
262+
"""Forecast Output contains all the details required to generate the forecast.csv output file.
255263
256264
init
257265
-------

ads/opctl/operator/lowcode/forecast/model_evaluator.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -91,20 +91,13 @@ def create_operator_config(self, operator_config, backtest, model, historical_da
9191
output_dir = operator_config.spec.output_directory.url
9292
output_file_path = f'{output_dir}/back_testing/{model}/{backtest}'
9393
Path(output_file_path).mkdir(parents=True, exist_ok=True)
94-
historical_data_url = f'{output_file_path}/historical.csv'
95-
additional_data_url = f'{output_file_path}/additional.csv'
96-
test_data_url = f'{output_file_path}/test.csv'
97-
historical_data.to_csv(historical_data_url, index=False)
98-
additional_data.to_csv(additional_data_url, index=False)
99-
test_data.to_csv(test_data_url, index=False)
10094
backtest_op_config_draft = operator_config.to_dict()
10195
backtest_spec = backtest_op_config_draft["spec"]
102-
backtest_spec["historical_data"]["url"] = historical_data_url
10396
backtest_spec["datetime_column"]["format"] = None
104-
if backtest_spec["additional_data"]:
105-
backtest_spec["additional_data"]["url"] = additional_data_url
106-
backtest_spec["test_data"] = {}
107-
backtest_spec["test_data"]["url"] = test_data_url
97+
backtest_spec.pop("test_data")
98+
backtest_spec.pop("additional_data")
99+
backtest_spec.pop("historical_data")
100+
backtest_spec["generate_report"] = False
108101
backtest_spec["model"] = model
109102
backtest_spec['model_kwargs'] = None
110103
backtest_spec["output_directory"] = {"url": output_file_path}
@@ -119,19 +112,23 @@ def create_operator_config(self, operator_config, backtest, model, historical_da
119112
def run_all_models(self, datasets: ForecastDatasets, operator_config: ForecastOperatorConfig):
120113
cut_offs, train_sets, additional_data, test_sets = self.generate_k_fold_data(datasets, operator_config)
121114
metrics = {}
115+
date_col = operator_config.spec.datetime_column.name
122116
for model in self.models:
123117
from .model.factory import ForecastOperatorModelFactory
124118
metrics[model] = {}
125119
for i in range(len(cut_offs)):
126120
try:
127-
backtest_historical_data = train_sets[i]
128-
backtest_additional_data = additional_data[i]
129-
backtest_test_data = test_sets[i]
121+
backtest_historical_data = train_sets[i].set_index([date_col, DataColumns.Series])
122+
backtest_additional_data = additional_data[i].set_index([date_col, DataColumns.Series])
123+
backtest_test_data = test_sets[i].set_index([date_col, DataColumns.Series])
130124
backtest_operator_config = self.create_operator_config(operator_config, i, model,
131125
backtest_historical_data,
132126
backtest_additional_data,
133127
backtest_test_data)
134-
datasets = ForecastDatasets(backtest_operator_config)
128+
datasets = ForecastDatasets(backtest_operator_config,
129+
backtest_historical_data,
130+
backtest_additional_data,
131+
backtest_test_data)
135132
ForecastOperatorModelFactory.get_model(
136133
backtest_operator_config, datasets
137134
).generate_report()

0 commit comments

Comments
 (0)