Skip to content

Commit 32e748e

Browse files
committed
improve unit test
1 parent 9e3442e commit 32e748e

File tree

3 files changed

+38
-15
lines changed

3 files changed

+38
-15
lines changed

ads/opctl/operator/lowcode/forecast/model/base_model.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -153,9 +153,8 @@ def generate_report(self):
153153
model_description,
154154
other_sections,
155155
) = self._generate_report()
156-
report_title = self.config.spec.report_title or "Forecast Report"
157156
header_section = rc.Block(
158-
rc.Heading(report_title, level=1),
157+
rc.Heading(self.spec.report_title, level=1),
159158
rc.Text(
160159
f"You selected the {self.spec.model} model.\nBased on your dataset, you could have also selected any of the models: {SupportedModels.keys()}."
161160
),
@@ -572,7 +571,7 @@ def _save_report(
572571
if self.spec.generate_explanations:
573572
try:
574573
if not self.formatted_global_explanation.empty:
575-
if not self.spec.generate_explanations_file:
574+
if self.spec.generate_explanation_files:
576575
write_data(
577576
data=self.formatted_global_explanation,
578577
filename=os.path.join(
@@ -589,7 +588,7 @@ def _save_report(
589588
)
590589

591590
if not self.formatted_local_explanation.empty:
592-
if not self.spec.generate_explanations_file:
591+
if self.spec.generate_explanation_files:
593592
write_data(
594593
data=self.formatted_local_explanation,
595594
filename=os.path.join(

ads/opctl/operator/lowcode/forecast/operator_config.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/usr/bin/env python
22

3-
# Copyright (c) 2023, 2024 Oracle and/or its affiliates.
3+
# Copyright (c) 2023, 2025 Oracle and/or its affiliates.
44
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
55

66
import os
@@ -117,7 +117,7 @@ class ForecastOperatorSpec(DataClassSerializable):
117117
generate_metrics: bool = None
118118
generate_metrics_file: bool = None
119119
generate_explanations: bool = None
120-
generate_explanations_file: bool = None
120+
generate_explanation_files: bool = None
121121
explanations_accuracy_mode: str = None
122122
horizon: int = None
123123
model: str = None
@@ -136,9 +136,7 @@ def __post_init__(self):
136136
self.output_directory = self.output_directory or OutputDirectory(
137137
url=find_output_dirname(self.output_directory)
138138
)
139-
self.generate_model_pickle = (
140-
True if self.generate_model_pickle or self.what_if_analysis else False
141-
)
139+
self.generate_model_pickle = self.generate_model_pickle or self.what_if_analysis
142140
self.metric = (self.metric or "").lower() or SupportedMetrics.SMAPE.lower()
143141
self.model = self.model or SupportedModels.Prophet
144142
self.confidence_interval_width = self.confidence_interval_width or 0.80
@@ -166,9 +164,9 @@ def __post_init__(self):
166164
if self.generate_forecast_file is not None
167165
else True
168166
)
169-
self.generate_explanations_file = (
170-
self.generate_explanations_file
171-
if self.generate_explanations_file is not None
167+
self.generate_explanation_files = (
168+
self.generate_explanation_files
169+
if self.generate_explanation_files is not None
172170
else True
173171
)
174172
# For Explanations Generation. When user doesn't specify defaults to False
@@ -191,6 +189,7 @@ def __post_init__(self):
191189
if self.generate_model_pickle is not None
192190
else False
193191
)
192+
self.report_title = self.report_title or "Forecast Report"
194193
self.report_theme = self.report_theme or "light"
195194
self.metrics_filename = self.metrics_filename or "metrics.csv"
196195
self.test_metrics_filename = self.test_metrics_filename or "test_metrics.csv"

tests/operators/forecast/test_errors.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -929,19 +929,24 @@ def test_generate_files(operator_setup, model):
929929
)
930930

931931
yaml_i = TEMPLATE_YAML.copy()
932-
yaml_i["spec"]["horizon"] = 10
932+
yaml_i["spec"]["horizon"] = 3
933933
yaml_i["spec"]["model"] = model
934934
yaml_i["spec"]["historical_data"] = {"format": "pandas"}
935+
yaml_i["spec"]["additional_data"] = {"format": "pandas"}
935936
yaml_i["spec"]["target_column"] = TARGET_COL.name
936937
yaml_i["spec"]["datetime_column"]["name"] = HISTORICAL_DATETIME_COL.name
937-
yaml_i["spec"]["report_title"] = "Skibidi ADS Skibidi"
938938
yaml_i["spec"]["output_directory"]["url"] = operator_setup
939-
yaml_i["spec"]["generate_explanations_file"] = False
939+
yaml_i["spec"]["generate_explanation_files"] = False
940940
yaml_i["spec"]["generate_forecast_file"] = False
941941
yaml_i["spec"]["generate_metrics_file"] = False
942+
yaml_i["spec"]["generate_explanations"] = True
942943

943944
df = pd.concat([HISTORICAL_DATETIME_COL[:15], TARGET_COL[:15]], axis=1)
945+
df_add = pd.concat([HISTORICAL_DATETIME_COL[:18], ADD_COLS[:18]], axis=1)
946+
print(f"df: {df}")
947+
print(f"df_add: {df_add}")
944948
yaml_i["spec"]["historical_data"]["data"] = df
949+
yaml_i["spec"]["additional_data"]["data"] = df_add
945950
operator_config = ForecastOperatorConfig.from_dict(yaml_i)
946951
results = operate(operator_config)
947952
files = os.listdir(operator_setup)
@@ -952,8 +957,28 @@ def test_generate_files(operator_setup, model):
952957
assert (
953958
"metrics.csv" not in files
954959
), "Generated metrics file, but `generate_metrics_file` was set False"
960+
assert (
961+
"local_explanations.csv" not in files
962+
), "Generated metrics file, but `generate_explanation_files` was set False"
963+
assert (
964+
"global_explanations.csv" not in files
965+
), "Generated metrics file, but `generate_explanation_files` was set False"
955966
assert not results.get_forecast().empty
956967
assert not results.get_metrics().empty
968+
assert not results.get_global_explanations().empty
969+
assert not results.get_local_explanations().empty
970+
971+
yaml_i["spec"].pop("generate_explanation_files")
972+
yaml_i["spec"].pop("generate_forecast_file")
973+
yaml_i["spec"].pop("generate_metrics_file")
974+
operator_config = ForecastOperatorConfig.from_dict(yaml_i)
975+
results = operate(operator_config)
976+
files = os.listdir(operator_setup)
977+
assert "report.html" in files, "Failed to generate report"
978+
assert "forecast.csv" in files, "Failed to generate forecast file"
979+
assert "metrics.csv" in files, "Failed to generated metrics file"
980+
assert "local_explanation.csv" in files, "Failed to generated local expl file"
981+
assert "global_explanation.csv" in files, "Failed to generated global expl file"
957982

958983

959984
if __name__ == "__main__":

0 commit comments

Comments
 (0)