@@ -259,6 +259,7 @@ def setup_faulty_rossman():
259
259
additional_data_path = f"{ data_folder } /rs_2_add_encoded.csv"
260
260
return historical_data_path , additional_data_path
261
261
262
+
262
263
def setup_small_rossman ():
263
264
curr_dir = pathlib .Path (__file__ ).parent .resolve ()
264
265
data_folder = f"{ curr_dir } /../data/"
@@ -396,7 +397,7 @@ def test_0_series(operator_setup, model):
396
397
historical_data_path = historical_data_path ,
397
398
additional_data_path = additional_data_path ,
398
399
test_data_path = test_data_path ,
399
- preprocessing = {"enabled" : False }
400
+ preprocessing = {"enabled" : False },
400
401
)
401
402
with pytest .raises (DataMismatchError ):
402
403
run_yaml (
@@ -465,36 +466,36 @@ def test_disabling_outlier_treatment(operator_setup):
465
466
axis = 1 ,
466
467
)
467
468
outliers = [1000 , - 800 ]
468
- hist_data_0 .at [40 , ' Sales' ] = outliers [0 ]
469
- hist_data_0 .at [75 , ' Sales' ] = outliers [1 ]
469
+ hist_data_0 .at [40 , " Sales" ] = outliers [0 ]
470
+ hist_data_0 .at [75 , " Sales" ] = outliers [1 ]
470
471
historical_data_path , additional_data_path , test_data_path = setup_artificial_data (
471
472
tmpdirname , hist_data_0
472
473
)
473
474
474
475
yaml_i , output_data_path = populate_yaml (
475
- tmpdirname = tmpdirname ,
476
- model = "arima" ,
477
- historical_data_path = historical_data_path
476
+ tmpdirname = tmpdirname , model = "arima" , historical_data_path = historical_data_path
478
477
)
479
478
yaml_i ["spec" ].pop ("target_category_columns" )
480
479
yaml_i ["spec" ].pop ("additional_data" )
481
480
482
481
# running default pipeline where outlier will be treated
483
482
run_yaml (tmpdirname = tmpdirname , yaml_i = yaml_i , output_data_path = output_data_path )
484
483
forecast_without_outlier = pd .read_csv (f"{ tmpdirname } /results/forecast.csv" )
485
- input_vals_without_outlier = set (forecast_without_outlier [' input_value' ])
484
+ input_vals_without_outlier = set (forecast_without_outlier [" input_value" ])
486
485
assert all (
487
- item not in input_vals_without_outlier for item in outliers ), "forecast file should not contain any outliers"
486
+ item not in input_vals_without_outlier for item in outliers
487
+ ), "forecast file should not contain any outliers"
488
488
489
489
# switching off outlier_treatment
490
490
preprocessing_steps = {"missing_value_imputation" : True , "outlier_treatment" : False }
491
491
preprocessing = {"enabled" : True , "steps" : preprocessing_steps }
492
492
yaml_i ["spec" ]["preprocessing" ] = preprocessing
493
493
run_yaml (tmpdirname = tmpdirname , yaml_i = yaml_i , output_data_path = output_data_path )
494
494
forecast_with_outlier = pd .read_csv (f"{ tmpdirname } /results/forecast.csv" )
495
- input_vals_with_outlier = set (forecast_with_outlier [' input_value' ])
495
+ input_vals_with_outlier = set (forecast_with_outlier [" input_value" ])
496
496
assert all (
497
- item in input_vals_with_outlier for item in outliers ), "forecast file should contain all the outliers"
497
+ item in input_vals_with_outlier for item in outliers
498
+ ), "forecast file should contain all the outliers"
498
499
499
500
500
501
@pytest .mark .parametrize ("model" , MODELS )
@@ -529,7 +530,7 @@ def split_df(df):
529
530
historical_data_path = historical_data_path ,
530
531
additional_data_path = additional_data_path ,
531
532
test_data_path = test_data_path ,
532
- preprocessing = {"enabled" : True , "steps" : preprocessing_steps }
533
+ preprocessing = {"enabled" : True , "steps" : preprocessing_steps },
533
534
)
534
535
with pytest .raises (DataMismatchError ):
535
536
# 4 columns in historical data, but only 1 cat col specified
@@ -561,8 +562,8 @@ def test_all_series_failure(model):
561
562
)
562
563
preprocessing_steps = {"missing_value_imputation" : True , "outlier_treatment" : False }
563
564
yaml_i ["spec" ]["model" ] = model
564
- yaml_i [' spec' ][ ' horizon' ] = 10
565
- yaml_i [' spec' ][ ' preprocessing' ] = preprocessing_steps
565
+ yaml_i [" spec" ][ " horizon" ] = 10
566
+ yaml_i [" spec" ][ " preprocessing" ] = preprocessing_steps
566
567
if yaml_i ["spec" ].get ("additional_data" ) is not None and model != "autots" :
567
568
yaml_i ["spec" ]["generate_explanations" ] = True
568
569
if model == "autots" :
@@ -571,14 +572,15 @@ def test_all_series_failure(model):
571
572
yaml_i ["spec" ]["model_kwargs" ] = {"time_budget" : 1 }
572
573
573
574
module_to_patch = {
574
- "arima" : ' pmdarima.auto_arima' ,
575
- "autots" : ' autots.AutoTS' ,
576
- "automlx" : ' automlx.Pipeline' ,
577
- "prophet" : ' prophet.Prophet' ,
578
- "neuralprophet" : ' neuralprophet.NeuralProphet'
575
+ "arima" : " pmdarima.auto_arima" ,
576
+ "autots" : " autots.AutoTS" ,
577
+ "automlx" : " automlx.Pipeline" ,
578
+ "prophet" : " prophet.Prophet" ,
579
+ "neuralprophet" : " neuralprophet.NeuralProphet" ,
579
580
}
580
- with patch (module_to_patch [model ], side_effect = Exception ("Custom exception message" )):
581
-
581
+ with patch (
582
+ module_to_patch [model ], side_effect = Exception ("Custom exception message" )
583
+ ):
582
584
run (yaml_i , backend = "operator.local" , debug = False )
583
585
584
586
report_path = f"{ output_data_path } /report.html"
@@ -588,17 +590,26 @@ def test_all_series_failure(model):
588
590
assert os .path .exists (error_path ), f"Error file not found at { error_path } "
589
591
590
592
# Additionally, you can read the content of the error.json and assert its content
591
- with open (error_path , 'r' ) as error_file :
593
+ with open (error_path , "r" ) as error_file :
592
594
error_content = json .load (error_file )
593
- assert "Custom exception message" in error_content ["1" ]["error" ], "Error message mismatch"
594
- assert "Custom exception message" in error_content ["13" ]["error" ], "Error message mismatch"
595
+ assert (
596
+ "Custom exception message" in error_content ["1" ]["error" ]
597
+ ), "Error message mismatch"
598
+ assert (
599
+ "Custom exception message" in error_content ["13" ]["error" ]
600
+ ), "Error message mismatch"
595
601
596
602
if yaml_i ["spec" ]["generate_explanations" ]:
597
603
global_fn = f"{ tmpdirname } /results/global_explanation.csv"
598
- assert os .path .exists (global_fn ), f"Global explanation file not found at { report_path } "
604
+ assert os .path .exists (
605
+ global_fn
606
+ ), f"Global explanation file not found at { report_path } "
599
607
600
608
local_fn = f"{ tmpdirname } /results/local_explanation.csv"
601
- assert os .path .exists (local_fn ), f"Local explanation file not found at { report_path } "
609
+ assert os .path .exists (
610
+ local_fn
611
+ ), f"Local explanation file not found at { report_path } "
612
+
602
613
603
614
@pytest .mark .parametrize ("model" , MODELS )
604
615
def test_arima_automlx_errors (operator_setup , model ):
@@ -611,29 +622,38 @@ def test_arima_automlx_errors(operator_setup, model):
611
622
)
612
623
613
624
"""
614
- Arima was failing for constant trend when there are constant columns and when there are boolean columns .
615
- We added label encoding for boolean and are dropping columns with constant value for arima with constant trend.
625
+ Arima was failing for constant trend when there are constant columns and when there are boolean columns .
626
+ We added label encoding for boolean and are dropping columns with constant value for arima with constant trend.
616
627
This test checks that report, metrics, explanations are generated for this case.
617
628
"""
618
629
619
630
"""
620
- series 13 in this data has missing dates and automlx fails for this with DatetimeIndex error. This test checks that
631
+ series 13 in this data has missing dates and automlx fails for this with DatetimeIndex error. This test checks that
621
632
outputs get generated and that error is shown in errors.json
622
633
"""
623
634
624
635
"""
625
- explanations generation is failing when boolean columns are passed.
626
- TypeError: ufunc 'isfinite' not supported for the input types, and the inputs could not be safely coerced
636
+ explanations generation is failing when boolean columns are passed.
637
+ TypeError: ufunc 'isfinite' not supported for the input types, and the inputs could not be safely coerced
627
638
any supported types according to the casting rule ''safe''
628
639
Added label encoding before passing data to explainer
629
640
"""
630
641
preprocessing_steps = {"missing_value_imputation" : True , "outlier_treatment" : False }
631
- yaml_i ['spec' ]['horizon' ] = 10
632
- yaml_i ['spec' ]['preprocessing' ] = preprocessing_steps
633
- yaml_i ['spec' ]['generate_explanations' ] = True
634
- yaml_i ['spec' ]['model' ] = model
642
+ yaml_i ["spec" ]["horizon" ] = 10
643
+ yaml_i ["spec" ]["preprocessing" ] = preprocessing_steps
644
+ yaml_i ["spec" ]["generate_explanations" ] = True
645
+ yaml_i ["spec" ]["model" ] = model
646
+ if model == "autots" :
647
+ yaml_i ["spec" ]["model_kwargs" ] = {"model_list" : "superfast" }
648
+ if model == "automlx" :
649
+ yaml_i ["spec" ]["model_kwargs" ] = {"time_budget" : 1 }
635
650
636
- run_yaml (tmpdirname = tmpdirname , yaml_i = yaml_i , output_data_path = output_data_path , test_metrics_check = False )
651
+ run_yaml (
652
+ tmpdirname = tmpdirname ,
653
+ yaml_i = yaml_i ,
654
+ output_data_path = output_data_path ,
655
+ test_metrics_check = False ,
656
+ )
637
657
638
658
report_path = f"{ tmpdirname } /results/report.html"
639
659
assert os .path .exists (report_path ), f"Report file not found at { report_path } "
@@ -642,23 +662,28 @@ def test_arima_automlx_errors(operator_setup, model):
642
662
assert os .path .exists (forecast_path ), f"Forecast file not found at { report_path } "
643
663
assert not pd .read_csv (forecast_path ).empty
644
664
645
-
646
665
error_path = f"{ tmpdirname } /results/errors.json"
647
666
if model == "arima" :
648
667
assert not os .path .exists (error_path ), f"Error file not found at { error_path } "
649
668
elif model == "automlx" :
650
669
assert os .path .exists (error_path ), f"Error file not found at { error_path } "
651
- with open (error_path , 'r' ) as error_file :
670
+ with open (error_path , "r" ) as error_file :
652
671
error_content = json .load (error_file )
653
- assert "Input data does not have a consistent (in terms of diff) DatetimeIndex." in error_content ["13" ][
654
- "error" ], "Error message mismatch"
672
+ assert (
673
+ "Input data does not have a consistent (in terms of diff) DatetimeIndex."
674
+ in error_content ["13" ]["error" ]
675
+ ), "Error message mismatch"
655
676
656
677
if model != "autots" :
657
678
global_fn = f"{ tmpdirname } /results/global_explanation.csv"
658
- assert os .path .exists (global_fn ), f"Global explanation file not found at { report_path } "
679
+ assert os .path .exists (
680
+ global_fn
681
+ ), f"Global explanation file not found at { report_path } "
659
682
660
683
local_fn = f"{ tmpdirname } /results/local_explanation.csv"
661
- assert os .path .exists (local_fn ), f"Local explanation file not found at { report_path } "
684
+ assert os .path .exists (
685
+ local_fn
686
+ ), f"Local explanation file not found at { report_path } "
662
687
663
688
glb_expl = pd .read_csv (global_fn , index_col = 0 )
664
689
loc_expl = pd .read_csv (local_fn )
@@ -680,13 +705,20 @@ def test_date_format(operator_setup, model):
680
705
historical_data_path = historical_data_path ,
681
706
additional_data_path = additional_data_path ,
682
707
)
683
- yaml_i [' spec' ][ ' horizon' ] = 10
708
+ yaml_i [" spec" ][ " horizon" ] = 10
684
709
yaml_i ["spec" ]["model" ] = model
685
710
if model == "autots" :
686
711
yaml_i ["spec" ]["model_kwargs" ] = {"model_list" : "superfast" }
687
712
688
- run_yaml (tmpdirname = tmpdirname , yaml_i = yaml_i , output_data_path = output_data_path , test_metrics_check = False )
689
- assert pd .read_csv (additional_data_path )['Date' ].equals (pd .read_csv (f"{ tmpdirname } /results/forecast.csv" )['Date' ])
713
+ run_yaml (
714
+ tmpdirname = tmpdirname ,
715
+ yaml_i = yaml_i ,
716
+ output_data_path = output_data_path ,
717
+ test_metrics_check = False ,
718
+ )
719
+ assert pd .read_csv (additional_data_path )["Date" ].equals (
720
+ pd .read_csv (f"{ tmpdirname } /results/forecast.csv" )["Date" ]
721
+ )
690
722
691
723
692
724
if __name__ == "__main__" :
0 commit comments