Skip to content

Commit eba9dee

Browse files
MatthewMiddlehurstpattplatt
authored andcommitted
tests
1 parent b6f0a30 commit eba9dee

File tree

1 file changed

+62
-2
lines changed

1 file changed

+62
-2
lines changed
Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,69 @@
11
"""Tests for all collection anomaly detectors."""
22

3+
from functools import partial
4+
5+
from aeon.base._base import _clone_estimator
6+
from aeon.testing.testing_data import FULL_TEST_DATA_DICT
7+
from aeon.testing.utils.estimator_checks import _assert_predict_labels
8+
from aeon.utils.data_types import COLLECTIONS_DATA_TYPES
9+
310

411
def _yield_collection_anomaly_detection_checks(
512
estimator_class, estimator_instances, datatypes
613
):
714
"""Yield all collection anomaly detection checks for an aeon estimator."""
8-
# nothing currently!
9-
return []
15+
# only class required
16+
yield partial(
17+
check_collection_detector_overrides_and_tags, estimator_class=estimator_class
18+
)
19+
20+
# test class instances
21+
for i, estimator in enumerate(estimator_instances):
22+
# test all data types
23+
for datatype in datatypes[i]:
24+
yield partial(
25+
check_collection_detector_output, estimator=estimator, datatype=datatype
26+
)
27+
28+
29+
def check_collection_detector_overrides_and_tags(estimator_class):
30+
"""Test compliance with the detector base class contract."""
31+
# Test they don't override final methods, because Python does not enforce this
32+
final_methods = [
33+
"fit",
34+
"predict",
35+
]
36+
for method in final_methods:
37+
if method in estimator_class.__dict__:
38+
raise ValueError(
39+
f"Collection anomaly detector {estimator_class} overrides the "
40+
f"method {method}. Override _{method} instead."
41+
)
42+
43+
# Test valid tag for X_inner_type
44+
X_inner_type = estimator_class.get_class_tag(tag_name="X_inner_type")
45+
if isinstance(X_inner_type, str):
46+
assert X_inner_type in COLLECTIONS_DATA_TYPES
47+
else: # must be a list
48+
assert all([t in COLLECTIONS_DATA_TYPES for t in X_inner_type])
49+
50+
# one of X_inner_types must be capable of storing unequal length
51+
if estimator_class.get_class_tag("capability:unequal_length"):
52+
valid_unequal_types = ["np-list", "df-list", "pd-multiindex"]
53+
if isinstance(X_inner_type, str):
54+
assert X_inner_type in valid_unequal_types
55+
else: # must be a list
56+
assert any([t in valid_unequal_types for t in X_inner_type])
57+
58+
59+
def check_collection_detector_output(estimator, datatype):
60+
"""Test detector outputs the correct data types and values."""
61+
estimator = _clone_estimator(estimator)
62+
63+
# run fit and predict
64+
estimator.fit(
65+
FULL_TEST_DATA_DICT[datatype]["train"][0],
66+
FULL_TEST_DATA_DICT[datatype]["train"][1],
67+
)
68+
y_pred = estimator.predict(FULL_TEST_DATA_DICT[datatype]["test"][0])
69+
_assert_predict_labels(y_pred, datatype, unique_labels=[0, 1])

0 commit comments

Comments
 (0)