Skip to content

Commit 90ffb9b

Browse files
authored
enable report title (#1118)
2 parents 92e22f8 + 8f3a070 commit 90ffb9b

File tree

4 files changed

+188
-55
lines changed

4 files changed

+188
-55
lines changed

ads/opctl/operator/lowcode/forecast/model/base_model.py

Lines changed: 53 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -153,9 +153,8 @@ def generate_report(self):
153153
model_description,
154154
other_sections,
155155
) = self._generate_report()
156-
157156
header_section = rc.Block(
158-
rc.Heading("Forecast Report", level=1),
157+
rc.Heading(self.spec.report_title, level=1),
159158
rc.Text(
160159
f"You selected the {self.spec.model} model.\nBased on your dataset, you could have also selected any of the models: {SupportedModels.keys()}."
161160
),
@@ -476,10 +475,11 @@ def _save_report(
476475
unique_output_dir = self.spec.output_directory.url
477476
results = ForecastResults()
478477

479-
if ObjectStorageDetails.is_oci_path(unique_output_dir):
480-
storage_options = default_signer()
481-
else:
482-
storage_options = {}
478+
storage_options = (
479+
default_signer()
480+
if ObjectStorageDetails.is_oci_path(unique_output_dir)
481+
else {}
482+
)
483483

484484
# report-creator html report
485485
if self.spec.generate_report:
@@ -510,12 +510,13 @@ def _save_report(
510510
if self.target_cat_col
511511
else result_df.drop(DataColumns.Series, axis=1)
512512
)
513-
write_data(
514-
data=result_df,
515-
filename=os.path.join(unique_output_dir, self.spec.forecast_filename),
516-
format="csv",
517-
storage_options=storage_options,
518-
)
513+
if self.spec.generate_forecast_file:
514+
write_data(
515+
data=result_df,
516+
filename=os.path.join(unique_output_dir, self.spec.forecast_filename),
517+
format="csv",
518+
storage_options=storage_options,
519+
)
519520
results.set_forecast(result_df)
520521

521522
# metrics csv report
@@ -529,15 +530,16 @@ def _save_report(
529530
metrics_df_formatted = metrics_df.reset_index().rename(
530531
{"index": "metrics", "Series 1": metrics_col_name}, axis=1
531532
)
532-
write_data(
533-
data=metrics_df_formatted,
534-
filename=os.path.join(
535-
unique_output_dir, self.spec.metrics_filename
536-
),
537-
format="csv",
538-
storage_options=storage_options,
539-
index=False,
540-
)
533+
if self.spec.generate_metrics_file:
534+
write_data(
535+
data=metrics_df_formatted,
536+
filename=os.path.join(
537+
unique_output_dir, self.spec.metrics_filename
538+
),
539+
format="csv",
540+
storage_options=storage_options,
541+
index=False,
542+
)
541543
results.set_metrics(metrics_df_formatted)
542544
else:
543545
logger.warning(
@@ -550,15 +552,16 @@ def _save_report(
550552
test_metrics_df_formatted = test_metrics_df.reset_index().rename(
551553
{"index": "metrics", "Series 1": metrics_col_name}, axis=1
552554
)
553-
write_data(
554-
data=test_metrics_df_formatted,
555-
filename=os.path.join(
556-
unique_output_dir, self.spec.test_metrics_filename
557-
),
558-
format="csv",
559-
storage_options=storage_options,
560-
index=False,
561-
)
555+
if self.spec.generate_metrics_file:
556+
write_data(
557+
data=test_metrics_df_formatted,
558+
filename=os.path.join(
559+
unique_output_dir, self.spec.test_metrics_filename
560+
),
561+
format="csv",
562+
storage_options=storage_options,
563+
index=False,
564+
)
562565
results.set_test_metrics(test_metrics_df_formatted)
563566
else:
564567
logger.warning(
@@ -568,31 +571,33 @@ def _save_report(
568571
if self.spec.generate_explanations:
569572
try:
570573
if not self.formatted_global_explanation.empty:
571-
write_data(
572-
data=self.formatted_global_explanation,
573-
filename=os.path.join(
574-
unique_output_dir, self.spec.global_explanation_filename
575-
),
576-
format="csv",
577-
storage_options=storage_options,
578-
index=True,
579-
)
574+
if self.spec.generate_explanation_files:
575+
write_data(
576+
data=self.formatted_global_explanation,
577+
filename=os.path.join(
578+
unique_output_dir, self.spec.global_explanation_filename
579+
),
580+
format="csv",
581+
storage_options=storage_options,
582+
index=True,
583+
)
580584
results.set_global_explanations(self.formatted_global_explanation)
581585
else:
582586
logger.warning(
583587
f"Attempted to generate global explanations for the {self.spec.global_explanation_filename} file, but an issue occured in formatting the explanations."
584588
)
585589

586590
if not self.formatted_local_explanation.empty:
587-
write_data(
588-
data=self.formatted_local_explanation,
589-
filename=os.path.join(
590-
unique_output_dir, self.spec.local_explanation_filename
591-
),
592-
format="csv",
593-
storage_options=storage_options,
594-
index=True,
595-
)
591+
if self.spec.generate_explanation_files:
592+
write_data(
593+
data=self.formatted_local_explanation,
594+
filename=os.path.join(
595+
unique_output_dir, self.spec.local_explanation_filename
596+
),
597+
format="csv",
598+
storage_options=storage_options,
599+
index=True,
600+
)
596601
results.set_local_explanations(self.formatted_local_explanation)
597602
else:
598603
logger.warning(

ads/opctl/operator/lowcode/forecast/operator_config.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/usr/bin/env python
22

3-
# Copyright (c) 2023, 2024 Oracle and/or its affiliates.
3+
# Copyright (c) 2023, 2025 Oracle and/or its affiliates.
44
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
55

66
import os
@@ -18,19 +18,23 @@
1818

1919
from .const import SpeedAccuracyMode, SupportedMetrics, SupportedModels
2020

21+
2122
@dataclass
2223
class AutoScaling(DataClassSerializable):
2324
"""Class representing simple autoscaling policy"""
25+
2426
minimum_instance: int = 1
2527
maximum_instance: int = None
2628
cool_down_in_seconds: int = 600
2729
scale_in_threshold: int = 10
2830
scale_out_threshold: int = 80
2931
scaling_metric: str = "CPU_UTILIZATION"
3032

33+
3134
@dataclass(repr=True)
3235
class ModelDeploymentServer(DataClassSerializable):
3336
"""Class representing model deployment server specification for whatif-analysis."""
37+
3438
display_name: str = None
3539
initial_shape: str = None
3640
description: str = None
@@ -42,10 +46,13 @@ class ModelDeploymentServer(DataClassSerializable):
4246
@dataclass(repr=True)
4347
class WhatIfAnalysis(DataClassSerializable):
4448
"""Class representing operator specification for whatif-analysis."""
49+
4550
model_display_name: str = None
4651
compartment_id: str = None
4752
project_id: str = None
48-
model_deployment: ModelDeploymentServer = field(default_factory=ModelDeploymentServer)
53+
model_deployment: ModelDeploymentServer = field(
54+
default_factory=ModelDeploymentServer
55+
)
4956

5057

5158
@dataclass(repr=True)
@@ -106,8 +113,11 @@ class ForecastOperatorSpec(DataClassSerializable):
106113
datetime_column: DateTimeColumn = field(default_factory=DateTimeColumn)
107114
target_category_columns: List[str] = field(default_factory=list)
108115
generate_report: bool = None
116+
generate_forecast_file: bool = None
109117
generate_metrics: bool = None
118+
generate_metrics_file: bool = None
110119
generate_explanations: bool = None
120+
generate_explanation_files: bool = None
111121
explanations_accuracy_mode: str = None
112122
horizon: int = None
113123
model: str = None
@@ -126,7 +136,7 @@ def __post_init__(self):
126136
self.output_directory = self.output_directory or OutputDirectory(
127137
url=find_output_dirname(self.output_directory)
128138
)
129-
self.generate_model_pickle = True if self.generate_model_pickle or self.what_if_analysis else False
139+
self.generate_model_pickle = self.generate_model_pickle or self.what_if_analysis
130140
self.metric = (self.metric or "").lower() or SupportedMetrics.SMAPE.lower()
131141
self.model = self.model or SupportedModels.Prophet
132142
self.confidence_interval_width = self.confidence_interval_width or 0.80
@@ -144,6 +154,21 @@ def __post_init__(self):
144154
self.generate_metrics = (
145155
self.generate_metrics if self.generate_metrics is not None else True
146156
)
157+
self.generate_metrics_file = (
158+
self.generate_metrics_file
159+
if self.generate_metrics_file is not None
160+
else True
161+
)
162+
self.generate_forecast_file = (
163+
self.generate_forecast_file
164+
if self.generate_forecast_file is not None
165+
else True
166+
)
167+
self.generate_explanation_files = (
168+
self.generate_explanation_files
169+
if self.generate_explanation_files is not None
170+
else True
171+
)
147172
# For Explanations Generation. When user doesn't specify defaults to False
148173
self.generate_explanations = (
149174
self.generate_explanations
@@ -164,6 +189,7 @@ def __post_init__(self):
164189
if self.generate_model_pickle is not None
165190
else False
166191
)
192+
self.report_title = self.report_title or "Forecast Report"
167193
self.report_theme = self.report_theme or "light"
168194
self.metrics_filename = self.metrics_filename or "metrics.csv"
169195
self.test_metrics_filename = self.test_metrics_filename or "test_metrics.csv"

docs/source/user_guide/operators/forecast_operator/development.rst

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,35 @@ Before running operators on a job, users must configure their output directory.
125125
horizon: 3
126126
target_column: y
127127
128+
129+
Exclude Writing Certain Output Files
130+
====================================
131+
132+
You can choose to exclude certain files from being written to the output folder. This may be because you are calling the API, and not using the output folder. The yaml options below are ``True`` by default, but can be set to ``False`` to prevent file generation.
133+
134+
.. code-block:: yaml
135+
136+
kind: operator
137+
type: forecast
138+
version: v1
139+
spec:
140+
datetime_column:
141+
name: ds
142+
historical_data:
143+
url: oci://<bucket_name>@<namespace_name>/example_yosemite_temps.csv
144+
output_directory:
145+
url: oci://<bucket_name>@<namespace_name>/my_results/
146+
horizon: 3
147+
target_column: y
148+
generate_report: True
149+
generate_forecast_file: False
150+
generate_metrics_file: False
151+
generate_explanations: True
152+
generate_explanations_file: False
153+
154+
The above example will save a report.html to ``oci://<bucket_name>@<namespace_name>/my_results/``, but it will NOT save other files.
155+
156+
128157
Ingesting and Interpreting Outputs
129158
==================================
130159

tests/operators/forecast/test_errors.py

Lines changed: 77 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
from ads.opctl.operator.lowcode.forecast.utils import smape
3232
from ads.opctl.operator.cmd import run
33+
from ads.opctl.operator.lowcode.forecast.__main__ import operate
3334
import os
3435
import json
3536
import math
@@ -885,17 +886,36 @@ def test_auto_select(operator_setup):
885886

886887

887888
@pytest.mark.parametrize("model", ["prophet"])
888-
def test_prophet_floor_cap(operator_setup, model):
889-
from ads.opctl.operator.lowcode.forecast.__main__ import operate
889+
def test_report_title(operator_setup, model):
890+
yaml_i = TEMPLATE_YAML.copy()
891+
yaml_i["spec"]["horizon"] = 10
892+
yaml_i["spec"]["model"] = model
893+
yaml_i["spec"]["historical_data"] = {"format": "pandas"}
894+
yaml_i["spec"]["target_column"] = TARGET_COL.name
895+
yaml_i["spec"]["datetime_column"]["name"] = HISTORICAL_DATETIME_COL.name
896+
yaml_i["spec"]["report_title"] = "Skibidi ADS Skibidi"
897+
yaml_i["spec"]["output_directory"]["url"] = operator_setup
890898

899+
df = pd.concat([HISTORICAL_DATETIME_COL[:15], TARGET_COL[:15]], axis=1)
900+
yaml_i["spec"]["historical_data"]["data"] = df
901+
operator_config = ForecastOperatorConfig.from_dict(yaml_i)
902+
results = operate(operator_config)
903+
with open(os.path.join(operator_setup, "report.html")) as f:
904+
for line in f:
905+
if "Skibidi ADS Skibidi" in line:
906+
return True
907+
assert False, "Report Title was not set"
908+
909+
910+
@pytest.mark.parametrize("model", ["prophet"])
911+
def test_prophet_floor_cap(operator_setup, model):
891912
yaml_i = TEMPLATE_YAML.copy()
892913
yaml_i["spec"]["horizon"] = 10
893914
yaml_i["spec"]["model"] = model
894915
yaml_i["spec"]["historical_data"] = {"format": "pandas"}
895-
yaml_i["spec"]["target_column"] = "target"
896916
yaml_i["spec"]["datetime_column"]["name"] = HISTORICAL_DATETIME_COL.name
897917
yaml_i["spec"]["output_directory"]["url"] = operator_setup
898-
yaml_i["spec"]["model_kwargs"] = {"max": 20, "min": 0}
918+
yaml_i["spec"]["target_column"] = "target"
899919

900920
target_column = pd.Series(np.arange(20, -6, -2), name="target")
901921
df = pd.concat(
@@ -926,5 +946,58 @@ def test_prophet_floor_cap(operator_setup, model):
926946
), "`max` not obeyed in prophet"
927947

928948

949+
@pytest.mark.parametrize("model", ["prophet"])
950+
def test_generate_files(operator_setup, model):
951+
yaml_i = TEMPLATE_YAML.copy()
952+
yaml_i["spec"]["horizon"] = 3
953+
yaml_i["spec"]["model"] = model
954+
yaml_i["spec"]["historical_data"] = {"format": "pandas"}
955+
yaml_i["spec"]["additional_data"] = {"format": "pandas"}
956+
yaml_i["spec"]["target_column"] = TARGET_COL.name
957+
yaml_i["spec"]["datetime_column"]["name"] = HISTORICAL_DATETIME_COL.name
958+
yaml_i["spec"]["output_directory"]["url"] = operator_setup
959+
yaml_i["spec"]["generate_explanation_files"] = False
960+
yaml_i["spec"]["generate_forecast_file"] = False
961+
yaml_i["spec"]["generate_metrics_file"] = False
962+
yaml_i["spec"]["generate_explanations"] = True
963+
964+
df = pd.concat([HISTORICAL_DATETIME_COL[:15], TARGET_COL[:15]], axis=1)
965+
df_add = pd.concat([HISTORICAL_DATETIME_COL[:18], ADD_COLS[:18]], axis=1)
966+
yaml_i["spec"]["historical_data"]["data"] = df
967+
yaml_i["spec"]["additional_data"]["data"] = df_add
968+
operator_config = ForecastOperatorConfig.from_dict(yaml_i)
969+
results = operate(operator_config)
970+
files = os.listdir(operator_setup)
971+
assert "report.html" in files, "Failed to generate report"
972+
assert (
973+
"forecast.csv" not in files
974+
), "Generated forecast file, but `generate_forecast_file` was set False"
975+
assert (
976+
"metrics.csv" not in files
977+
), "Generated metrics file, but `generate_metrics_file` was set False"
978+
assert (
979+
"local_explanations.csv" not in files
980+
), "Generated metrics file, but `generate_explanation_files` was set False"
981+
assert (
982+
"global_explanations.csv" not in files
983+
), "Generated metrics file, but `generate_explanation_files` was set False"
984+
assert not results.get_forecast().empty
985+
assert not results.get_metrics().empty
986+
assert not results.get_global_explanations().empty
987+
assert not results.get_local_explanations().empty
988+
989+
yaml_i["spec"].pop("generate_explanation_files")
990+
yaml_i["spec"].pop("generate_forecast_file")
991+
yaml_i["spec"].pop("generate_metrics_file")
992+
operator_config = ForecastOperatorConfig.from_dict(yaml_i)
993+
results = operate(operator_config)
994+
files = os.listdir(operator_setup)
995+
assert "report.html" in files, "Failed to generate report"
996+
assert "forecast.csv" in files, "Failed to generate forecast file"
997+
assert "metrics.csv" in files, "Failed to generated metrics file"
998+
assert "local_explanation.csv" in files, "Failed to generated local expl file"
999+
assert "global_explanation.csv" in files, "Failed to generated global expl file"
1000+
1001+
9291002
if __name__ == "__main__":
9301003
pass

0 commit comments

Comments
 (0)