@@ -929,19 +929,24 @@ def test_generate_files(operator_setup, model):
929
929
)
930
930
931
931
yaml_i = TEMPLATE_YAML .copy ()
932
- yaml_i ["spec" ]["horizon" ] = 10
932
+ yaml_i ["spec" ]["horizon" ] = 3
933
933
yaml_i ["spec" ]["model" ] = model
934
934
yaml_i ["spec" ]["historical_data" ] = {"format" : "pandas" }
935
+ yaml_i ["spec" ]["additional_data" ] = {"format" : "pandas" }
935
936
yaml_i ["spec" ]["target_column" ] = TARGET_COL .name
936
937
yaml_i ["spec" ]["datetime_column" ]["name" ] = HISTORICAL_DATETIME_COL .name
937
- yaml_i ["spec" ]["report_title" ] = "Skibidi ADS Skibidi"
938
938
yaml_i ["spec" ]["output_directory" ]["url" ] = operator_setup
939
- yaml_i ["spec" ]["generate_explanations_file " ] = False
939
+ yaml_i ["spec" ]["generate_explanation_files " ] = False
940
940
yaml_i ["spec" ]["generate_forecast_file" ] = False
941
941
yaml_i ["spec" ]["generate_metrics_file" ] = False
942
+ yaml_i ["spec" ]["generate_explanations" ] = True
942
943
943
944
df = pd .concat ([HISTORICAL_DATETIME_COL [:15 ], TARGET_COL [:15 ]], axis = 1 )
945
+ df_add = pd .concat ([HISTORICAL_DATETIME_COL [:18 ], ADD_COLS [:18 ]], axis = 1 )
946
+ print (f"df: { df } " )
947
+ print (f"df_add: { df_add } " )
944
948
yaml_i ["spec" ]["historical_data" ]["data" ] = df
949
+ yaml_i ["spec" ]["additional_data" ]["data" ] = df_add
945
950
operator_config = ForecastOperatorConfig .from_dict (yaml_i )
946
951
results = operate (operator_config )
947
952
files = os .listdir (operator_setup )
@@ -952,8 +957,28 @@ def test_generate_files(operator_setup, model):
952
957
assert (
953
958
"metrics.csv" not in files
954
959
), "Generated metrics file, but `generate_metrics_file` was set False"
960
+ assert (
961
+ "local_explanations.csv" not in files
962
+ ), "Generated metrics file, but `generate_explanation_files` was set False"
963
+ assert (
964
+ "global_explanations.csv" not in files
965
+ ), "Generated metrics file, but `generate_explanation_files` was set False"
955
966
assert not results .get_forecast ().empty
956
967
assert not results .get_metrics ().empty
968
+ assert not results .get_global_explanations ().empty
969
+ assert not results .get_local_explanations ().empty
970
+
971
+ yaml_i ["spec" ].pop ("generate_explanation_files" )
972
+ yaml_i ["spec" ].pop ("generate_forecast_file" )
973
+ yaml_i ["spec" ].pop ("generate_metrics_file" )
974
+ operator_config = ForecastOperatorConfig .from_dict (yaml_i )
975
+ results = operate (operator_config )
976
+ files = os .listdir (operator_setup )
977
+ assert "report.html" in files , "Failed to generate report"
978
+ assert "forecast.csv" in files , "Failed to generate forecast file"
979
+ assert "metrics.csv" in files , "Failed to generated metrics file"
980
+ assert "local_explanation.csv" in files , "Failed to generated local expl file"
981
+ assert "global_explanation.csv" in files , "Failed to generated global expl file"
957
982
958
983
959
984
if __name__ == "__main__" :
0 commit comments