Skip to content

Commit f6122d7

Browse files
Revert "smoothing refactor"
This reverts commit c245ccb.
1 parent c245ccb commit f6122d7

20 files changed

+312
-582
lines changed

aeon/transformations/series/_dft.py

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,12 @@
44
__all__ = ["DFTSeriesTransformer"]
55

66

7-
from deprecated.sphinx import deprecated
7+
import numpy as np
88

9-
from aeon.transformations.series.smoothing import DiscreteFourierApproximation
9+
from aeon.transformations.series.base import BaseSeriesTransformer
1010

1111

12-
# TODO: Remove in v1.3.0
13-
@deprecated(
14-
version="1.2.0",
15-
reason="DFTSeriesTransformer is deprecated and will be removed in v1.3.0. "
16-
"Please use DiscreteFourierApproximation from "
17-
"transformations.series.smoothing instead.",
18-
category=FutureWarning,
19-
)
20-
class DFTSeriesTransformer(DiscreteFourierApproximation):
12+
class DFTSeriesTransformer(BaseSeriesTransformer):
2113
"""Filter a times series using Discrete Fourier Approximation (DFT).
2214
2315
Parameters
@@ -50,4 +42,47 @@ class DFTSeriesTransformer(DiscreteFourierApproximation):
5042
(2, 100)
5143
"""
5244

53-
...
45+
_tags = {
46+
"capability:multivariate": True,
47+
"X_inner_type": "np.ndarray",
48+
"fit_is_empty": True,
49+
}
50+
51+
def __init__(self, r=0.5, sort=False):
52+
self.r = r
53+
self.sort = sort
54+
super().__init__(axis=1)
55+
56+
def _transform(self, X, y=None):
57+
"""Transform X and return a transformed version.
58+
59+
Parameters
60+
----------
61+
X : np.ndarray
62+
time series in shape (n_channels, n_timepoints)
63+
y : ignored argument for interface compatibility
64+
65+
Returns
66+
-------
67+
transformed version of X
68+
"""
69+
# Compute DFT
70+
dft = np.fft.fft(X)
71+
72+
# Mask array of terms to keep and number of terms to keep
73+
mask = np.zeros_like(dft, dtype=bool)
74+
keep = max(int(self.r * dft.shape[1]), 1)
75+
76+
# If sort is set, sort the indices by the decreasing dft amplitude
77+
if self.sort:
78+
sorted_indices = np.argsort(np.abs(dft))[:, ::-1]
79+
for i in range(dft.shape[0]):
80+
mask[i, sorted_indices[i, 0:keep]] = True
81+
# Else, keep the first terms
82+
else:
83+
mask[:, 0:keep] = True
84+
85+
# Invert DFT with masked terms
86+
X_ = np.fft.ifft(dft * mask).real
87+
88+
return X_

aeon/transformations/series/_exp_smoothing.py

Lines changed: 43 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,14 @@
33
__maintainer__ = ["Datadote"]
44
__all__ = ["ExpSmoothingSeriesTransformer"]
55

6+
from typing import Union
67

7-
from deprecated.sphinx import deprecated
8+
import numpy as np
89

9-
from aeon.transformations.series.smoothing import ExponentialSmoothing
10+
from aeon.transformations.series.base import BaseSeriesTransformer
1011

1112

12-
# TODO: Remove in v1.3.0
13-
@deprecated(
14-
version="1.2.0",
15-
reason="ExpSmoothingSeriesTransformer is deprecated and will be removed in v1.3.0. "
16-
"Please use ExponentialSmoothing from "
17-
"transformations.series.smoothing instead.",
18-
category=FutureWarning,
19-
)
20-
class ExpSmoothingSeriesTransformer(ExponentialSmoothing):
13+
class ExpSmoothingSeriesTransformer(BaseSeriesTransformer):
2114
"""Filter a time series using exponential smoothing.
2215
2316
- Exponential smoothing (EXP) is a generalisaton of moving average smoothing that
@@ -61,4 +54,42 @@ class ExpSmoothingSeriesTransformer(ExponentialSmoothing):
6154
[10. 9.5 8.75 7.875]]
6255
"""
6356

