Skip to content

Commit 547e925

Browse files
TonyBagnallAlex Banwell
andauthored
[ENH] Rework ETS Forecaster (#2939)
* iterative * iterative ETS * notebook * Fix bug in seasonality calculation * Add example for ETS iterative forecasting * ets tests * holding notebook * example * Run Itertive notebook --------- Co-authored-by: Alex Banwell <arb1g19@soton.ac.uk>
1 parent b9b2df6 commit 547e925

File tree

4 files changed

+293
-141
lines changed

4 files changed

+293
-141
lines changed

aeon/forecasting/_ets.py

Lines changed: 43 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
"""ETSForecaster class.
22
33
An implementation of the exponential smoothing statistics forecasting algorithm.
4-
Implements additive and multiplicative error models.
4+
Implements additive and multiplicative error models. We recommend using the AutoETS
5+
version, but this is useful for demonstrations.
56
"""
67

78
__maintainer__ = []
@@ -85,7 +86,7 @@ class ETSForecaster(BaseForecaster):
8586
... seasonality_type='multiplicative', seasonal_period=4
8687
... )
8788
>>> forecaster.forecast(y)
88-
366.90200486015596
89+
365.5141941111267
8990
"""
9091

9192
_tags = {
@@ -193,7 +194,7 @@ def _get_int(x):
193194
self._gamma,
194195
self.phi,
195196
)
196-
self.forecast_ = _predict(
197+
self.forecast_ = _numba_predict(
197198
self._trend_type,
198199
self._seasonality_type,
199200
self.level_,
@@ -240,6 +241,29 @@ def _initialise(self, data):
240241
self._trend_type, self._seasonality_type, self._seasonal_period, data
241242
)
242243

244+
def iterative_forecast(self, y, prediction_horizon):
245+
"""Forecast with ETS specific iterative method.
246+
247+
Overrides the base class iterative_forecast to avoid refitting on each step.
248+
This simply rolls the ETS model forward
249+
"""
250+
self.fit(y)
251+
preds = np.zeros(prediction_horizon)
252+
preds[0] = self.forecast_
253+
for i in range(1, prediction_horizon):
254+
preds[i] = _numba_predict(
255+
self._trend_type,
256+
self._seasonality_type,
257+
self.level_,
258+
self.trend_,
259+
self.seasonality_,
260+
self.phi,
261+
i + 1,
262+
self.n_timepoints_,
263+
self._seasonal_period,
264+
)
265+
return preds
266+
243267

244268
@njit(fastmath=True, cache=True)
245269
def _numba_fit(
@@ -268,20 +292,18 @@ def _numba_fit(
268292
time_point = data[index]
269293

270294
# Calculate level, trend, and seasonal components
271-
fitted_value, error, level, trend, seasonality[t % seasonal_period] = (
272-
_update_states(
273-
error_type,
274-
trend_type,
275-
seasonality_type,
276-
level,
277-
trend,
278-
seasonality[s_index],
279-
time_point,
280-
alpha,
281-
beta,
282-
gamma,
283-
phi,
284-
)
295+
fitted_value, error, level, trend, seasonality[s_index] = _update_states(
296+
error_type,
297+
trend_type,
298+
seasonality_type,
299+
level,
300+
trend,
301+
seasonality[s_index],
302+
time_point,
303+
alpha,
304+
beta,
305+
gamma,
306+
phi,
285307
)
286308
residuals_[t] = error
287309
fitted_values_[t] = fitted_value
@@ -314,7 +336,7 @@ def _numba_fit(
314336

315337

316338
@njit(fastmath=True, cache=True)
317-
def _predict(
339+
def _numba_predict(
318340
trend_type,
319341
seasonality_type,
320342
level,
@@ -327,11 +349,11 @@ def _predict(
327349
):
328350
# Generate forecasts based on the final values of level, trend, and seasonals
329351
if phi == 1: # No damping case
330-
phi_h = 1
352+
phi_h = horizon
331353
else:
332354
# Geometric series formula for calculating phi + phi^2 + ... + phi^h
333355
phi_h = phi * (1 - phi**horizon) / (1 - phi)
334-
seasonal_index = (n_timepoints + horizon) % seasonal_period
356+
seasonal_index = (n_timepoints + horizon - 1) % seasonal_period
335357
return _predict_value(
336358
trend_type,
337359
seasonality_type,
@@ -392,7 +414,7 @@ def _update_states(
392414
level,
393415
trend,
394416
seasonality,
395-
data_item: int,
417+
data_item,
396418
alpha,
397419
beta,
398420
gamma,

aeon/forecasting/tests/test_ets.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
seasonality_type="additive",
2525
seasonal_period=4,
2626
),
27-
9.191190608800001,
27+
11.456563248800002,
2828
),
2929
(
3030
dict(
@@ -37,7 +37,7 @@
3737
seasonality_type="additive",
3838
seasonal_period=4,
3939
),
40-
16.20176819429869,
40+
15.507105356706465,
4141
),
4242
(
4343
dict(
@@ -50,7 +50,7 @@
5050
seasonality_type="multiplicative",
5151
seasonal_period=4,
5252
),
53-
12.301259229712382,
53+
13.168538863095991,
5454
),
5555
(
5656
dict(
@@ -63,7 +63,7 @@
6363
seasonality_type="multiplicative",
6464
seasonal_period=4,
6565
),
66-
16.811888294476528,
66+
15.223040987015944,
6767
),
6868
],
6969
)

examples/forecasting/ets.ipynb

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
{
2+
"cells": [
3+
{
4+
"metadata": {},
5+
"cell_type": "markdown",
6+
"source": [
7+
"# ETS Forecasting\n",
8+
"\n",
9+
"Examples for ETS notebook here\n",
10+
"\n",
11+
"`aeon` has native a implementation of ETS and will soon add Auto ETS. Details on this\n",
12+
" will be forthcoming when the implementation is stable"
13+
],
14+
"id": "70dcc762ca5e1b6c"
15+
},
16+
{
17+
"metadata": {},
18+
"cell_type": "markdown",
19+
"source": "",
20+
"id": "f1fa7d726479eff9"
21+
}
22+
],
23+
"metadata": {
24+
"kernelspec": {
25+
"display_name": "Python 3",
26+
"language": "python",
27+
"name": "python3"
28+
},
29+
"language_info": {
30+
"codemirror_mode": {
31+
"name": "ipython",
32+
"version": 2
33+
},
34+
"file_extension": ".py",
35+
"mimetype": "text/x-python",
36+
"name": "python",
37+
"nbconvert_exporter": "python",
38+
"pygments_lexer": "ipython2",
39+
"version": "2.7.6"
40+
}
41+
},
42+
"nbformat": 4,
43+
"nbformat_minor": 5
44+
}

examples/forecasting/iterative.ipynb

Lines changed: 202 additions & 116 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)