Skip to content

Commit 3d8f148

Browse files
authored
Updated Auto-Select Reporting in Forecast Operator (#1013)
2 parents eabb40e + 6ea37e5 commit 3d8f148

File tree

4 files changed

+16
-19
lines changed

4 files changed

+16
-19
lines changed

ads/opctl/operator/lowcode/forecast/const.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,3 +87,4 @@ class ForecastOutputColumns(str, metaclass=ExtendedEnumMeta):
8787
PROPHET_INTERNAL_DATE_COL = "ds"
8888
RENDER_LIMIT = 5000
8989
AUTO_SELECT = "auto-select"
90+
BACKTEST_REPORT_NAME = "back_test.csv"

ads/opctl/operator/lowcode/forecast/model/base_model.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
SpeedAccuracyMode,
4848
SupportedMetrics,
4949
SupportedModels,
50+
BACKTEST_REPORT_NAME
5051
)
5152
from ..operator_config import ForecastOperatorConfig, ForecastOperatorSpec
5253
from .forecast_datasets import ForecastDatasets
@@ -256,12 +257,9 @@ def generate_report(self):
256257

257258
backtest_sections = []
258259
output_dir = self.spec.output_directory.url
259-
backtest_report_name = "backtest_stats.csv"
260-
file_path = f"{output_dir}/{backtest_report_name}"
260+
file_path = f"{output_dir}/{BACKTEST_REPORT_NAME}"
261261
if self.spec.model == AUTO_SELECT:
262-
backtest_sections.append(
263-
rc.Heading("Auto-select statistics", level=2)
264-
)
262+
backtest_sections.append(rc.Heading("Auto-Select Backtesting and Performance Metrics", level=2))
265263
if not os.path.exists(file_path):
266264
failure_msg = rc.Text(
267265
"auto-select could not be executed. Please check the "
@@ -270,19 +268,15 @@ def generate_report(self):
270268
backtest_sections.append(failure_msg)
271269
else:
272270
backtest_stats = pd.read_csv(file_path)
273-
average_dict = backtest_stats.mean().to_dict()
274-
del average_dict["backtest"]
271+
model_metric_map = backtest_stats.drop(columns=['metric', 'backtest'])
272+
average_dict = {k: round(v, 4) for k, v in model_metric_map.mean().to_dict().items()}
275273
best_model = min(average_dict, key=average_dict.get)
276-
backtest_text = rc.Heading("Back Testing Metrics", level=3)
277274
summary_text = rc.Text(
278-
f"Overall, the average scores for the models are {average_dict}, with {best_model}"
279-
f" being identified as the top-performing model during backtesting."
280-
)
275+
f"Overall, the average {self.spec.metric} scores for the models are {average_dict}, with"
276+
f" {best_model} being identified as the top-performing model during backtesting.")
281277
backtest_table = rc.DataTable(backtest_stats, index=True)
282278
liner_plot = get_auto_select_plot(backtest_stats)
283-
backtest_sections.extend(
284-
[backtest_text, backtest_table, summary_text, liner_plot]
285-
)
279+
backtest_sections.extend([backtest_table, summary_text, liner_plot])
286280

287281
forecast_plots = []
288282
if len(self.forecast_output.list_series_ids()) > 0:

ads/opctl/operator/lowcode/forecast/model_evaluator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from ads.opctl import logger
1212
from ads.opctl.operator.lowcode.common.const import DataColumns
13+
from ads.opctl.operator.lowcode.forecast.const import BACKTEST_REPORT_NAME
1314
from .model.forecast_datasets import ForecastDatasets
1415
from .operator_config import ForecastOperatorConfig
1516
from ads.opctl.operator.lowcode.forecast.model.factory import SupportedModels
@@ -156,8 +157,8 @@ def find_best_model(self, datasets: ForecastDatasets, operator_config: ForecastO
156157
best_model = min(avg_backtests_metric, key=avg_backtests_metric.get)
157158
logger.info(f"Among models {self.models}, {best_model} model shows better performance during backtesting.")
158159
backtest_stats = pd.DataFrame(nonempty_metrics).rename_axis('backtest')
160+
backtest_stats["metric"] = operator_config.spec.metric
159161
backtest_stats.reset_index(inplace=True)
160162
output_dir = operator_config.spec.output_directory.url
161-
backtest_report_name = "backtest_stats.csv"
162-
backtest_stats.to_csv(f"{output_dir}/{backtest_report_name}", index=False)
163+
backtest_stats.to_csv(f"{output_dir}/{BACKTEST_REPORT_NAME}", index=False)
163164
return best_model

ads/opctl/operator/lowcode/forecast/utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -261,10 +261,11 @@ def _add_unit(num, unit):
261261

262262
def get_auto_select_plot(backtest_results):
263263
fig = go.Figure()
264-
columns = backtest_results.columns.tolist()
264+
back_test_csv_columns = backtest_results.columns.tolist()
265265
back_test_column = "backtest"
266-
columns.remove(back_test_column)
267-
for column in columns:
266+
metric_column = "metric"
267+
models = [x for x in back_test_csv_columns if x not in [back_test_column, metric_column]]
268+
for i, column in enumerate(models):
268269
fig.add_trace(
269270
go.Scatter(
270271
x=backtest_results[back_test_column],

0 commit comments

Comments
 (0)