Skip to content

Commit 59fb9a4

Browse files
MatthewMiddlehurstpattplatt
authored andcommitted
base docs
1 parent c2036fc commit 59fb9a4

File tree

1 file changed

+111
-4
lines changed
  • aeon/anomaly_detection/whole_series

1 file changed

+111
-4
lines changed

aeon/anomaly_detection/whole_series/base.py

Lines changed: 111 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,27 @@
1-
"""Abstract base class for whole-series/collection anomaly detectors."""
1+
"""
2+
Abstract base class for whole-series/collection anomaly detectors.
3+
4+
class name: BaseCollectionAnomalyDetector
5+
6+
Defining methods:
7+
fitting - fit(self, X, y)
8+
predicting - predict(self, X)
9+
10+
Data validation:
11+
data processing - _preprocess_collection(self, X, store_metadata=True)
12+
shape verification - _check_shape(self, X)
13+
14+
State:
15+
fitted model/strategy - by convention, any attributes ending in "_"
16+
fitted state flag - is_fitted
17+
train input metadata - metadata_
18+
resetting state - reset(self)
19+
20+
Tags:
21+
default estimator tags - _tags
22+
tag retrieval - get_tag(self, tag_name)
23+
tag setting - set_tag(self, tag_name, value)
24+
"""
225

326
__maintainer__ = ["MatthewMiddlehurst"]
427
__all__ = ["BaseCollectionAnomalyDetector"]
@@ -13,7 +36,23 @@
1336

1437

1538
class BaseCollectionAnomalyDetector(BaseCollectionEstimator):
16-
"""Collection anomaly detector base class."""
39+
"""
40+
Abstract base class for collection anomaly detectors.
41+
42+
The base detector specifies the methods and method signatures that all
43+
collection anomaly detectors have to implement. Attributes with an underscore
44+
suffix are set in the method fit.
45+
46+
Attributes
47+
----------
48+
is_fitted : bool
49+
True if the estimator has been fitted, False otherwise.
50+
Unused if ``"fit_is_empty"`` tag is set to True.
51+
metadata_ : dict
52+
Dictionary containing metadata about the `fit` input data.
53+
_tags_dynamic : dict
54+
Dictionary containing dynamic tag values which have been set at runtime.
55+
"""
1756

1857
_tags = {
1958
"fit_is_empty": False,
@@ -25,7 +64,42 @@ def __init__(self):
2564

2665
@final
2766
def fit(self, X, y=None):
28-
"""Fit."""
67+
"""Fit collection anomaly detector to training data.
68+
69+
Parameters
70+
----------
71+
X : np.ndarray or list
72+
Input data, any number of channels, equal length series of shape ``(
73+
n_cases, n_channels, n_timepoints)``
74+
or 2D np.array (univariate, equal length series) of shape
75+
``(n_cases, n_timepoints)``
76+
or list of numpy arrays (any number of channels, unequal length series)
77+
of shape ``[n_cases]``, 2D np.array ``(n_channels, n_timepoints_i)``,
78+
where ``n_timepoints_i`` is length of series ``i``. Other types are
79+
allowed and converted into one of the above.
80+
81+
Different estimators have different capabilities to handle different
82+
types of input. If ``self.get_tag("capability:multivariate")`` is False,
83+
they cannot handle multivariate series, so either ``n_channels == 1`` is
84+
true or X is 2D of shape ``(n_cases, n_timepoints)``. If ``self.get_tag(
85+
"capability:unequal_length")`` is False, they cannot handle unequal
86+
length input. In both situations, a ``ValueError`` is raised if X has a
87+
characteristic that the estimator does not have the capability for is
88+
passed.
89+
y : np.ndarray
90+
1D np.array of int, of shape ``(n_cases)`` - anomaly labels
91+
(ground truth) for fitting indices corresponding to instance indices in X.
92+
93+
Returns
94+
-------
95+
self : BaseCollectionAnomalyDetector
96+
Reference to self.
97+
98+
Notes
99+
-----
100+
Changes state by creating a fitted model that updates attributes
101+
ending in "_" and sets is_fitted flag to True.
102+
"""
29103
if self.get_tag("fit_is_empty"):
30104
self.is_fitted = True
31105
return self
@@ -49,7 +123,35 @@ def fit(self, X, y=None):
49123

50124
@final
51125
def predict(self, X):
52-
"""Predict."""
126+
"""Predicts anomalies for time series in X.
127+
128+
Parameters
129+
----------
130+
X : np.ndarray or list
131+
Input data, any number of channels, equal length series of shape ``(
132+
n_cases, n_channels, n_timepoints)``
133+
or 2D np.array (univariate, equal length series) of shape
134+
``(n_cases, n_timepoints)``
135+
or list of numpy arrays (any number of channels, unequal length series)
136+
of shape ``[n_cases]``, 2D np.array ``(n_channels, n_timepoints_i)``,
137+
where ``n_timepoints_i`` is length of series ``i``
138+
other types are allowed and converted into one of the above.
139+
140+
Different estimators have different capabilities to handle different
141+
types of input. If ``self.get_tag("capability:multivariate")`` is False,
142+
they cannot handle multivariate series, so either ``n_channels == 1`` is
143+
true or X is 2D of shape ``(n_cases, n_timepoints)``. If ``self.get_tag(
144+
"capability:unequal_length")`` is False, they cannot handle unequal
145+
length input. In both situations, a ``ValueError`` is raised if X has a
146+
characteristic that the estimator does not have the capability for is
147+
passed.
148+
149+
Returns
150+
-------
151+
predictions : np.ndarray
152+
1D np.array of float, of shape (n_cases) - predicted anomalies
153+
indices correspond to instance indices in X
154+
"""
53155
fit_empty = self.get_tag("fit_is_empty")
54156
if not fit_empty:
55157
self._check_is_fitted()
@@ -67,6 +169,11 @@ def _fit(self, X, y=None): ...
67169
def _predict(self, X): ...
68170

69171
def _check_y(self, y, n_cases):
172+
"""Check y input is valid.
173+
174+
Must be 1-dimensional and contain only 0s (no anomaly) and 1s (anomaly).
175+
Must match the number of cases in X.
176+
"""
70177
if not isinstance(y, (pd.Series, np.ndarray)):
71178
raise TypeError(
72179
f"y must be a np.array or a pd.Series, but found type: {type(y)}"

0 commit comments

Comments
 (0)