Skip to content

Commit aee9fc9

Browse files
[ENH] Forecasting testing (#2891)
* call forecasting checks * basic cleanup * fix * add assert error messages * remove fit_is_empty check (see #2893) * remove extra argument to fit * remove extra argument to predict * correct forecast argument * Update _yield_forecasting_checks.py * add test for predict --------- Co-authored-by: Tony Bagnall <ajb@uea.ac.uk>
1 parent fa63e6e commit aee9fc9

File tree

3 files changed

+68
-32
lines changed

3 files changed

+68
-32
lines changed

aeon/testing/estimator_checking/_yield_anomaly_detection_checks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def check_anomaly_detector_overrides_and_tags(estimator_class):
3838

3939
# Test that fit_is_empty is correctly set
4040
fit_is_empty = estimator_class.get_class_tag(tag_name="fit_is_empty")
41-
assert not fit_is_empty == "_fit" not in estimator_class.__dict__
41+
assert fit_is_empty == ("_fit" not in estimator_class.__dict__)
4242

4343
# Test valid tag for X_inner_type
4444
X_inner_type = estimator_class.get_class_tag(tag_name="X_inner_type")

aeon/testing/estimator_checking/_yield_estimator_checks.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from aeon.classification.early_classification import BaseEarlyClassifier
2222
from aeon.clustering import BaseClusterer
2323
from aeon.clustering.deep_learning.base import BaseDeepClusterer
24+
from aeon.forecasting import BaseForecaster
2425
from aeon.regression import BaseRegressor
2526
from aeon.regression.deep_learning.base import BaseDeepRegressor
2627
from aeon.segmentation import BaseSegmenter
@@ -40,6 +41,9 @@
4041
from aeon.testing.estimator_checking._yield_early_classification_checks import (
4142
_yield_early_classification_checks,
4243
)
44+
from aeon.testing.estimator_checking._yield_forecasting_checks import (
45+
_yield_forecasting_checks,
46+
)
4347
from aeon.testing.estimator_checking._yield_multithreading_checks import (
4448
_yield_multithreading_checks,
4549
)
@@ -158,6 +162,11 @@ def _yield_all_aeon_checks(
158162
estimator_class, estimator_instances, datatypes
159163
)
160164

165+
if issubclass(estimator_class, BaseForecaster):
166+
yield from _yield_forecasting_checks(
167+
estimator_class, estimator_instances, datatypes
168+
)
169+
161170
if issubclass(estimator_class, BaseTransformer):
162171
yield from _yield_transformation_checks(
163172
estimator_class, estimator_instances, datatypes

aeon/testing/estimator_checking/_yield_forecasting_checks.py

Lines changed: 58 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,50 +2,77 @@
22

33
from functools import partial
44

5-
import numpy as np
6-
75
from aeon.base._base import _clone_estimator
8-
from aeon.utils.data_types import VALID_SERIES_INPUT_TYPES
6+
from aeon.testing.testing_data import FULL_TEST_DATA_DICT
7+
from aeon.utils.data_types import VALID_SERIES_INNER_TYPES
98

109

1110
def _yield_forecasting_checks(estimator_class, estimator_instances, datatypes):
1211
"""Yield all forecasting checks for an aeon forecaster."""
1312
# only class required
14-
yield partial(check_forecasting_base_functionality, estimator_class=estimator_class)
13+
yield partial(check_forecaster_overrides_and_tags, estimator_class=estimator_class)
1514

1615
# test class instances
17-
for _, estimator in enumerate(estimator_instances):
18-
# no data needed
19-
yield partial(check_forecaster_instance, estimator=estimator)
20-
21-
22-
def check_forecasting_base_functionality(estimator_class):
23-
"""Test compliance with the base class contract."""
24-
# Test they dont override final methods, because python does not enforce this
25-
assert "fit" not in estimator_class.__dict__
26-
assert "predict" not in estimator_class.__dict__
27-
assert "forecast" not in estimator_class.__dict__
28-
fit_is_empty = estimator_class.get_class_tag(tag_name="fit_is_empty")
29-
assert not fit_is_empty == "_fit" not in estimator_class.__dict__
16+
for i, estimator in enumerate(estimator_instances):
17+
# test all data types
18+
for datatype in datatypes[i]:
19+
yield partial(
20+
check_forecaster_output,
21+
estimator=estimator,
22+
datatype=datatype,
23+
)
24+
25+
26+
def check_forecaster_overrides_and_tags(estimator_class):
27+
"""Test compliance with the forecaster base class contract."""
28+
# Test they don't override final methods, because Python does not enforce this
29+
final_methods = ["fit", "predict", "forecast"]
30+
for method in final_methods:
31+
if method in estimator_class.__dict__:
32+
raise ValueError(
33+
f"Forecaster {estimator_class} overrides the "
34+
f"method {method}. Override _{method} instead."
35+
)
36+
37+
# Test that all forecasters implement abstract predict.
38+
assert "_predict" in estimator_class.__dict__
39+
40+
# todo decide what to do with "fit_is_empty" and abstract "_fit"
41+
3042
# Test valid tag for X_inner_type
3143
X_inner_type = estimator_class.get_class_tag(tag_name="X_inner_type")
32-
assert X_inner_type in VALID_SERIES_INPUT_TYPES
44+
if isinstance(X_inner_type, str):
45+
assert X_inner_type in VALID_SERIES_INNER_TYPES
46+
else: # must be a list
47+
assert all([t in VALID_SERIES_INNER_TYPES for t in X_inner_type])
48+
3349
# Must have at least one set to True
3450
multi = estimator_class.get_class_tag(tag_name="capability:multivariate")
3551
uni = estimator_class.get_class_tag(tag_name="capability:univariate")
36-
assert multi or uni
52+
assert multi or uni, (
53+
"At least one of tag capability:multivariate or "
54+
"capability:univariate must be true."
55+
)
3756

3857

39-
def check_forecaster_instance(estimator):
40-
"""Test forecasters."""
58+
def check_forecaster_output(estimator, datatype):
59+
"""Test the forecaster output on valid data."""
4160
estimator = _clone_estimator(estimator)
42-
pass
43-
# Sort
44-
# Check output correct: predict should return a float
45-
y = np.array([0.5, 0.7, 0.8, 0.9, 1.0])
46-
estimator.fit(y)
47-
p = estimator.predict()
48-
assert isinstance(p, float)
49-
# forecast should return a float equal to fit/predict
50-
p2 = estimator.forecast(y)
51-
assert p == p2
61+
62+
estimator.fit(
63+
FULL_TEST_DATA_DICT[datatype]["train"][0],
64+
)
65+
y_pred = estimator.predict()
66+
assert isinstance(y_pred, float), (
67+
f"predict() output should be float, got" f" {type(y_pred)}"
68+
)
69+
70+
y_pred2 = estimator.forecast(FULL_TEST_DATA_DICT[datatype]["train"][0])
71+
assert y_pred == y_pred2, (
72+
f"predict() and forecast() output differ: {y_pred} !=" f" {y_pred2}"
73+
)
74+
y_pred3 = estimator.predict(FULL_TEST_DATA_DICT[datatype]["train"][0])
75+
assert y_pred == y_pred3, (
76+
f"after fit(), predict() and predict(y_train) should be the same, but"
77+
f"output differ: {y_pred} != {y_pred3}"
78+
)

0 commit comments

Comments
 (0)