|
2 | 2 |
|
3 | 3 | from functools import partial
|
4 | 4 |
|
5 |
| -import numpy as np |
6 |
| - |
7 | 5 | from aeon.base._base import _clone_estimator
|
8 |
| -from aeon.utils.data_types import VALID_SERIES_INPUT_TYPES |
| 6 | +from aeon.testing.testing_data import FULL_TEST_DATA_DICT |
| 7 | +from aeon.utils.data_types import VALID_SERIES_INNER_TYPES |
9 | 8 |
|
10 | 9 |
|
11 | 10 | def _yield_forecasting_checks(estimator_class, estimator_instances, datatypes):
|
12 | 11 | """Yield all forecasting checks for an aeon forecaster."""
|
13 | 12 | # only class required
|
14 |
| - yield partial(check_forecasting_base_functionality, estimator_class=estimator_class) |
| 13 | + yield partial(check_forecaster_overrides_and_tags, estimator_class=estimator_class) |
15 | 14 |
|
16 | 15 | # test class instances
|
17 |
| - for _, estimator in enumerate(estimator_instances): |
18 |
| - # no data needed |
19 |
| - yield partial(check_forecaster_instance, estimator=estimator) |
20 |
| - |
21 |
| - |
22 |
| -def check_forecasting_base_functionality(estimator_class): |
23 |
| - """Test compliance with the base class contract.""" |
24 |
| - # Test they dont override final methods, because python does not enforce this |
25 |
| - assert "fit" not in estimator_class.__dict__ |
26 |
| - assert "predict" not in estimator_class.__dict__ |
27 |
| - assert "forecast" not in estimator_class.__dict__ |
28 |
| - fit_is_empty = estimator_class.get_class_tag(tag_name="fit_is_empty") |
29 |
| - assert not fit_is_empty == "_fit" not in estimator_class.__dict__ |
| 16 | + for i, estimator in enumerate(estimator_instances): |
| 17 | + # test all data types |
| 18 | + for datatype in datatypes[i]: |
| 19 | + yield partial( |
| 20 | + check_forecaster_output, |
| 21 | + estimator=estimator, |
| 22 | + datatype=datatype, |
| 23 | + ) |
| 24 | + |
| 25 | + |
| 26 | +def check_forecaster_overrides_and_tags(estimator_class): |
| 27 | + """Test compliance with the forecaster base class contract.""" |
| 28 | + # Test they don't override final methods, because Python does not enforce this |
| 29 | + final_methods = ["fit", "predict", "forecast"] |
| 30 | + for method in final_methods: |
| 31 | + if method in estimator_class.__dict__: |
| 32 | + raise ValueError( |
| 33 | + f"Forecaster {estimator_class} overrides the " |
| 34 | + f"method {method}. Override _{method} instead." |
| 35 | + ) |
| 36 | + |
| 37 | + # Test that all forecasters implement abstract predict. |
| 38 | + assert "_predict" in estimator_class.__dict__ |
| 39 | + |
| 40 | + # todo decide what to do with "fit_is_empty" and abstract "_fit" |
| 41 | + |
30 | 42 | # Test valid tag for X_inner_type
|
31 | 43 | X_inner_type = estimator_class.get_class_tag(tag_name="X_inner_type")
|
32 |
| - assert X_inner_type in VALID_SERIES_INPUT_TYPES |
| 44 | + if isinstance(X_inner_type, str): |
| 45 | + assert X_inner_type in VALID_SERIES_INNER_TYPES |
| 46 | + else: # must be a list |
| 47 | + assert all([t in VALID_SERIES_INNER_TYPES for t in X_inner_type]) |
| 48 | + |
33 | 49 | # Must have at least one set to True
|
34 | 50 | multi = estimator_class.get_class_tag(tag_name="capability:multivariate")
|
35 | 51 | uni = estimator_class.get_class_tag(tag_name="capability:univariate")
|
36 |
| - assert multi or uni |
| 52 | + assert multi or uni, ( |
| 53 | + "At least one of tag capability:multivariate or " |
| 54 | + "capability:univariate must be true." |
| 55 | + ) |
37 | 56 |
|
38 | 57 |
|
39 |
| -def check_forecaster_instance(estimator): |
40 |
| - """Test forecasters.""" |
| 58 | +def check_forecaster_output(estimator, datatype): |
| 59 | + """Test the forecaster output on valid data.""" |
41 | 60 | estimator = _clone_estimator(estimator)
|
42 |
| - pass |
43 |
| - # Sort |
44 |
| - # Check output correct: predict should return a float |
45 |
| - y = np.array([0.5, 0.7, 0.8, 0.9, 1.0]) |
46 |
| - estimator.fit(y) |
47 |
| - p = estimator.predict() |
48 |
| - assert isinstance(p, float) |
49 |
| - # forecast should return a float equal to fit/predict |
50 |
| - p2 = estimator.forecast(y) |
51 |
| - assert p == p2 |
| 61 | + |
| 62 | + estimator.fit( |
| 63 | + FULL_TEST_DATA_DICT[datatype]["train"][0], |
| 64 | + ) |
| 65 | + y_pred = estimator.predict() |
| 66 | + assert isinstance(y_pred, float), ( |
| 67 | + f"predict() output should be float, got" f" {type(y_pred)}" |
| 68 | + ) |
| 69 | + |
| 70 | + y_pred2 = estimator.forecast(FULL_TEST_DATA_DICT[datatype]["train"][0]) |
| 71 | + assert y_pred == y_pred2, ( |
| 72 | + f"predict() and forecast() output differ: {y_pred} !=" f" {y_pred2}" |
| 73 | + ) |
| 74 | + y_pred3 = estimator.predict(FULL_TEST_DATA_DICT[datatype]["train"][0]) |
| 75 | + assert y_pred == y_pred3, ( |
| 76 | + f"after fit(), predict() and predict(y_train) should be the same, but" |
| 77 | + f"output differ: {y_pred} != {y_pred3}" |
| 78 | + ) |
0 commit comments