Skip to content

Commit 8bc78e3

Browse files
committed
speed up slowest test
1 parent aef41ca commit 8bc78e3

File tree

1 file changed

+76
-44
lines changed

1 file changed

+76
-44
lines changed

tests/operators/forecast/test_errors.py

Lines changed: 76 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,7 @@ def setup_faulty_rossman():
259259
additional_data_path = f"{data_folder}/rs_2_add_encoded.csv"
260260
return historical_data_path, additional_data_path
261261

262+
262263
def setup_small_rossman():
263264
curr_dir = pathlib.Path(__file__).parent.resolve()
264265
data_folder = f"{curr_dir}/../data/"
@@ -396,7 +397,7 @@ def test_0_series(operator_setup, model):
396397
historical_data_path=historical_data_path,
397398
additional_data_path=additional_data_path,
398399
test_data_path=test_data_path,
399-
preprocessing={"enabled": False}
400+
preprocessing={"enabled": False},
400401
)
401402
with pytest.raises(DataMismatchError):
402403
run_yaml(
@@ -465,36 +466,36 @@ def test_disabling_outlier_treatment(operator_setup):
465466
axis=1,
466467
)
467468
outliers = [1000, -800]
468-
hist_data_0.at[40, 'Sales'] = outliers[0]
469-
hist_data_0.at[75, 'Sales'] = outliers[1]
469+
hist_data_0.at[40, "Sales"] = outliers[0]
470+
hist_data_0.at[75, "Sales"] = outliers[1]
470471
historical_data_path, additional_data_path, test_data_path = setup_artificial_data(
471472
tmpdirname, hist_data_0
472473
)
473474

474475
yaml_i, output_data_path = populate_yaml(
475-
tmpdirname=tmpdirname,
476-
model="arima",
477-
historical_data_path=historical_data_path
476+
tmpdirname=tmpdirname, model="arima", historical_data_path=historical_data_path
478477
)
479478
yaml_i["spec"].pop("target_category_columns")
480479
yaml_i["spec"].pop("additional_data")
481480

482481
# running default pipeline where outlier will be treated
483482
run_yaml(tmpdirname=tmpdirname, yaml_i=yaml_i, output_data_path=output_data_path)
484483
forecast_without_outlier = pd.read_csv(f"{tmpdirname}/results/forecast.csv")
485-
input_vals_without_outlier = set(forecast_without_outlier['input_value'])
484+
input_vals_without_outlier = set(forecast_without_outlier["input_value"])
486485
assert all(
487-
item not in input_vals_without_outlier for item in outliers), "forecast file should not contain any outliers"
486+
item not in input_vals_without_outlier for item in outliers
487+
), "forecast file should not contain any outliers"
488488

489489
# switching off outlier_treatment
490490
preprocessing_steps = {"missing_value_imputation": True, "outlier_treatment": False}
491491
preprocessing = {"enabled": True, "steps": preprocessing_steps}
492492
yaml_i["spec"]["preprocessing"] = preprocessing
493493
run_yaml(tmpdirname=tmpdirname, yaml_i=yaml_i, output_data_path=output_data_path)
494494
forecast_with_outlier = pd.read_csv(f"{tmpdirname}/results/forecast.csv")
495-
input_vals_with_outlier = set(forecast_with_outlier['input_value'])
495+
input_vals_with_outlier = set(forecast_with_outlier["input_value"])
496496
assert all(
497-
item in input_vals_with_outlier for item in outliers), "forecast file should contain all the outliers"
497+
item in input_vals_with_outlier for item in outliers
498+
), "forecast file should contain all the outliers"
498499

499500

500501
@pytest.mark.parametrize("model", MODELS)
@@ -529,7 +530,7 @@ def split_df(df):
529530
historical_data_path=historical_data_path,
530531
additional_data_path=additional_data_path,
531532
test_data_path=test_data_path,
532-
preprocessing={"enabled": True, "steps": preprocessing_steps}
533+
preprocessing={"enabled": True, "steps": preprocessing_steps},
533534
)
534535
with pytest.raises(DataMismatchError):
535536
# 4 columns in historical data, but only 1 cat col specified
@@ -561,8 +562,8 @@ def test_all_series_failure(model):
561562
)
562563
preprocessing_steps = {"missing_value_imputation": True, "outlier_treatment": False}
563564
yaml_i["spec"]["model"] = model
564-
yaml_i['spec']['horizon'] = 10
565-
yaml_i['spec']['preprocessing'] = preprocessing_steps
565+
yaml_i["spec"]["horizon"] = 10
566+
yaml_i["spec"]["preprocessing"] = preprocessing_steps
566567
if yaml_i["spec"].get("additional_data") is not None and model != "autots":
567568
yaml_i["spec"]["generate_explanations"] = True
568569
if model == "autots":
@@ -571,14 +572,15 @@ def test_all_series_failure(model):
571572
yaml_i["spec"]["model_kwargs"] = {"time_budget": 1}
572573