64-
...
57+
_tags = {
58+
"capability:multivariate": True,
59+
"X_inner_type": "np.ndarray",
60+
"fit_is_empty": True,
61+
}
62+
63+
def __init__(
64+
self, alpha: float = 0.2, window_size: Union[int, float, None] = None
65+
) -> None:
66+
if not 0 <= alpha <= 1:
67+
raise ValueError(f"alpha must be in range [0, 1], got {alpha}")
68+
if window_size is not None and window_size <= 0:
69+
raise ValueError(f"window_size must be > 0, got {window_size}")
70+
super().__init__(axis=1)
71+
self.alpha = alpha if window_size is None else 2.0 / (window_size + 1)
72+
self.window_size = window_size
73+
74+
def _transform(self, X, y=None):
75+
"""Transform X and return a transformed version.
76+
77+
private _transform containing core logic, called from transform
78+
79+
Parameters
80+
----------
81+
X : np.ndarray
82+
Data to be transformed
83+
y : ignored argument for interface compatibility
84+
Additional data, e.g., labels for transformation
85+
86+
Returns
87+
-------
88+
Xt: 2D np.ndarray
89+
transformed version of X
90+
"""
91+
Xt = np.zeros_like(X, dtype="float")
92+
Xt[:, 0] = X[:, 0]
93+
for i in range(1, Xt.shape[1]):
94+
Xt[:, i] = self.alpha * X[:, i] + (1 - self.alpha) * Xt[:, i - 1]
95+
return Xt

aeon/transformations/series/_gauss.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,12 @@
44
__all__ = ["GaussSeriesTransformer"]
55

66

7-
from deprecated.sphinx import deprecated
7+
from scipy.ndimage import gaussian_filter1d
88

9-
from aeon.transformations.series.smoothing import GaussianFilter
9+
from aeon.transformations.series.base import BaseSeriesTransformer
1010

1111

12-
# TODO: Remove in v1.3.0
13-
@deprecated(
14-
version="1.2.0",
15-
reason="GaussSeriesTransformer is deprecated and will be removed in v1.3.0. "
16-
"Please use GaussianFilter from "
17-
"transformations.series.smoothing instead.",
18-
category=FutureWarning,
19-
)
20-
class GaussSeriesTransformer(GaussianFilter):
12+
class GaussSeriesTransformer(BaseSeriesTransformer):
2113
"""Filter a times series using Gaussian filter.
2214
2315
Parameters
@@ -53,4 +45,31 @@ class GaussSeriesTransformer(GaussianFilter):
5345
(2, 100)
5446
"""
5547

56-
...
48+
_tags = {
49+
"capability:multivariate": True,
50+
"X_inner_type": "np.ndarray",
51+
"fit_is_empty": True,
52+
}
53+
54+
def __init__(self, sigma=1, order=0):
55+
self.sigma = sigma
56+
self.order = order
57+
super().__init__(axis=1)
58+
59+
def _transform(self, X, y=None):
60+
"""Transform X and return a transformed version.
61+
62+
Parameters
63+
----------
64+
X : np.ndarray
65+
time series in shape (n_channels, n_timepoints)
66+
y : ignored argument for interface compatibility
67+
68+
Returns
69+
-------
70+
transformed version of X
71+
"""
72+
# Compute Gaussian filter
73+
X_ = gaussian_filter1d(X, self.sigma, self.axis, self.order)
74+
75+
return X_

aeon/transformations/series/_moving_average.py

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,12 @@
33
__maintainer__ = ["Datadote"]
44
__all__ = ["MovingAverageSeriesTransformer"]
55

6+
import numpy as np
67

7-
from deprecated.sphinx import deprecated
8+
from aeon.transformations.series.base import BaseSeriesTransformer
89

9-
from aeon.transformations.series.smoothing import MovingAverage
1010

11-
12-
# TODO: Remove in v1.3.0
13-
@deprecated(
14-
version="1.2.0",
15-
reason="MovingAverageSeriesTransformer is deprecated and will be removed in "
16-
"v1.3.0. Please use MovingAverage from "
17-
"transformations.series.smoothing instead.",
18-
category=FutureWarning,
19-
)
20-
class MovingAverageSeriesTransformer(MovingAverage):
11+
class MovingAverageSeriesTransformer(BaseSeriesTransformer):
2112
"""Calculate the moving average of an array of numbers.
2213
2314
Slides a window across the input array, and returns the averages for each window.
@@ -50,4 +41,38 @@ class MovingAverageSeriesTransformer(MovingAverage):
5041
[[-2.5 -1.5 -0.5 0.5 1.5 2.5]]
5142
"""
5243

53-
...
44+
_tags = {
45+
"capability:multivariate": True,
46+
"X_inner_type": "np.ndarray",
47+
"fit_is_empty": True,
48+
}
49+
50+
def __init__(self, window_size: int = 5) -> None:
51+
super().__init__(axis=0)
52+
if window_size <= 0:
53+
raise ValueError(f"window_size must be > 0, got {window_size}")
54+
self.window_size = window_size
55+
56+
def _transform(self, X, y=None):
57+
"""Transform X and return a transformed version.
58+
59+
private _transform containing core logic, called from transform
60+
61+
Parameters
62+
----------
63+
X : np.ndarray
64+
Data to be transformed
65+
y : ignored argument for interface compatibility
66+
Additional data, e.g., labels for transformation
67+
68+
Returns
69+
-------
70+
Xt: 2D np.ndarray
71+
transformed version of X
72+
"""
73+
csum = np.cumsum(X, axis=0)
74+
csum[self.window_size :, :] = (
75+
csum[self.window_size :, :] - csum[: -self.window_size, :]
76+
)
77+
Xt = csum[self.window_size - 1 :, :] / self.window_size
78+
return Xt

