Skip to content

Commit b6f0a30

Browse files
MatthewMiddlehurstpattplatt
authored andcommitted
wrappers
1 parent 59fb9a4 commit b6f0a30

File tree

2 files changed

+97
-10
lines changed

2 files changed

+97
-10
lines changed
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
"""Adapter to use classification algorithms for collection anomaly detection."""
2+
3+
__maintainer__ = []
4+
5+
6+
from sklearn.base import ClassifierMixin
7+
from sklearn.ensemble import RandomForestClassifier
8+
9+
from aeon.anomaly_detection.whole_series.base import BaseCollectionAnomalyDetector
10+
from aeon.base._base import _clone_estimator
11+
from aeon.classification.feature_based import SummaryClassifier
12+
13+
14+
class ClassificationAdapter(BaseCollectionAnomalyDetector):
15+
"""
16+
Basic classifier adapter for collection anomaly detection.
17+
18+
This class wraps a classification algorithm to be used as an anomaly detector.
19+
Anomaly labels are required for training.
20+
21+
Parameters
22+
----------
23+
classifier : aeon classifier or ClassifierMixin
24+
The classification algorithm to be adapted.
25+
random_state : int, RandomState instance or None, default=None
26+
If `int`, random_state is the seed used by the random number generator;
27+
If `RandomState` instance, random_state is the random number generator;
28+
If `None`, the random number generator is the `RandomState` instance used
29+
by `np.random`.
30+
"""
31+
32+
_tags = {
33+
"X_inner_type": "numpy2D",
34+
"requires_y": True,
35+
}
36+
37+
def __init__(self, classifier, random_state=None):
38+
self.classifier = classifier
39+
self.random_state = random_state
40+
41+
super().__init__()
42+
43+
def _fit(self, X, y=None):
44+
if not isinstance(self.classifier, ClassifierMixin):
45+
raise ValueError(
46+
"The estimator must be an aeon classification algorithm "
47+
"or class that implements the ClassifierMixin interface."
48+
)
49+
50+
self.classifier_ = _clone_estimator(
51+
self.classifier, random_state=self.random_state
52+
)
53+
self.classifier_.fit(X, y)
54+
return self
55+
56+
def _predict(self, X):
57+
return self.classifier_.predict(X)
58+
59+
@classmethod
60+
def _get_test_params(cls, parameter_set="default"):
61+
return {
62+
"estimator": SummaryClassifier(
63+
estimator=RandomForestClassifier(n_estimators=5)
64+
)
65+
}

aeon/anomaly_detection/whole_series/_outlier_detection.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,55 @@
1-
"""Basic outlier detection classifier."""
1+
"""Adapter to use outlier detection algorithms for collection anomaly detection."""
22

3+
__maintainer__ = []
4+
5+
from sklearn.base import OutlierMixin
36
from sklearn.ensemble import IsolationForest
47

58
from aeon.anomaly_detection.whole_series.base import BaseCollectionAnomalyDetector
69
from aeon.base._base import _clone_estimator
710

811

9-
class OutlierDetectionClassifier(BaseCollectionAnomalyDetector):
10-
"""Basic outlier detection classifier."""
12+
class OutlierDetectionAdapter(BaseCollectionAnomalyDetector):
13+
"""
14+
Basic outlier detection adapter for collection anomaly detection.
15+
16+
This class wraps an sklearn outlier detection algorithm to be used as an anomaly
17+
detector.
18+
19+
Parameters
20+
----------
21+
detector : OutlierMixin
22+
The outlier detection algorithm to be adapted.
23+
random_state : int, RandomState instance or None, default=None
24+
If `int`, random_state is the seed used by the random number generator;
25+
If `RandomState` instance, random_state is the random number generator;
26+
If `None`, the random number generator is the `RandomState` instance used
27+
by `np.random`.
28+
"""
1129

1230
_tags = {
1331
"X_inner_type": "numpy2D",
1432
}
1533

16-
def __init__(self, estimator, random_state=None):
17-
self.estimator = estimator
34+
def __init__(self, detector, random_state=None):
35+
self.detector = detector
1836
self.random_state = random_state
1937

2038
super().__init__()
2139

2240
def _fit(self, X, y=None):
23-
self.estimator_ = _clone_estimator(
24-
self.estimator, random_state=self.random_state
25-
)
26-
self.estimator_.fit(X, y)
41+
if not isinstance(self.detector, OutlierMixin):
42+
raise ValueError(
43+
"The estimator must be an outlier detection algorithm "
44+
"that implements the OutlierMixin interface."
45+
)
46+
47+
self.detector_ = _clone_estimator(self.detector, random_state=self.random_state)
48+
self.detector_.fit(X, y)
2749
return self
2850

2951
def _predict(self, X):
30-
pred = self.estimator_.predict(X)
52+
pred = self.detector_.predict(X)
3153
pred[pred == -1] = 0
3254
return pred
3355

0 commit comments

Comments
 (0)