Skip to content

Commit 551641a

Browse files
authored
[ENH] Clone estimator in direct forecast (#2936)
* clone estimator * clone estimator * tests * test update
1 parent 3e34210 commit 551641a

File tree

2 files changed

+18
-4
lines changed

2 files changed

+18
-4
lines changed

aeon/forecasting/base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import pandas as pd
1515

1616
from aeon.base import BaseSeriesEstimator
17+
from aeon.base._base import _clone_estimator
1718
from aeon.utils.data_types import VALID_SERIES_INNER_TYPES
1819

1920

@@ -203,8 +204,9 @@ def direct_forecast(self, y, prediction_horizon, exog=None):
203204

204205
preds = np.zeros(prediction_horizon)
205206
for i in range(0, prediction_horizon):
206-
self.horizon = i + 1
207-
preds[i] = self.forecast(y, exog)
207+
f = _clone_estimator(self)
208+
f.horizon = i + 1
209+
preds[i] = f.forecast(y, exog)
208210
return preds
209211

210212
def iterative_forecast(self, y, prediction_horizon):

aeon/forecasting/tests/test_base.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
import pandas as pd
55
import pytest
66

7-
from aeon.forecasting import BaseForecaster, NaiveForecaster, RegressionForecaster
7+
from aeon.forecasting import NaiveForecaster, RegressionForecaster
8+
from aeon.forecasting.base import BaseForecaster
89

910

1011
def test_base_forecaster():
@@ -55,7 +56,7 @@ def test_direct_forecast():
5556

5657

5758
def test_iterative_forecast():
58-
"""Test iterative forecasting."""
59+
"""Test terativeforecasting."""
5960
y = np.random.rand(50)
6061
f = RegressionForecaster(window=4)
6162
preds = f.iterative_forecast(y, prediction_horizon=10)
@@ -67,6 +68,17 @@ def test_iterative_forecast():
6768
y = np.append(y, p)
6869

6970

71+
def test_output_equivalence():
72+
"""Test output same for one ahead forecast."""
73+
y = np.random.rand(50)
74+
f = RegressionForecaster(window=4)
75+
p1 = f.forecast(y)
76+
p2 = f.fit(y).predict(y)
77+
p3 = f.iterative_forecast(y, 1)
78+
p4 = f.direct_forecast(y, 1)
79+
assert np.allclose(p1, p2, p3[0], p4[0])
80+
81+
7082
def test_direct_forecast_with_exog():
7183
"""Test direct forecasting with exogenous variables."""
7284
y = np.arange(50)

0 commit comments

Comments
 (0)