Skip to content

Commit 8f1830a

Browse files
authored
[ENH] remove horizon from ETS constructor, parametrize ETS tests (#2898)
* remove horizon from ETS argument, parametrize tests * remove horizon from example * docstring
1 parent c86d391 commit 8f1830a

File tree

2 files changed

+64
-85
lines changed

2 files changed

+64
-85
lines changed

aeon/forecasting/_ets.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,6 @@ class ETSForecaster(BaseForecaster):
4444
Seasonal smoothing parameter.
4545
phi : float, default=0.99
4646
Trend damping parameter (used only for damped trend models).
47-
horizon : int, default=1
48-
Forecasting horizon (number of time steps ahead to predict).
4947
5048
Attributes
5149
----------
@@ -82,7 +80,7 @@ class ETSForecaster(BaseForecaster):
8280
>>> from aeon.datasets import load_airline
8381
>>> y = load_airline()
8482
>>> forecaster = ETSForecaster(
85-
... alpha=0.4, beta=0.2, gamma=0.5, phi=0.8, horizon=1,
83+
... alpha=0.4, beta=0.2, gamma=0.5, phi=0.8,
8684
... error_type='additive', trend_type='multiplicative',
8785
... seasonality_type='multiplicative', seasonal_period=4
8886
... )
@@ -106,7 +104,6 @@ def __init__(
106104
beta: float = 0.01,
107105
gamma: float = 0.01,
108106
phi: float = 0.99,
109-
horizon: int = 1,
110107
):
111108
self.alpha = alpha
112109
self.beta = beta
@@ -130,7 +127,7 @@ def __init__(
130127
self.aic_ = 0
131128
self.residuals_ = []
132129
self.fitted_values_ = []
133-
super().__init__(horizon=horizon, axis=1)
130+
super().__init__(horizon=1, axis=1)
134131

135132
def _fit(self, y, exog=None):
136133
"""Fit Exponential Smoothing forecaster to series y.

aeon/forecasting/tests/test_ets.py

Lines changed: 62 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -10,88 +10,70 @@
1010
from aeon.forecasting._ets import _validate_parameter
1111

1212

13-
def test_ets_forecaster_additive():
14-
"""TestETSForecaster."""
15-
data = np.array(
16-
[3, 10, 12, 13, 12, 10, 12, 3, 10, 12, 13, 12, 10, 12]
17-
) # Sample seasonal data
18-
forecaster = ETSForecaster(
19-
alpha=0.5,
20-
beta=0.3,
21-
gamma=0.4,
22-
phi=1,
23-
horizon=1,
24-
error_type="additive",
25-
trend_type="additive",
26-
seasonality_type="additive",
27-
seasonal_period=4,
28-
)
13+
@pytest.mark.parametrize(
14+
"params, expected",
15+
[
16+
(
17+
dict(
18+
alpha=0.5,
19+
beta=0.3,
20+
gamma=0.4,
21+
phi=1,
22+
error_type="additive",
23+
trend_type="additive",
24+
seasonality_type="additive",
25+
seasonal_period=4,
26+
),
27+
9.191190608800001,
28+
),
29+
(
30+
dict(
31+
alpha=0.7,
32+
beta=0.6,
33+
gamma=0.1,
34+
phi=0.97,
35+
error_type="multiplicative",
36+
trend_type="additive",
37+
seasonality_type="additive",
38+
seasonal_period=4,
39+
),
40+
16.20176819429869,
41+
),
42+
(
43+
dict(
44+
alpha=0.4,
45+
beta=0.2,
46+
gamma=0.5,
47+
phi=0.8,
48+
error_type="additive",
49+
trend_type="multiplicative",
50+
seasonality_type="multiplicative",
51+
seasonal_period=4,
52+
),
53+
12.301259229712382,
54+
),
55+
(
56+
dict(
57+
alpha=0.7,
58+
beta=0.5,
59+
gamma=0.2,
60+
phi=0.85,
61+
error_type="multiplicative",
62+
trend_type="multiplicative",
63+
seasonality_type="multiplicative",
64+
seasonal_period=4,
65+
),
66+
16.811888294476528,
67+
),
68+
],
69+
)
70+
def test_ets_forecaster(params, expected):
71+
"""Test ETSForecaster for multiple parameter combinations."""
72+
data = np.array([3, 10, 12, 13, 12, 10, 12, 3, 10, 12, 13, 12, 10, 12])
73+
forecaster = ETSForecaster(**params)
2974
forecaster.fit(data)
3075
p = forecaster.predict()
31-
assert np.isclose(p, 9.191190608800001)
32-
33-
34-
def test_ets_forecaster_mult_error():
35-
"""TestETSForecaster."""
36-
data = np.array(
37-
[3, 10, 12, 13, 12, 10, 12, 3, 10, 12, 13, 12, 10, 12]
38-
) # Sample seasonal data
39-
forecaster = ETSForecaster(
40-
alpha=0.7,
41-
beta=0.6,
42-
gamma=0.1,
43-
phi=0.97,
44-
horizon=1,
45-
error_type="multiplicative",
46-
trend_type="additive",
47-
seasonality_type="additive",
48-
seasonal_period=4,
49-
)
50-
forecaster.fit(data)
51-
p = forecaster.predict()
52-
assert np.isclose(p, 16.20176819429869)
53-
54-
55-
def test_ets_forecaster_mult_compnents():
56-
"""TestETSForecaster."""
57-
data = np.array(
58-
[3, 10, 12, 13, 12, 10, 12, 3, 10, 12, 13, 12, 10, 12]
59-
) # Sample seasonal data
60-
forecaster = ETSForecaster(
61-
alpha=0.4,
62-
beta=0.2,
63-
gamma=0.5,
64-
phi=0.8,
65-
horizon=1,
66-
error_type="additive",
67-
trend_type="multiplicative",
68-
seasonality_type="multiplicative",
69-
seasonal_period=4,
70-
)
71-
forecaster.fit(data)
72-
p = forecaster.predict()
73-
assert np.isclose(p, 12.301259229712382)
74-
75-
76-
def test_ets_forecaster_multiplicative():
77-
"""TestETSForecaster."""
78-
data = np.array(
79-
[3, 10, 12, 13, 12, 10, 12, 3, 10, 12, 13, 12, 10, 12]
80-
) # Sample seasonal data
81-
forecaster = ETSForecaster(
82-
alpha=0.7,
83-
beta=0.5,
84-
gamma=0.2,
85-
phi=0.85,
86-
horizon=1,
87-
error_type="multiplicative",
88-
trend_type="multiplicative",
89-
seasonality_type="multiplicative",
90-
seasonal_period=4,
91-
)
92-
forecaster.fit(data)
93-
p = forecaster.predict()
94-
assert np.isclose(p, 16.811888294476528)
76+
assert np.isclose(p, expected)
9577

9678

9779
def test_incorrect_parameters():

0 commit comments

Comments
 (0)