573574
module_to_patch = {
574-
"arima": 'pmdarima.auto_arima',
575-
"autots": 'autots.AutoTS',
576-
"automlx": 'automlx.Pipeline',
577-
"prophet": 'prophet.Prophet',
578-
"neuralprophet": 'neuralprophet.NeuralProphet'
575+
"arima": "pmdarima.auto_arima",
576+
"autots": "autots.AutoTS",
577+
"automlx": "automlx.Pipeline",
578+
"prophet": "prophet.Prophet",
579+
"neuralprophet": "neuralprophet.NeuralProphet",
579580
}
580-
with patch(module_to_patch[model], side_effect=Exception("Custom exception message")):
581-
581+
with patch(
582+
module_to_patch[model], side_effect=Exception("Custom exception message")
583+
):
582584
run(yaml_i, backend="operator.local", debug=False)
583585

584586
report_path = f"{output_data_path}/report.html"
@@ -588,17 +590,26 @@ def test_all_series_failure(model):
588590
assert os.path.exists(error_path), f"Error file not found at {error_path}"
589591

590592
# Additionally, you can read the content of the error.json and assert its content
591-
with open(error_path, 'r') as error_file:
593+
with open(error_path, "r") as error_file:
592594
error_content = json.load(error_file)
593-
assert "Custom exception message" in error_content["1"]["error"], "Error message mismatch"
594-
assert "Custom exception message" in error_content["13"]["error"], "Error message mismatch"
595+
assert (
596+
"Custom exception message" in error_content["1"]["error"]
597+
), "Error message mismatch"
598+
assert (
599+
"Custom exception message" in error_content["13"]["error"]
600+
), "Error message mismatch"
595601

596602
if yaml_i["spec"]["generate_explanations"]:
597603
global_fn = f"{tmpdirname}/results/global_explanation.csv"
598-
assert os.path.exists(global_fn), f"Global explanation file not found at {report_path}"
604+
assert os.path.exists(
605+
global_fn
606+
), f"Global explanation file not found at {report_path}"
599607

600608
local_fn = f"{tmpdirname}/results/local_explanation.csv"
601-
assert os.path.exists(local_fn), f"Local explanation file not found at {report_path}"
609+
assert os.path.exists(
610+
local_fn
611+
), f"Local explanation file not found at {report_path}"
612+
602613

603614
@pytest.mark.parametrize("model", MODELS)
604615
def test_arima_automlx_errors(operator_setup, model):
@@ -611,29 +622,38 @@ def test_arima_automlx_errors(operator_setup, model):
611622
)
612623

613624
"""
614-
Arima was failing for constant trend when there are constant columns and when there are boolean columns .
615-
We added label encoding for boolean and are dropping columns with constant value for arima with constant trend.
625+
Arima was failing for constant trend when there are constant columns and when there are boolean columns .
626+
We added label encoding for boolean and are dropping columns with constant value for arima with constant trend.
616627
This test checks that report, metrics, explanations are generated for this case.
617628
"""
618629

619630
"""
620-
series 13 in this data has missing dates and automlx fails for this with DatetimeIndex error. This test checks that
631+
series 13 in this data has missing dates and automlx fails for this with DatetimeIndex error. This test checks that
621632
outputs get generated and that error is shown in errors.json
622633
"""
623634

