Skip to content

[ENH] Rework ETS Forecaster #2939

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jul 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 43 additions & 21 deletions aeon/forecasting/_ets.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""ETSForecaster class.

An implementation of the exponential smoothing statistics forecasting algorithm.
Implements additive and multiplicative error models.
Implements additive and multiplicative error models. We recommend using the AutoETS
version, but this is useful for demonstrations.
"""

__maintainer__ = []
Expand Down Expand Up @@ -85,7 +86,7 @@ class ETSForecaster(BaseForecaster):
... seasonality_type='multiplicative', seasonal_period=4
... )
>>> forecaster.forecast(y)
366.90200486015596
365.5141941111267
"""

_tags = {
Expand Down Expand Up @@ -193,7 +194,7 @@ def _get_int(x):
self._gamma,
self.phi,
)
self.forecast_ = _predict(
self.forecast_ = _numba_predict(
self._trend_type,
self._seasonality_type,
self.level_,
Expand Down Expand Up @@ -240,6 +241,29 @@ def _initialise(self, data):
self._trend_type, self._seasonality_type, self._seasonal_period, data
)

def iterative_forecast(self, y, prediction_horizon):
"""Forecast with ETS specific iterative method.

Overrides the base class iterative_forecast to avoid refitting on each step.
This simply rolls the ETS model forward
"""
self.fit(y)
preds = np.zeros(prediction_horizon)
preds[0] = self.forecast_
for i in range(1, prediction_horizon):
preds[i] = _numba_predict(
self._trend_type,
self._seasonality_type,
self.level_,
self.trend_,
self.seasonality_,
self.phi,
i + 1,
self.n_timepoints_,
self._seasonal_period,
)
return preds


@njit(fastmath=True, cache=True)
def _numba_fit(
Expand Down Expand Up @@ -268,20 +292,18 @@ def _numba_fit(
time_point = data[index]

# Calculate level, trend, and seasonal components
fitted_value, error, level, trend, seasonality[t % seasonal_period] = (
_update_states(
error_type,
trend_type,
seasonality_type,
level,
trend,
seasonality[s_index],
time_point,
alpha,
beta,
gamma,
phi,
)
fitted_value, error, level, trend, seasonality[s_index] = _update_states(
error_type,
trend_type,
seasonality_type,
level,
trend,
seasonality[s_index],
time_point,
alpha,
beta,
gamma,
phi,
)
residuals_[t] = error
fitted_values_[t] = fitted_value
Expand Down Expand Up @@ -314,7 +336,7 @@ def _numba_fit(


@njit(fastmath=True, cache=True)
def _predict(
def _numba_predict(
trend_type,
seasonality_type,
level,
Expand All @@ -327,11 +349,11 @@ def _predict(
):
# Generate forecasts based on the final values of level, trend, and seasonals
if phi == 1: # No damping case
phi_h = 1
phi_h = horizon
else:
# Geometric series formula for calculating phi + phi^2 + ... + phi^h
phi_h = phi * (1 - phi**horizon) / (1 - phi)
seasonal_index = (n_timepoints + horizon) % seasonal_period
seasonal_index = (n_timepoints + horizon - 1) % seasonal_period
return _predict_value(
trend_type,
seasonality_type,
Expand Down Expand Up @@ -392,7 +414,7 @@ def _update_states(
level,
trend,
seasonality,
data_item: int,
data_item,
alpha,
beta,
gamma,
Expand Down
8 changes: 4 additions & 4 deletions aeon/forecasting/tests/test_ets.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
seasonality_type="additive",
seasonal_period=4,
),
9.191190608800001,
11.456563248800002,
),
(
dict(
Expand All @@ -37,7 +37,7 @@
seasonality_type="additive",
seasonal_period=4,
),
16.20176819429869,
15.507105356706465,
),
(
dict(
Expand All @@ -50,7 +50,7 @@
seasonality_type="multiplicative",
seasonal_period=4,
),
12.301259229712382,
13.168538863095991,
),
(
dict(
Expand All @@ -63,7 +63,7 @@
seasonality_type="multiplicative",
seasonal_period=4,
),
16.811888294476528,
15.223040987015944,
),
],
)
Expand Down
44 changes: 44 additions & 0 deletions examples/forecasting/ets.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
{
"cells": [
{
"metadata": {},
"cell_type": "markdown",
"source": [
"# ETS Forecasting\n",
"\n",
"Examples for ETS notebook here\n",
"\n",
"`aeon` has native a implementation of ETS and will soon add Auto ETS. Details on this\n",
" will be forthcoming when the implementation is stable"
],
"id": "70dcc762ca5e1b6c"
},
{
"metadata": {},
"cell_type": "markdown",
"source": "",
"id": "f1fa7d726479eff9"
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
318 changes: 202 additions & 116 deletions examples/forecasting/iterative.ipynb

Large diffs are not rendered by default.