Skip to content

Commit fdbe233

Browse files
committed
update unit tests
1 parent 3e065c0 commit fdbe233

File tree

1 file changed

+41
-23
lines changed

1 file changed

+41
-23
lines changed

tests/operators/forecast/test_errors.py

Lines changed: 41 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -948,6 +948,38 @@ def test_prophet_floor_cap(operator_setup, model):
948948
), "`max` not obeyed in prophet"
949949

950950

951+
def _check_results_obj(results):
952+
assert not results.get_forecast().empty
953+
assert not results.get_metrics().empty
954+
assert not results.get_global_explanations().empty
955+
assert not results.get_local_explanations().empty
956+
957+
958+
def _check_no_skippable_files(yaml_i, check_report=True):
959+
files = os.listdir(yaml_i["spec"]["output_directory"]["url"])
960+
961+
if "errors.json" in files:
962+
with open(
963+
os.path.join(yaml_i["spec"]["output_directory"]["url"], "errors.json")
964+
) as f:
965+
assert False, f"Failed due to errors.json being created: {f.read()}"
966+
if check_report:
967+
assert "report.html" in files, "Failed to generate report"
968+
969+
assert (
970+
"forecast.csv" not in files
971+
), "Generated forecast file, but `generate_forecast_file` was set False"
972+
assert (
973+
"metrics.csv" not in files
974+
), "Generated metrics file, but `generate_metrics_file` was set False"
975+
assert (
976+
"local_explanations.csv" not in files
977+
), "Generated metrics file, but `generate_explanation_files` was set False"
978+
assert (
979+
"global_explanations.csv" not in files
980+
), "Generated metrics file, but `generate_explanation_files` was set False"
981+
982+
951983
@pytest.mark.parametrize("model", ["prophet"])
952984
def test_generate_files(operator_setup, model):
953985
yaml_i = TEMPLATE_YAML.copy()
@@ -970,35 +1002,15 @@ def test_generate_files(operator_setup, model):
9701002
yaml_i["spec"]["additional_data"]["data"] = df_add
9711003
operator_config = ForecastOperatorConfig.from_dict(yaml_i)
9721004
results = operate(operator_config)
973-
files = os.listdir(yaml_i["spec"]["output_directory"]["url"])
974-
if "errors.json" in files:
975-
with open(
976-
os.path.join(yaml_i["spec"]["output_directory"]["url"], "errors.json")
977-
) as f:
978-
assert False, f"Failed due to errors.json being created: {f.read()}"
979-
assert "report.html" in files, "Failed to generate report"
980-
assert (
981-
"forecast.csv" not in files
982-
), "Generated forecast file, but `generate_forecast_file` was set False"
983-
assert (
984-
"metrics.csv" not in files
985-
), "Generated metrics file, but `generate_metrics_file` was set False"
986-
assert (
987-
"local_explanations.csv" not in files
988-
), "Generated metrics file, but `generate_explanation_files` was set False"
989-
assert (
990-
"global_explanations.csv" not in files
991-
), "Generated metrics file, but `generate_explanation_files` was set False"
992-
assert not results.get_forecast().empty
993-
assert not results.get_metrics().empty
994-
assert not results.get_global_explanations().empty
995-
assert not results.get_local_explanations().empty
1005+
_check_results_obj(results)
1006+
_check_no_skippable_files(yaml_i)
9961007

9971008
yaml_i["spec"].pop("generate_explanation_files")
9981009
yaml_i["spec"].pop("generate_forecast_file")
9991010
yaml_i["spec"].pop("generate_metrics_file")
10001011
operator_config = ForecastOperatorConfig.from_dict(yaml_i)
10011012
results = operate(operator_config)
1013+
_check_results_obj(results)
10021014
files = os.listdir(yaml_i["spec"]["output_directory"]["url"])
10031015
if "errors.json" in files:
10041016
with open(
@@ -1012,6 +1024,12 @@ def test_generate_files(operator_setup, model):
10121024
assert "local_explanation.csv" in files, "Failed to generated local expl file"
10131025
assert "global_explanation.csv" in files, "Failed to generated global expl file"
10141026

1027+
# Test that the results object still generates when report.html has an error
1028+
yaml_i["spec"]["output_directory"]["url"] = "s3://test@test/test_dir"
1029+
operator_config = ForecastOperatorConfig.from_dict(yaml_i)
1030+
results = operate(operator_config)
1031+
_check_results_obj(results)
1032+
10151033

10161034
if __name__ == "__main__":
10171035
pass

0 commit comments

Comments
 (0)