Skip to content

Commit 07f0280

Browse files
committed
auto model evalulator added for forecasting tasks
1 parent e5eceef commit 07f0280

File tree

1 file changed

+88
-0
lines changed

1 file changed

+88
-0
lines changed
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# -*- coding: utf-8; -*-
2+
3+
# Copyright (c) 2023 Oracle and/or its affiliates.
4+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5+
6+
7+
import numpy as np
8+
from ads.opctl import logger
9+
from ads.opctl.operator.lowcode.common.utils import (
10+
find_output_dirname,
11+
)
12+
from .const import ForecastOutputColumns
13+
from .model.forecast_datasets import ForecastDatasets
14+
from .operator_config import ForecastOperatorConfig
15+
16+
17+
class ModelEvaluator:
18+
def __init__(self, models, k=5, subsample_ratio=0.20):
19+
self.models = models
20+
self.k = k
21+
self.subsample_ratio = subsample_ratio
22+
23+
def generate_k_fold_data(self, datasets: ForecastDatasets, date_col: str, horizon: int):
24+
historical_data = datasets.historical_data.data.reset_index()
25+
series_col = ForecastOutputColumns.SERIES
26+
group_counts = historical_data[series_col].value_counts()
27+
28+
sample_count = max(5, int(len(group_counts) * self.subsample_ratio))
29+
sampled_groups = group_counts.head(sample_count)
30+
sampled_historical_data = historical_data[historical_data[series_col].isin(sampled_groups.index)]
31+
32+
min_group = group_counts.idxmin()
33+
min_series_data = historical_data[historical_data[series_col] == min_group]
34+
unique_dates = min_series_data[date_col].unique()
35+
36+
sorted_dates = np.sort(unique_dates)
37+
train_window_size = [len(sorted_dates) - (i + 1) * horizon for i in range(self.k)]
38+
valid_train_window_size = [ws for ws in train_window_size if ws >= horizon * 3]
39+
if len(valid_train_window_size) < self.k:
40+
logger.warn(f"Only ${valid_train_window_size} backtests can be created")
41+
42+
cut_offs = sorted_dates[-horizon - 1:-horizon * (self.k + 1):-horizon][:len(valid_train_window_size)]
43+
training_datasets = [sampled_historical_data[sampled_historical_data[date_col] <= cut_off_date] for cut_off_date
44+
in cut_offs]
45+
test_datasets = [sampled_historical_data[sampled_historical_data[date_col] > cut_offs[0]]]
46+
for i, current in enumerate(cut_offs[1:]):
47+
test_datasets.append(sampled_historical_data[(current < sampled_historical_data[date_col]) & (
48+
sampled_historical_data[date_col] <= cut_offs[i])])
49+
return cut_offs, training_datasets, test_datasets
50+
51+
def remove_none_values(self, obj):
52+
if isinstance(obj, dict):
53+
return {k: self.remove_none_values(v) for k, v in obj.items() if k is not None and v is not None}
54+
else:
55+
return obj
56+
57+
def run_all_models(self, datasets: ForecastDatasets, operator_config: ForecastOperatorConfig):
58+
date_col = operator_config.spec.datetime_column.name
59+
horizon = operator_config.spec.horizon
60+
cut_offs, train_sets, test_sets = self.generate_k_fold_data(datasets, date_col, horizon)
61+
62+
for model in self.models:
63+
from .model.factory import ForecastOperatorModelFactory
64+
for i in range(len(cut_offs)):
65+
backtest_historical_data = train_sets[i]
66+
backtest_test_data = test_sets[i]
67+
output_dir = find_output_dirname(operator_config.spec.output_directory)
68+
output_file_path = f'{output_dir}back_test/{i}'
69+
from pathlib import Path
70+
Path(output_file_path).mkdir(parents=True, exist_ok=True)
71+
historical_data_url = f'{output_file_path}/historical.csv'
72+
test_data_url = f'{output_file_path}/test.csv'
73+
backtest_historical_data.to_csv(historical_data_url, index=False)
74+
backtest_test_data.to_csv(test_data_url, index=False)
75+
backtest_op_config_draft = operator_config.to_dict()
76+
backtest_spec = backtest_op_config_draft["spec"]
77+
backtest_spec["historical_data"]["url"] = historical_data_url
78+
backtest_spec["test_data"]["url"] = test_data_url
79+
backtest_spec["model"] = model
80+
backtest_spec["output_directory"]["url"] = output_dir
81+
cleaned_config = self.remove_none_values(backtest_op_config_draft)
82+
backtest_op_cofig = ForecastOperatorConfig.from_dict(
83+
obj_dict=cleaned_config)
84+
datasets = ForecastDatasets(backtest_op_cofig)
85+
86+
ForecastOperatorModelFactory.get_model(
87+
operator_config, datasets
88+
).generate_report()

0 commit comments

Comments
 (0)