Skip to content

Commit 589b98c

Browse files
committed
updated auto-select reporting
1 parent e30e3fa commit 589b98c

File tree

4 files changed

+19
-19
lines changed

4 files changed

+19
-19
lines changed

ads/opctl/operator/lowcode/forecast/const.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,3 +87,4 @@ class ForecastOutputColumns(str, metaclass=ExtendedEnumMeta):
8787
PROPHET_INTERNAL_DATE_COL = "ds"
8888
RENDER_LIMIT = 5000
8989
AUTO_SELECT = "auto-select"
90+
BACKTEST_REPORT_NAME = "back_test.csv"

ads/opctl/operator/lowcode/forecast/model/base_model.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@
4747
SpeedAccuracyMode,
4848
SupportedMetrics,
4949
SupportedModels,
50+
SpeedAccuracyMode,
51+
AUTO_SELECT,
52+
BACKTEST_REPORT_NAME
5053
)
5154
from ..operator_config import ForecastOperatorConfig, ForecastOperatorSpec
5255
from .forecast_datasets import ForecastDatasets
@@ -255,12 +258,9 @@ def generate_report(self):
255258

256259
backtest_sections = []
257260
output_dir = self.spec.output_directory.url
258-
backtest_report_name = "backtest_stats.csv"
259-
file_path = f"{output_dir}/{backtest_report_name}"
261+
file_path = f"{output_dir}/{BACKTEST_REPORT_NAME}"
260262
if self.spec.model == AUTO_SELECT:
261-
backtest_sections.append(
262-
rc.Heading("Auto-select statistics", level=2)
263-
)
263+
backtest_sections.append(rc.Heading("Auto-Select Backtesting and Performance Metrics", level=2))
264264
if not os.path.exists(file_path):
265265
failure_msg = rc.Text(
266266
"auto-select could not be executed. Please check the "
@@ -269,19 +269,15 @@ def generate_report(self):
269269
backtest_sections.append(failure_msg)
270270
else:
271271
backtest_stats = pd.read_csv(file_path)
272-
average_dict = backtest_stats.mean().to_dict()
273-
del average_dict["backtest"]
272+
model_metric_map = backtest_stats.drop(columns=['metric', 'backtest'])
273+
average_dict = {k: round(v, 4) for k, v in model_metric_map.mean().to_dict().items()}
274274
best_model = min(average_dict, key=average_dict.get)
275-
backtest_text = rc.Heading("Back Testing Metrics", level=3)
276275
summary_text = rc.Text(
277-
f"Overall, the average scores for the models are {average_dict}, with {best_model}"
278-
f" being identified as the top-performing model during backtesting."
279-
)
276+
f"Overall, the average {self.spec.metric} scores for the models are {average_dict}, with"
277+
f" {best_model} being identified as the top-performing model during backtesting.")
280278
backtest_table = rc.DataTable(backtest_stats, index=True)
281279
liner_plot = get_auto_select_plot(backtest_stats)
282-
backtest_sections.extend(
283-
[backtest_text, backtest_table, summary_text, liner_plot]
284-
)
280+
backtest_sections.extend([backtest_table, summary_text, liner_plot])
285281

286282
forecast_plots = []
287283
if len(self.forecast_output.list_series_ids()) > 0:

ads/opctl/operator/lowcode/forecast/model_evaluator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from ads.opctl import logger
1212
from ads.opctl.operator.lowcode.common.const import DataColumns
13+
from ads.opctl.operator.lowcode.forecast.const import BACKTEST_REPORT_NAME
1314
from .model.forecast_datasets import ForecastDatasets
1415
from .operator_config import ForecastOperatorConfig
1516
from ads.opctl.operator.lowcode.forecast.model.factory import SupportedModels
@@ -156,8 +157,8 @@ def find_best_model(self, datasets: ForecastDatasets, operator_config: ForecastO
156157
best_model = min(avg_backtests_metric, key=avg_backtests_metric.get)
157158
logger.info(f"Among models {self.models}, {best_model} model shows better performance during backtesting.")
158159
backtest_stats = pd.DataFrame(nonempty_metrics).rename_axis('backtest')
160+
backtest_stats["metric"] = operator_config.spec.metric
159161
backtest_stats.reset_index(inplace=True)
160162
output_dir = operator_config.spec.output_directory.url
161-
backtest_report_name = "backtest_stats.csv"
162-
backtest_stats.to_csv(f"{output_dir}/{backtest_report_name}", index=False)
163+
backtest_stats.to_csv(f"{output_dir}/{BACKTEST_REPORT_NAME}", index=False)
163164
return best_model

ads/opctl/operator/lowcode/forecast/utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -261,10 +261,12 @@ def _add_unit(num, unit):
261261

262262
def get_auto_select_plot(backtest_results):
263263
fig = go.Figure()
264-
columns = backtest_results.columns.tolist()
264+
back_test_csv_columns = backtest_results.columns.tolist()
265265
back_test_column = "backtest"
266-
columns.remove(back_test_column)
267-
for column in columns:
266+
metric_column = "metric"
267+
models = [x for x in back_test_csv_columns if x not in [back_test_column, metric_column]]
268+
for i, column in enumerate(models):
269+
color = 0 #int(i * 255 / len(columns))
268270
fig.add_trace(
269271
go.Scatter(
270272
x=backtest_results[back_test_column],

0 commit comments

Comments
 (0)