Skip to content

Commit 3022d7b

Browse files
committed
iterative
1 parent 551641a commit 3022d7b

File tree

1 file changed

+38
-17
lines changed

1 file changed

+38
-17
lines changed

aeon/forecasting/_ets.py

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def _get_int(x):
193193
self._gamma,
194194
self.phi,
195195
)
196-
self.forecast_ = _predict(
196+
self.forecast_ = _numba_predict(
197197
self._trend_type,
198198
self._seasonality_type,
199199
self.level_,
@@ -240,6 +240,29 @@ def _initialise(self, data):
240240
self._trend_type, self._seasonality_type, self._seasonal_period, data
241241
)
242242

243+
def iterative_forecast(self, y, prediction_horizon):
244+
"""Forecast with ETS specific iterative method.
245+
246+
Overrides the base class iterative_forecast to avoid refitting on each step.
247+
This simply rolls the ETS model forward
248+
"""
249+
self.fit(y)
250+
preds = np.zeros(prediction_horizon)
251+
preds[0] = self.forecast_
252+
for i in range(1, prediction_horizon):
253+
preds[i] = _numba_predict(
254+
self._trend_type,
255+
self._seasonality_type,
256+
self.level_,
257+
self.trend_,
258+
self.seasonality_,
259+
self.phi,
260+
i + 1,
261+
self.n_timepoints_,
262+
self._seasonal_period,
263+
)
264+
return preds
265+
243266

244267
@njit(fastmath=True, cache=True)
245268
def _numba_fit(
@@ -268,20 +291,18 @@ def _numba_fit(
268291
time_point = data[index]
269292

270293
# Calculate level, trend, and seasonal components
271-
fitted_value, error, level, trend, seasonality[t % seasonal_period] = (
272-
_update_states(
273-
error_type,
274-
trend_type,
275-
seasonality_type,
276-
level,
277-
trend,
278-
seasonality[s_index],
279-
time_point,
280-
alpha,
281-
beta,
282-
gamma,
283-
phi,
284-
)
294+
fitted_value, error, level, trend, seasonality[s_index] = _update_states(
295+
error_type,
296+
trend_type,
297+
seasonality_type,
298+
level,
299+
trend,
300+
seasonality[s_index],
301+
time_point,
302+
alpha,
303+
beta,
304+
gamma,
305+
phi,
285306
)
286307
residuals_[t] = error
287308
fitted_values_[t] = fitted_value
@@ -314,7 +335,7 @@ def _numba_fit(
314335

315336

316337
@njit(fastmath=True, cache=True)
317-
def _predict(
338+
def _numba_predict(
318339
trend_type,
319340
seasonality_type,
320341
level,
@@ -392,7 +413,7 @@ def _update_states(
392413
level,
393414
trend,
394415
seasonality,
395-
data_item: int,
416+
data_item,
396417
alpha,
397418
beta,
398419
gamma,

0 commit comments

Comments
 (0)