|
| 1 | +# -*- coding: utf-8; -*- |
| 2 | + |
| 3 | +# Copyright (c) 2023 Oracle and/or its affiliates. |
| 4 | +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ |
| 5 | + |
| 6 | + |
| 7 | +import numpy as np |
| 8 | +from ads.opctl import logger |
| 9 | +from ads.opctl.operator.lowcode.common.utils import ( |
| 10 | + find_output_dirname, |
| 11 | +) |
| 12 | +from .const import ForecastOutputColumns |
| 13 | +from .model.forecast_datasets import ForecastDatasets |
| 14 | +from .operator_config import ForecastOperatorConfig |
| 15 | + |
| 16 | + |
| 17 | +class ModelEvaluator: |
| 18 | + def __init__(self, models, k=5, subsample_ratio=0.20): |
| 19 | + self.models = models |
| 20 | + self.k = k |
| 21 | + self.subsample_ratio = subsample_ratio |
| 22 | + |
| 23 | + def generate_k_fold_data(self, datasets: ForecastDatasets, date_col: str, horizon: int): |
| 24 | + historical_data = datasets.historical_data.data.reset_index() |
| 25 | + series_col = ForecastOutputColumns.SERIES |
| 26 | + group_counts = historical_data[series_col].value_counts() |
| 27 | + |
| 28 | + sample_count = max(5, int(len(group_counts) * self.subsample_ratio)) |
| 29 | + sampled_groups = group_counts.head(sample_count) |
| 30 | + sampled_historical_data = historical_data[historical_data[series_col].isin(sampled_groups.index)] |
| 31 | + |
| 32 | + min_group = group_counts.idxmin() |
| 33 | + min_series_data = historical_data[historical_data[series_col] == min_group] |
| 34 | + unique_dates = min_series_data[date_col].unique() |
| 35 | + |
| 36 | + sorted_dates = np.sort(unique_dates) |
| 37 | + train_window_size = [len(sorted_dates) - (i + 1) * horizon for i in range(self.k)] |
| 38 | + valid_train_window_size = [ws for ws in train_window_size if ws >= horizon * 3] |
| 39 | + if len(valid_train_window_size) < self.k: |
| 40 | + logger.warn(f"Only ${valid_train_window_size} backtests can be created") |
| 41 | + |
| 42 | + cut_offs = sorted_dates[-horizon - 1:-horizon * (self.k + 1):-horizon][:len(valid_train_window_size)] |
| 43 | + training_datasets = [sampled_historical_data[sampled_historical_data[date_col] <= cut_off_date] for cut_off_date |
| 44 | + in cut_offs] |
| 45 | + test_datasets = [sampled_historical_data[sampled_historical_data[date_col] > cut_offs[0]]] |
| 46 | + for i, current in enumerate(cut_offs[1:]): |
| 47 | + test_datasets.append(sampled_historical_data[(current < sampled_historical_data[date_col]) & ( |
| 48 | + sampled_historical_data[date_col] <= cut_offs[i])]) |
| 49 | + return cut_offs, training_datasets, test_datasets |
| 50 | + |
| 51 | + def remove_none_values(self, obj): |
| 52 | + if isinstance(obj, dict): |
| 53 | + return {k: self.remove_none_values(v) for k, v in obj.items() if k is not None and v is not None} |
| 54 | + else: |
| 55 | + return obj |
| 56 | + |
| 57 | + def run_all_models(self, datasets: ForecastDatasets, operator_config: ForecastOperatorConfig): |
| 58 | + date_col = operator_config.spec.datetime_column.name |
| 59 | + horizon = operator_config.spec.horizon |
| 60 | + cut_offs, train_sets, test_sets = self.generate_k_fold_data(datasets, date_col, horizon) |
| 61 | + |
| 62 | + for model in self.models: |
| 63 | + from .model.factory import ForecastOperatorModelFactory |
| 64 | + for i in range(len(cut_offs)): |
| 65 | + backtest_historical_data = train_sets[i] |
| 66 | + backtest_test_data = test_sets[i] |
| 67 | + output_dir = find_output_dirname(operator_config.spec.output_directory) |
| 68 | + output_file_path = f'{output_dir}back_test/{i}' |
| 69 | + from pathlib import Path |
| 70 | + Path(output_file_path).mkdir(parents=True, exist_ok=True) |
| 71 | + historical_data_url = f'{output_file_path}/historical.csv' |
| 72 | + test_data_url = f'{output_file_path}/test.csv' |
| 73 | + backtest_historical_data.to_csv(historical_data_url, index=False) |
| 74 | + backtest_test_data.to_csv(test_data_url, index=False) |
| 75 | + backtest_op_config_draft = operator_config.to_dict() |
| 76 | + backtest_spec = backtest_op_config_draft["spec"] |
| 77 | + backtest_spec["historical_data"]["url"] = historical_data_url |
| 78 | + backtest_spec["test_data"]["url"] = test_data_url |
| 79 | + backtest_spec["model"] = model |
| 80 | + backtest_spec["output_directory"]["url"] = output_dir |
| 81 | + cleaned_config = self.remove_none_values(backtest_op_config_draft) |
| 82 | + backtest_op_cofig = ForecastOperatorConfig.from_dict( |
| 83 | + obj_dict=cleaned_config) |
| 84 | + datasets = ForecastDatasets(backtest_op_cofig) |
| 85 | + |
| 86 | + ForecastOperatorModelFactory.get_model( |
| 87 | + operator_config, datasets |
| 88 | + ).generate_report() |
0 commit comments