@@ -193,7 +193,7 @@ def _get_int(x):
193
193
self ._gamma ,
194
194
self .phi ,
195
195
)
196
- self .forecast_ = _predict (
196
+ self .forecast_ = _numba_predict (
197
197
self ._trend_type ,
198
198
self ._seasonality_type ,
199
199
self .level_ ,
@@ -240,6 +240,29 @@ def _initialise(self, data):
240
240
self ._trend_type , self ._seasonality_type , self ._seasonal_period , data
241
241
)
242
242
243
+ def iterative_forecast (self , y , prediction_horizon ):
244
+ """Forecast with ETS specific iterative method.
245
+
246
+ Overrides the base class iterative_forecast to avoid refitting on each step.
247
+ This simply rolls the ETS model forward
248
+ """
249
+ self .fit (y )
250
+ preds = np .zeros (prediction_horizon )
251
+ preds [0 ] = self .forecast_
252
+ for i in range (1 , prediction_horizon ):
253
+ preds [i ] = _numba_predict (
254
+ self ._trend_type ,
255
+ self ._seasonality_type ,
256
+ self .level_ ,
257
+ self .trend_ ,
258
+ self .seasonality_ ,
259
+ self .phi ,
260
+ i + 1 ,
261
+ self .n_timepoints_ ,
262
+ self ._seasonal_period ,
263
+ )
264
+ return preds
265
+
243
266
244
267
@njit (fastmath = True , cache = True )
245
268
def _numba_fit (
@@ -268,20 +291,18 @@ def _numba_fit(
268
291
time_point = data [index ]
269
292
270
293
# Calculate level, trend, and seasonal components
271
- fitted_value , error , level , trend , seasonality [t % seasonal_period ] = (
272
- _update_states (
273
- error_type ,
274
- trend_type ,
275
- seasonality_type ,
276
- level ,
277
- trend ,
278
- seasonality [s_index ],
279
- time_point ,
280
- alpha ,
281
- beta ,
282
- gamma ,
283
- phi ,
284
- )
294
+ fitted_value , error , level , trend , seasonality [s_index ] = _update_states (
295
+ error_type ,
296
+ trend_type ,
297
+ seasonality_type ,
298
+ level ,
299
+ trend ,
300
+ seasonality [s_index ],
301
+ time_point ,
302
+ alpha ,
303
+ beta ,
304
+ gamma ,
305
+ phi ,
285
306
)
286
307
residuals_ [t ] = error
287
308
fitted_values_ [t ] = fitted_value
@@ -314,7 +335,7 @@ def _numba_fit(
314
335
315
336
316
337
@njit (fastmath = True , cache = True )
317
- def _predict (
338
+ def _numba_predict (
318
339
trend_type ,
319
340
seasonality_type ,
320
341
level ,
@@ -392,7 +413,7 @@ def _update_states(
392
413
level ,
393
414
trend ,
394
415
seasonality ,
395
- data_item : int ,
416
+ data_item ,
396
417
alpha ,
397
418
beta ,
398
419
gamma ,
0 commit comments