@@ -948,6 +948,38 @@ def test_prophet_floor_cap(operator_setup, model):
948
948
), "`max` not obeyed in prophet"
949
949
950
950
951
+ def _check_results_obj (results ):
952
+ assert not results .get_forecast ().empty
953
+ assert not results .get_metrics ().empty
954
+ assert not results .get_global_explanations ().empty
955
+ assert not results .get_local_explanations ().empty
956
+
957
+
958
+ def _check_no_skippable_files (yaml_i , check_report = True ):
959
+ files = os .listdir (yaml_i ["spec" ]["output_directory" ]["url" ])
960
+
961
+ if "errors.json" in files :
962
+ with open (
963
+ os .path .join (yaml_i ["spec" ]["output_directory" ]["url" ], "errors.json" )
964
+ ) as f :
965
+ assert False , f"Failed due to errors.json being created: { f .read ()} "
966
+ if check_report :
967
+ assert "report.html" in files , "Failed to generate report"
968
+
969
+ assert (
970
+ "forecast.csv" not in files
971
+ ), "Generated forecast file, but `generate_forecast_file` was set False"
972
+ assert (
973
+ "metrics.csv" not in files
974
+ ), "Generated metrics file, but `generate_metrics_file` was set False"
975
+ assert (
976
+ "local_explanations.csv" not in files
977
+ ), "Generated metrics file, but `generate_explanation_files` was set False"
978
+ assert (
979
+ "global_explanations.csv" not in files
980
+ ), "Generated metrics file, but `generate_explanation_files` was set False"
981
+
982
+
951
983
@pytest .mark .parametrize ("model" , ["prophet" ])
952
984
def test_generate_files (operator_setup , model ):
953
985
yaml_i = TEMPLATE_YAML .copy ()
@@ -970,35 +1002,15 @@ def test_generate_files(operator_setup, model):
970
1002
yaml_i ["spec" ]["additional_data" ]["data" ] = df_add
971
1003
operator_config = ForecastOperatorConfig .from_dict (yaml_i )
972
1004
results = operate (operator_config )
973
- files = os .listdir (yaml_i ["spec" ]["output_directory" ]["url" ])
974
- if "errors.json" in files :
975
- with open (
976
- os .path .join (yaml_i ["spec" ]["output_directory" ]["url" ], "errors.json" )
977
- ) as f :
978
- assert False , f"Failed due to errors.json being created: { f .read ()} "
979
- assert "report.html" in files , "Failed to generate report"
980
- assert (
981
- "forecast.csv" not in files
982
- ), "Generated forecast file, but `generate_forecast_file` was set False"
983
- assert (
984
- "metrics.csv" not in files
985
- ), "Generated metrics file, but `generate_metrics_file` was set False"
986
- assert (
987
- "local_explanations.csv" not in files
988
- ), "Generated metrics file, but `generate_explanation_files` was set False"
989
- assert (
990
- "global_explanations.csv" not in files
991
- ), "Generated metrics file, but `generate_explanation_files` was set False"
992
- assert not results .get_forecast ().empty
993
- assert not results .get_metrics ().empty
994
- assert not results .get_global_explanations ().empty
995
- assert not results .get_local_explanations ().empty
1005
+ _check_results_obj (results )
1006
+ _check_no_skippable_files (yaml_i )
996
1007
997
1008
yaml_i ["spec" ].pop ("generate_explanation_files" )
998
1009
yaml_i ["spec" ].pop ("generate_forecast_file" )
999
1010
yaml_i ["spec" ].pop ("generate_metrics_file" )
1000
1011
operator_config = ForecastOperatorConfig .from_dict (yaml_i )
1001
1012
results = operate (operator_config )
1013
+ _check_results_obj (results )
1002
1014
files = os .listdir (yaml_i ["spec" ]["output_directory" ]["url" ])
1003
1015
if "errors.json" in files :
1004
1016
with open (
@@ -1012,6 +1024,12 @@ def test_generate_files(operator_setup, model):
1012
1024
assert "local_explanation.csv" in files , "Failed to generated local expl file"
1013
1025
assert "global_explanation.csv" in files , "Failed to generated global expl file"
1014
1026
1027
+ # Test that the results object still generates when report.html has an error
1028
+ yaml_i ["spec" ]["output_directory" ]["url" ] = "s3://test@test/test_dir"
1029
+ operator_config = ForecastOperatorConfig .from_dict (yaml_i )
1030
+ results = operate (operator_config )
1031
+ _check_results_obj (results )
1032
+
1015
1033
1016
1034
if __name__ == "__main__" :
1017
1035
pass
0 commit comments