aeon/transformations/series/_sg.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,12 @@
44
__all__ = ["SGSeriesTransformer"]
55

66

7-
from deprecated.sphinx import deprecated
7+
from scipy.signal import savgol_filter
88

9-
from aeon.transformations.series.smoothing import SavitzkyGolayFilter
9+
from aeon.transformations.series.base import BaseSeriesTransformer
1010

1111

12-
# TODO: Remove in v1.3.0
13-
@deprecated(
14-
version="1.2.0",
15-
reason="SGSeriesTransformer is deprecated and will be removed in v1.3.0. "
16-
"Please use SavitzkyGolayFilter from "
17-
"transformations.series.smoothing instead.",
18-
category=FutureWarning,
19-
)
20-
class SGSeriesTransformer(SavitzkyGolayFilter):
12+
class SGSeriesTransformer(BaseSeriesTransformer):
2113
"""Filter a times series using Savitzky-Golay (SG).
2214
2315
Parameters
@@ -53,4 +45,31 @@ class SGSeriesTransformer(SavitzkyGolayFilter):
5345
(2, 100)
5446
"""
5547

56-
...
48+
_tags = {
49+
"capability:multivariate": True,
50+
"X_inner_type": "np.ndarray",
51+
"fit_is_empty": True,
52+
}
53+
54+
def __init__(self, window_length=5, polyorder=2):
55+
self.window_length = window_length
56+
self.polyorder = polyorder
57+
super().__init__(axis=1)
58+
59+
def _transform(self, X, y=None):
60+
"""Transform X and return a transformed version.
61+
62+
Parameters
63+
----------
64+
X : np.ndarray
65+
time series in shape (n_channels, n_timepoints)
66+
y : ignored argument for interface compatibility
67+
68+
Returns
69+
-------
70+
transformed version of X
71+
"""
72+
# Compute SG
73+
X_ = savgol_filter(X, self.window_length, self.polyorder)
74+
75+
return X_

aeon/transformations/series/_siv.py

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,13 @@
44
__all__ = ["SIVSeriesTransformer"]
55

66

7-
from deprecated.sphinx import deprecated
7+
import numpy as np
8+
from scipy.ndimage import median_filter
89

9-
from aeon.transformations.series.smoothing import RecursiveMedianSieve
10+
from aeon.transformations.series.base import BaseSeriesTransformer
1011

1112

12-
# TODO: Remove in v1.3.0
13-
@deprecated(
14-
version="1.2.0",
15-
reason="SIVSeriesTransformer is deprecated and will be removed in v1.3.0. "
16-
"Please use RecursiveMedianSieve from "
17-
"transformations.series.smoothing instead.",
18-
category=FutureWarning,
19-
)
20-
class SIVSeriesTransformer(RecursiveMedianSieve):
13+
class SIVSeriesTransformer(BaseSeriesTransformer):
2114
"""Filter a times series using Recursive Median Sieve (SIV).
2215
2316
Parameters
@@ -55,4 +48,40 @@ class SIVSeriesTransformer(RecursiveMedianSieve):
5548
(2, 100)
5649
"""
5750

58-
...
51+
_tags = {
52+
"capability:multivariate": True,
53+
"X_inner_type": "np.ndarray",
54+
"fit_is_empty": True,
55+
}
56+
57+
def __init__(self, window_length=None):
58+
self.window_length = window_length
59+
super().__init__(axis=1)
60+
61+
def _transform(self, X, y=None):
62+
"""Transform X and return a transformed version.
63+
64+
Parameters
65+
----------
66+
X : np.ndarray
67+
time series in shape (n_channels, n_timepoints)
68+
y : ignored argument for interface compatibility
69+
70+
Returns
71+
-------
72+
transformed version of X
73+
"""
74+
window_length = self.window_length
75+
if window_length is None:
76+
window_length = [3, 5, 7]
77+
if not isinstance(window_length, list):
78+
window_length = [window_length]
79+
80+
# Compute SIV
81+
X_ = X
82+
83+
for w in window_length:
84+
footprint = np.ones((1, w))
85+
X_ = median_filter(X_, footprint=footprint)
86+
87+
return X_

0 commit comments

Comments
 (0)