624635
"""
625-
explanations generation is failing when boolean columns are passed.
626-
TypeError: ufunc 'isfinite' not supported for the input types, and the inputs could not be safely coerced
636+
explanations generation is failing when boolean columns are passed.
637+
TypeError: ufunc 'isfinite' not supported for the input types, and the inputs could not be safely coerced
627638
any supported types according to the casting rule ''safe''
628639
Added label encoding before passing data to explainer
629640
"""
630641
preprocessing_steps = {"missing_value_imputation": True, "outlier_treatment": False}
631-
yaml_i['spec']['horizon'] = 10
632-
yaml_i['spec']['preprocessing'] = preprocessing_steps
633-
yaml_i['spec']['generate_explanations'] = True
634-
yaml_i['spec']['model'] = model
642+
yaml_i["spec"]["horizon"] = 10
643+
yaml_i["spec"]["preprocessing"] = preprocessing_steps
644+
yaml_i["spec"]["generate_explanations"] = True
645+
yaml_i["spec"]["model"] = model
646+
if model == "autots":
647+
yaml_i["spec"]["model_kwargs"] = {"model_list": "superfast"}
648+
if model == "automlx":
649+
yaml_i["spec"]["model_kwargs"] = {"time_budget": 1}
635650

636-
run_yaml(tmpdirname=tmpdirname, yaml_i=yaml_i, output_data_path=output_data_path, test_metrics_check=False)
651+
run_yaml(
652+
tmpdirname=tmpdirname,
653+
yaml_i=yaml_i,
654+
output_data_path=output_data_path,
655+
test_metrics_check=False,
656+
)
637657

638658
report_path = f"{tmpdirname}/results/report.html"
639659
assert os.path.exists(report_path), f"Report file not found at {report_path}"
@@ -642,23 +662,28 @@ def test_arima_automlx_errors(operator_setup, model):
642662
assert os.path.exists(forecast_path), f"Forecast file not found at {report_path}"
643663
assert not pd.read_csv(forecast_path).empty
644664

645-
646665
error_path = f"{tmpdirname}/results/errors.json"
647666
if model == "arima":
648667
assert not os.path.exists(error_path), f"Error file not found at {error_path}"
649668
elif model == "automlx":
650669
assert os.path.exists(error_path), f"Error file not found at {error_path}"
651-
with open(error_path, 'r') as error_file:
670+
with open(error_path, "r") as error_file:
652671
error_content = json.load(error_file)
653-
assert "Input data does not have a consistent (in terms of diff) DatetimeIndex." in error_content["13"][
654-
"error"], "Error message mismatch"
672+
assert (
673+
"Input data does not have a consistent (in terms of diff) DatetimeIndex."
674+
in error_content["13"]["error"]
675+
), "Error message mismatch"
655676

656677
if model != "autots":
657678
global_fn = f"{tmpdirname}/results/global_explanation.csv"
658-
assert os.path.exists(global_fn), f"Global explanation file not found at {report_path}"
679+
assert os.path.exists(
680+
global_fn
681+
), f"Global explanation file not found at {report_path}"
659682

660683
local_fn = f"{tmpdirname}/results/local_explanation.csv"
661-
assert os.path.exists(local_fn), f"Local explanation file not found at {report_path}"
684+
assert os.path.exists(
685+
local_fn
686+
), f"Local explanation file not found at {report_path}"
662687

663688
glb_expl = pd.read_csv(global_fn, index_col=0)
664689
loc_expl = pd.read_csv(local_fn)
@@ -680,13 +705,20 @@ def test_date_format(operator_setup, model):
680705
historical_data_path=historical_data_path,
681706
additional_data_path=additional_data_path,
682707
)
683-
yaml_i['spec']['horizon'] = 10
708+
yaml_i["spec"]["horizon"] = 10
684709
yaml_i["spec"]["model"] = model
685710
if model == "autots":
686711
yaml_i["spec"]["model_kwargs"] = {"model_list": "superfast"}
687712

688-
run_yaml(tmpdirname=tmpdirname, yaml_i=yaml_i, output_data_path=output_data_path, test_metrics_check=False)
689-
assert pd.read_csv(additional_data_path)['Date'].equals(pd.read_csv(f"{tmpdirname}/results/forecast.csv")['Date'])
713+
run_yaml(
714+
tmpdirname=tmpdirname,
715+
yaml_i=yaml_i,
716+
output_data_path=output_data_path,
717+
test_metrics_check=False,
718+
)
719+
assert pd.read_csv(additional_data_path)["Date"].equals(
720+
pd.read_csv(f"{tmpdirname}/results/forecast.csv")["Date"]
721+
)
690722

691723

692724
if __name__ == "__main__":

0 commit comments

Comments
 (0)