Skip to content

Commit d777a12

Browse files
fixes
1 parent f6122d7 commit d777a12

File tree

12 files changed

+78
-23
lines changed

12 files changed

+78
-23
lines changed

aeon/anomaly_detection/base.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@ class BaseAnomalyDetector(BaseAeonEstimator):
1717
# todo
1818
}
1919

20-
def __init__(self, axis):
20+
def __init__(self):
2121
super().__init__()
2222

2323
@abstractmethod
24-
def fit(self, X, y=None, axis=1):
24+
def fit(self, X, y=None):
2525
"""Fit time series anomaly detector to X.
2626
2727
If the tag ``fit_is_empty`` is true, this just sets the ``is_fitted`` tag to
@@ -54,7 +54,7 @@ def fit(self, X, y=None, axis=1):
5454
...
5555

5656
@abstractmethod
57-
def predict(self, X, axis=1) -> np.ndarray:
57+
def predict(self, X) -> np.ndarray:
5858
"""Find anomalies in X.
5959
6060
Parameters
@@ -79,7 +79,7 @@ def predict(self, X, axis=1) -> np.ndarray:
7979
...
8080

8181
@abstractmethod
82-
def fit_predict(self, X, y=None, axis=1) -> np.ndarray:
82+
def fit_predict(self, X, y=None) -> np.ndarray:
8383
"""Fit time series anomaly detector and find anomalies for X.
8484
8585
Parameters

aeon/anomaly_detection/collection/base.py

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,11 @@ class name: BaseCollectionAnomalyDetector
3232
import numpy as np
3333
import pandas as pd
3434

35+
from aeon.anomaly_detection.base import BaseAnomalyDetector
3536
from aeon.base import BaseCollectionEstimator
3637

3738

38-
class BaseCollectionAnomalyDetector(BaseCollectionEstimator):
39+
class BaseCollectionAnomalyDetector(BaseCollectionEstimator, BaseAnomalyDetector):
3940
"""
4041
Abstract base class for collection anomaly detectors.
4142
@@ -162,12 +163,65 @@ def predict(self, X):
162163

163164
return self._predict(X)
164165

165-
@abstractmethod
166-
def _fit(self, X, y=None): ...
166+
@final
167+
def fit_predict(self, X, y=None, axis=1) -> np.ndarray:
168+
"""Fit time series anomaly detector and find anomalies for X.
169+
170+
Parameters
171+
----------
172+
X : one of aeon.base._base_series.VALID_SERIES_INPUT_TYPES
173+
The time series to fit the model to.
174+
A valid aeon time series data structure. See
175+
aeon.base._base_series.VALID_INPUT_TYPES for aeon supported types.
176+
y : one of aeon.base._base_series.VALID_SERIES_INPUT_TYPES, default=None
177+
The target values for the time series.
178+
A valid aeon time series data structure. See
179+
aeon.base._base_series.VALID_SERIES_INPUT_TYPES for aeon supported types.
180+
axis : int, default=1
181+
The time point axis of the input series if it is 2D. If ``axis==0``, it is
182+
assumed each column is a time series and each row is a time point. i.e. the
183+
shape of the data is ``(n_timepoints, n_channels)``. ``axis==1`` indicates
184+
the time series are in rows, i.e. the shape of the data is
185+
``(n_channels, n_timepoints)``.
186+
187+
Returns
188+
-------
189+
np.ndarray
190+
A boolean, int or float array of length len(X), where each element indicates
191+
whether the corresponding subsequence is anomalous or its anomaly score.
192+
"""
193+
if self.get_tag("requires_y"):
194+
if y is None:
195+
raise ValueError("Tag requires_y is true, but fit called with y=None")
196+
197+
# reset estimator at the start of fit
198+
self.reset()
199+
200+
X = self._preprocess_series(X, axis, True)
201+
202+
if self.get_tag("fit_is_empty"):
203+
self.is_fitted = True
204+
return self._predict(X)
205+
206+
if y is not None:
207+
y = self._check_y(y)
208+
209+
pred = self._fit_predict(X, y)
210+
211+
# this should happen last
212+
self.is_fitted = True
213+
return pred
214+
215+
def _fit(self, X, y):
216+
return self
167217

168218
@abstractmethod
169219
def _predict(self, X): ...
170220

221+
def _fit_predict(self, X, y):
222+
self._fit(X, y)
223+
return self._predict(X)
224+
171225
def _check_y(self, y, n_cases):
172226
"""Check y input is valid.
173227

aeon/anomaly_detection/series/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,12 @@
99
import numpy as np
1010
import pandas as pd
1111

12+
from aeon.anomaly_detection.base import BaseAnomalyDetector
1213
from aeon.base import BaseSeriesEstimator
1314
from aeon.base._base_series import VALID_SERIES_INPUT_TYPES
1415

1516

16-
class BaseSeriesAnomalyDetector(BaseSeriesEstimator):
17+
class BaseSeriesAnomalyDetector(BaseSeriesEstimator, BaseAnomalyDetector):
1718
"""Base class for series anomaly detection algorithms.
1819
1920
Anomaly detection algorithms are used to identify anomalous subsequences in time

