Skip to content

Commit e6043d4

Browse files
MatthewMiddlehurstpattplatt
authored andcommitted
- Resolved merge conflicts with upstream main in estimator checks, testing data, and register modules
1 parent b9c2a5b commit e6043d4

File tree

7 files changed

+141
-88
lines changed

7 files changed

+141
-88
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Whole-series anomaly detection methods."""
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
"""Basic outlier detection classifier."""
2+
3+
from aeon.anomaly_detection import IsolationForest
4+
from aeon.anomaly_detection.whole_series.base import BaseCollectionAnomalyDetector
5+
from aeon.base._base import _clone_estimator
6+
7+
8+
class OutlierDetectionClassifier(BaseCollectionAnomalyDetector):
9+
"""Basic outlier detection classifier."""
10+
11+
_tags = {
12+
"X_inner_type": "numpy2D",
13+
}
14+
15+
def __init__(self, estimator, random_state=None):
16+
self.estimator = estimator
17+
self.random_state = random_state
18+
19+
super().__init__()
20+
21+
def _fit(self, X, y=None):
22+
self.estimator_ = _clone_estimator(
23+
self.estimator, random_state=self.random_state
24+
)
25+
self.estimator_.fit(X, y)
26+
return self
27+
28+
def _predict(self, X):
29+
pred = self.estimator_.predict(X)
30+
pred[pred == -1] = 0
31+
return pred
32+
33+
def _get_test_params(cls, parameter_set="default"):
34+
return {"estimator": IsolationForest()}
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
"""Abstract base class for whole-series/collection anomaly detectors."""
2+
3+
__maintainer__ = ["MatthewMiddlehurst"]
4+
__all__ = ["BaseCollectionAnomalyDetector"]
5+
6+
from abc import abstractmethod
7+
from typing import final
8+
9+
import numpy as np
10+
import pandas as pd
11+
12+
from aeon.base import BaseCollectionEstimator
13+
14+
15+
class BaseCollectionAnomalyDetector(BaseCollectionEstimator):
16+
"""Collection anomaly detector base class."""
17+
18+
_tags = {
19+
"fit_is_empty": False,
20+
"requires_y": False,
21+
}
22+
23+
def __init__(self):
24+
super().__init__()
25+
26+
@final
27+
def fit(self, X, y=None):
28+
"""Fit."""
29+
if self.get_tag("fit_is_empty"):
30+
self.is_fitted = True
31+
return self
32+
33+
if self.get_tag("requires_y"):
34+
if y is None:
35+
raise ValueError("Tag requires_y is true, but fit called with y=None")
36+
37+
# reset estimator at the start of fit
38+
self.reset()
39+
40+
X = self._preprocess_collection(X)
41+
if y is not None:
42+
y = self._check_y(y, self.metadata_["n_cases"])
43+
44+
self._fit(X, y)
45+
46+
# this should happen last
47+
self.is_fitted = True
48+
return self
49+
50+
@final
51+
def predict(self, X):
52+
"""Predict."""
53+
fit_empty = self.get_tag("fit_is_empty")
54+
if not fit_empty:
55+
self._check_is_fitted()
56+
57+
X = self._preprocess_collection(X, store_metadata=False)
58+
# Check if X has the correct shape seen during fitting
59+
self._check_shape(X)
60+
61+
return self._predict(X)
62+
63+
@abstractmethod
64+
def _fit(self, X, y=None): ...
65+
66+
@abstractmethod
67+
def _predict(self, X): ...
68+
69+
def _check_y(self, y, n_cases):
70+
if not isinstance(y, (pd.Series, np.ndarray)):
71+
raise TypeError(
72+
f"y must be a np.array or a pd.Series, but found type: {type(y)}"
73+
)
74+
if isinstance(y, np.ndarray) and y.ndim > 1:
75+
raise TypeError(f"y must be 1-dimensional, found {y.ndim} dimensions")
76+
77+
if not all([x == 0 or x == 1 for x in y]):
78+
raise ValueError(
79+
"y input must only contain 0 (not anomalous) or 1 (anomalous) values."
80+
)
81+
82+
# Check matching number of labels
83+
n_labels = y.shape[0]
84+
if n_cases != n_labels:
85+
raise ValueError(
86+
f"Mismatch in number of cases. Found X = {n_cases} and y = {n_labels}"
87+
)
88+
89+
if isinstance(y, pd.Series):
90+
y = pd.Series.to_numpy(y)
91+
92+
return y
Lines changed: 2 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1,81 +1,9 @@
11
"""Tests for all collection anomaly detectors."""
22

3-
from functools import partial
4-
5-
import numpy as np
6-
7-
from aeon.base._base import _clone_estimator
8-
from aeon.testing.testing_data import FULL_TEST_DATA_DICT
9-
from aeon.utils.data_types import COLLECTIONS_DATA_TYPES
10-
from aeon.utils.validation import get_n_cases
11-
123

134
def _yield_collection_anomaly_detection_checks(
145
estimator_class, estimator_instances, datatypes
156
):
167
"""Yield all collection anomaly detection checks for an aeon estimator."""
17-
# only class required
18-
yield partial(
19-
check_collection_detector_overrides_and_tags, estimator_class=estimator_class
20-
)
21-
22-
# test class instances
23-
for i, estimator in enumerate(estimator_instances):
24-
# test all data types
25-
for datatype in datatypes[i]:
26-
yield partial(
27-
check_collection_anomaly_detector_output,
28-
estimator=estimator,
29-
datatype=datatype,
30-
)
31-
32-
33-
def check_collection_detector_overrides_and_tags(estimator_class):
34-
"""Test compliance with the detector base class contract."""
35-
# Test valid tag for X_inner_type
36-
X_inner_type = estimator_class.get_class_tag(tag_name="X_inner_type")
37-
if isinstance(X_inner_type, str):
38-
assert X_inner_type in COLLECTIONS_DATA_TYPES
39-
else: # must be a list
40-
assert all([t in COLLECTIONS_DATA_TYPES for t in X_inner_type])
41-
42-
# one of X_inner_types must be capable of storing unequal length
43-
if estimator_class.get_class_tag("capability:unequal_length"):
44-
valid_unequal_types = ["np-list", "df-list", "pd-multiindex"]
45-
if isinstance(X_inner_type, str):
46-
assert X_inner_type in valid_unequal_types
47-
else: # must be a list
48-
assert any([t in valid_unequal_types for t in X_inner_type])
49-
50-
51-
def check_collection_anomaly_detector_output(estimator, datatype):
52-
"""Test the collection anomaly detector output on valid data."""
53-
estimator = _clone_estimator(estimator)
54-
55-
estimator.fit(
56-
FULL_TEST_DATA_DICT[datatype]["train"][0],
57-
FULL_TEST_DATA_DICT[datatype]["train"][1],
58-
)
59-
60-
y_pred = estimator.predict(FULL_TEST_DATA_DICT[datatype]["test"][0])
61-
assert isinstance(y_pred, np.ndarray)
62-
# collections need n_cases predictions
63-
assert len(y_pred) == get_n_cases(FULL_TEST_DATA_DICT[datatype]["test"][0])
64-
65-
ot = estimator.get_tag("anomaly_output_type")
66-
if ot == "anomaly_scores":
67-
assert np.issubdtype(y_pred.dtype, np.floating) or np.issubdtype(
68-
y_pred.dtype, np.integer
69-
), "y_pred must be of floating point or int type"
70-
assert not np.array_equal(
71-
np.unique(y_pred), [0, 1]
72-
), "y_pred cannot contain only 0s and 1s"
73-
elif ot == "binary":
74-
assert np.issubdtype(y_pred.dtype, np.integer) or np.issubdtype(
75-
y_pred.dtype, np.bool_
76-
), "y_pred must be of int or bool type for binary output"
77-
assert all(
78-
val in [0, 1] for val in np.unique(y_pred)
79-
), "y_pred must contain only 0s, 1s, True, or False"
80-
else:
81-
raise ValueError(f"Unknown anomaly output type: {ot}")
8+
# nothing currently!
9+
return []

aeon/testing/estimator_checking/_yield_estimator_checks.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@
1212
from sklearn.exceptions import NotFittedError
1313

1414
from aeon.anomaly_detection.base import BaseAnomalyDetector
15-
from aeon.anomaly_detection.collection.base import BaseCollectionAnomalyDetector
16-
from aeon.anomaly_detection.series.base import BaseSeriesAnomalyDetector
15+
from aeon.anomaly_detection.whole_series.base import BaseCollectionAnomalyDetector
1716
from aeon.base import BaseAeonEstimator
1817
from aeon.base._base import _clone_estimator
1918
from aeon.classification import BaseClassifier
@@ -37,6 +36,9 @@
3736
from aeon.testing.estimator_checking._yield_collection_anomaly_detection_checks import (
3837
_yield_collection_anomaly_detection_checks,
3938
)
39+
from aeon.testing.estimator_checking._yield_collection_transformation_checks import (
40+
_yield_collection_transformation_checks,
41+
)
4042
from aeon.testing.estimator_checking._yield_early_classification_checks import (
4143
_yield_early_classification_checks,
4244
)
@@ -148,13 +150,13 @@ def _yield_all_aeon_checks(
148150
estimator_class, estimator_instances, datatypes
149151
)
150152

151-
if issubclass(estimator_class, BaseSeriesAnomalyDetector):
152-
yield from _yield_series_anomaly_detection_checks(
153+
if issubclass(estimator_class, BaseCollectionAnomalyDetector):
154+
yield from _yield_collection_anomaly_detection_checks(
153155
estimator_class, estimator_instances, datatypes
154156
)
155157

156-
if issubclass(estimator_class, BaseCollectionAnomalyDetector):
157-
yield from _yield_collection_anomaly_detection_checks(
158+
if issubclass(estimator_class, BaseSimilaritySearch):
159+
yield from _yield_similarity_search_checks(
158160
estimator_class, estimator_instances, datatypes
159161
)
160162

aeon/testing/testing_data.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
import numpy as np
44

5-
from aeon.anomaly_detection.collection.base import BaseCollectionAnomalyDetector
6-
from aeon.anomaly_detection.series.base import BaseSeriesAnomalyDetector
5+
from aeon.anomaly_detection.base import BaseAnomalyDetector
6+
from aeon.anomaly_detection.whole_series.base import BaseCollectionAnomalyDetector
77
from aeon.base import BaseCollectionEstimator, BaseSeriesEstimator
88
from aeon.classification import BaseClassifier
99
from aeon.classification.early_classification import BaseEarlyClassifier
@@ -862,8 +862,8 @@ def _get_task_for_estimator(estimator):
862862
or isinstance(estimator, BaseEarlyClassifier)
863863
or isinstance(estimator, BaseClusterer)
864864
or isinstance(estimator, BaseCollectionTransformer)
865+
or isinstance(estimator, BaseSimilaritySearch)
865866
or isinstance(estimator, BaseCollectionAnomalyDetector)
866-
or isinstance(estimator, BaseCollectionSimilaritySearch)
867867
):
868868
data_label = "Classification"
869869
# collection data with continuous target labels

aeon/utils/base/_register.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@
1616
]
1717

1818
from aeon.anomaly_detection.base import BaseAnomalyDetector
19-
from aeon.anomaly_detection.collection.base import BaseCollectionAnomalyDetector
20-
from aeon.anomaly_detection.series.base import BaseSeriesAnomalyDetector
19+
from aeon.anomaly_detection.whole_series.base import BaseCollectionAnomalyDetector
2120
from aeon.base import BaseAeonEstimator, BaseCollectionEstimator, BaseSeriesEstimator
2221
from aeon.classification.base import BaseClassifier
2322
from aeon.classification.early_classification import BaseEarlyClassifier
@@ -40,10 +39,7 @@
4039
"series-estimator": BaseSeriesEstimator,
4140
"transformer": BaseTransformer,
4241
"anomaly-detector": BaseAnomalyDetector,
43-
"similarity-search": BaseSimilaritySearch,
44-
# estimator types
4542
"collection-anomaly-detector": BaseCollectionAnomalyDetector,
46-
"collection-similarity-search": BaseCollectionSimilaritySearch,
4743
"collection-transformer": BaseCollectionTransformer,
4844
"classifier": BaseClassifier,
4945
"clusterer": BaseClusterer,

0 commit comments

Comments
 (0)