|
1 |
| -"""Basic outlier detection classifier.""" |
| 1 | +"""Adapter to use outlier detection algorithms for collection anomaly detection.""" |
2 | 2 |
|
| 3 | +__maintainer__ = [] |
| 4 | + |
| 5 | +from sklearn.base import OutlierMixin |
3 | 6 | from sklearn.ensemble import IsolationForest
|
4 | 7 |
|
5 | 8 | from aeon.anomaly_detection.whole_series.base import BaseCollectionAnomalyDetector
|
6 | 9 | from aeon.base._base import _clone_estimator
|
7 | 10 |
|
8 | 11 |
|
9 |
| -class OutlierDetectionClassifier(BaseCollectionAnomalyDetector): |
10 |
| - """Basic outlier detection classifier.""" |
| 12 | +class OutlierDetectionAdapter(BaseCollectionAnomalyDetector): |
| 13 | + """ |
| 14 | + Basic outlier detection adapter for collection anomaly detection. |
| 15 | +
|
| 16 | + This class wraps an sklearn outlier detection algorithm to be used as an anomaly |
| 17 | + detector. |
| 18 | +
|
| 19 | + Parameters |
| 20 | + ---------- |
| 21 | + detector : OutlierMixin |
| 22 | + The outlier detection algorithm to be adapted. |
| 23 | + random_state : int, RandomState instance or None, default=None |
| 24 | + If `int`, random_state is the seed used by the random number generator; |
| 25 | + If `RandomState` instance, random_state is the random number generator; |
| 26 | + If `None`, the random number generator is the `RandomState` instance used |
| 27 | + by `np.random`. |
| 28 | + """ |
11 | 29 |
|
12 | 30 | _tags = {
|
13 | 31 | "X_inner_type": "numpy2D",
|
14 | 32 | }
|
15 | 33 |
|
16 |
| - def __init__(self, estimator, random_state=None): |
17 |
| - self.estimator = estimator |
| 34 | + def __init__(self, detector, random_state=None): |
| 35 | + self.detector = detector |
18 | 36 | self.random_state = random_state
|
19 | 37 |
|
20 | 38 | super().__init__()
|
21 | 39 |
|
22 | 40 | def _fit(self, X, y=None):
|
23 |
| - self.estimator_ = _clone_estimator( |
24 |
| - self.estimator, random_state=self.random_state |
25 |
| - ) |
26 |
| - self.estimator_.fit(X, y) |
| 41 | + if not isinstance(self.detector, OutlierMixin): |
| 42 | + raise ValueError( |
| 43 | + "The estimator must be an outlier detection algorithm " |
| 44 | + "that implements the OutlierMixin interface." |
| 45 | + ) |
| 46 | + |
| 47 | + self.detector_ = _clone_estimator(self.detector, random_state=self.random_state) |
| 48 | + self.detector_.fit(X, y) |
27 | 49 | return self
|
28 | 50 |
|
29 | 51 | def _predict(self, X):
|
30 |
| - pred = self.estimator_.predict(X) |
| 52 | + pred = self.detector_.predict(X) |
31 | 53 | pred[pred == -1] = 0
|
32 | 54 | return pred
|
33 | 55 |
|
|
0 commit comments