aeon/anomaly_detection/series/distance_based/_kmeans.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ class KMeansAD(BaseSeriesAnomalyDetector):
6565
Examples
6666
--------
6767
>>> import numpy as np
68-
>>> from aeon.anomaly_detection.distance_based import KMeansAD
68+
>>> from aeon.anomaly_detection.series.distance_based import KMeansAD
6969
>>> X = np.array([1, 2, 3, 4, 1, 2, 3, 3, 2, 8, 9, 8, 1, 2, 3, 4], dtype=np.float64)
7070
>>> detector = KMeansAD(n_clusters=3, window_size=4, stride=1, random_state=0)
7171
>>> detector.fit_predict(X)

aeon/anomaly_detection/series/distance_based/_merlin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class MERLIN(BaseSeriesAnomalyDetector):
4343
Examples
4444
--------
4545
>>> import numpy as np
46-
>>> from aeon.anomaly_detection.distance_based import MERLIN
46+
>>> from aeon.anomaly_detection.series.distance_based import MERLIN
4747
>>> X = np.array([1, 2, 3, 4, 1, 2, 3, 4, 2, 3, 4, 5, 1, 2, 3, 4])
4848
>>> detector = MERLIN(min_length=4, max_length=5)
4949
>>> detector.fit_predict(X)

aeon/anomaly_detection/series/distribution_based/_dwt_mlead.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ class DWT_MLEAD(BaseSeriesAnomalyDetector):
7878
Examples
7979
--------
8080
>>> import numpy as np
81-
>>> from aeon.anomaly_detection.distribution_based import DWT_MLEAD
81+
>>> from aeon.anomaly_detection.series.distribution_based import DWT_MLEAD
8282
>>> X = np.array([1, 2, 3, 4, 1, 2, 3, 3, 2, 8, 9, 8, 1, 2, 3, 4], dtype=np.float64)
8383
>>> detector = DWT_MLEAD(
8484
... start_level=1, quantile_boundary_type='percentile', quantile_epsilon=0.01

aeon/anomaly_detection/series/outlier_detection/_stray.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ class STRAY(BaseSeriesAnomalyDetector):
5454
5555
Examples
5656
--------
57-
>>> from aeon.anomaly_detection.outlier_detection import STRAY
57+
>>> from aeon.anomaly_detection.series.outlier_detection import STRAY
5858
>>> from aeon.datasets import load_airline
5959
>>> import numpy as np
6060
>>> X = load_airline()

aeon/anomaly_detection/series/tests/test_pyod_adapter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pytest
77
from sklearn.utils import check_random_state
88

9-
from aeon.anomaly_detection.series.outlier_detection import PyODAdapter
9+
from aeon.anomaly_detection.series import PyODAdapter
1010
from aeon.utils.validation._dependencies import _check_soft_dependencies
1111

1212

aeon/testing/estimator_checking/_yield_estimator_checks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
import numpy as np
1212
from sklearn.exceptions import NotFittedError
1313

14+
from aeon.anomaly_detection.collection.base import BaseCollectionAnomalyDetector
1415
from aeon.anomaly_detection.series.base import BaseSeriesAnomalyDetector
15-
from aeon.anomaly_detection.whole_series.base import BaseCollectionAnomalyDetector
1616
from aeon.base import BaseAeonEstimator
1717
from aeon.base._base import _clone_estimator
1818
from aeon.classification import BaseClassifier

aeon/utils/base/_identifier.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,16 +47,16 @@ def get_identifier(estimator):
4747
if len(identifiers) == 0:
4848
raise TypeError("Error, no identifiers could be determined for estimator")
4949

50-
if len(identifiers) > 1 and "anomaly-detector" in identifiers:
51-
identifiers.remove("anomaly-detector")
5250
if len(identifiers) > 1 and "estimator" in identifiers:
5351
identifiers.remove("estimator")
54-
if len(identifiers) > 1 and "series-estimator" in identifiers:
55-
identifiers.remove("series-estimator")
5652
if len(identifiers) > 1 and "collection-estimator" in identifiers:
5753
identifiers.remove("collection-estimator")
54+
if len(identifiers) > 1 and "series-estimator" in identifiers:
55+
identifiers.remove("series-estimator")
5856
if len(identifiers) > 1 and "transformer" in identifiers:
5957
identifiers.remove("transformer")
58+
if len(identifiers) > 1 and "anomaly-detector" in identifiers:
59+
identifiers.remove("anomaly-detector")
6060

6161
if len(identifiers) > 1:
6262
TypeError(

0 commit comments

Comments
 (0)