Skip to content

Commit 3e34210

Browse files
authored
[ENH] Refactor to set forecast_ in fit and to require y in predict (#2927)
* require y in predict * forecast_ * remove calls to predict * doc examples * regression forecaster rework * other tests * docstrings * refactor forecasting checks * test * test * regression notebook * test_ets * Revert "test_ets" This reverts commit 827a12a. * Revert "regression notebook" This reverts commit 2797802. * ets * notebook * naive testing
1 parent e9401fd commit 3e34210

File tree

11 files changed

+334
-194
lines changed

11 files changed

+334
-194
lines changed

aeon/forecasting/_ets.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,7 @@ class ETSForecaster(BaseForecaster):
8484
... error_type='additive', trend_type='multiplicative',
8585
... seasonality_type='multiplicative', seasonal_period=4
8686
... )
87-
>>> forecaster.fit(y)
88-
ETSForecaster(...)
89-
>>> forecaster.predict()
87+
>>> forecaster.forecast(y)
9088
366.90200486015596
9189
"""
9290

@@ -195,9 +193,21 @@ def _get_int(x):
195193
self._gamma,
196194
self.phi,
197195
)
196+
self.forecast_ = _predict(
197+
self._trend_type,
198+
self._seasonality_type,
199+
self.level_,
200+
self.trend_,
201+
self.seasonality_,
202+
self.phi,
203+
self.horizon,
204+
self.n_timepoints_,
205+
self._seasonal_period,
206+
)
207+
198208
return self
199209

200-
def _predict(self, y=None, exog=None):
210+
def _predict(self, y, exog=None):
201211
"""
202212
Predict the next horizon steps ahead.
203213
@@ -214,18 +224,7 @@ def _predict(self, y=None, exog=None):
214224
float
215225
single prediction self.horizon steps ahead of y.
216226
"""
217-
fitted_value = _predict(
218-
self._trend_type,
219-
self._seasonality_type,
220-
self.level_,
221-
self.trend_,
222-
self.seasonality_,
223-
self.phi,
224-
self.horizon,
225-
self.n_timepoints_,
226-
self._seasonal_period,
227-
)
228-
return fitted_value
227+
return self.forecast_
229228

230229
def _initialise(self, data):
231230
"""

aeon/forecasting/_naive.py

Lines changed: 19 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -41,39 +41,33 @@ def _fit(self, y, exog=None):
4141
y_squeezed = y.squeeze()
4242

4343
if self.strategy == "last":
44-
self._fitted_scalar_value = y_squeezed[-1]
44+
self.forecast_ = y_squeezed[-1]
4545
elif self.strategy == "mean":
46-
self._fitted_scalar_value = np.mean(y_squeezed)
46+
self.forecast_ = np.mean(y_squeezed)
4747
elif self.strategy == "seasonal_last":
48-
self._fitted_last_season = y_squeezed[-self.seasonal_period :]
48+
season = y_squeezed[-self.seasonal_period :]
49+
idx = (self.horizon - 1) % self.seasonal_period
50+
self.forecast_ = season[idx]
4951
else:
5052
raise ValueError(
5153
f"Unknown strategy: {self.strategy}. "
5254
"Valid strategies are 'last', 'mean', 'seasonal_last'."
5355
)
5456
return self
5557

56-
def _predict(self, y=None, exog=None):
57-
if y is None:
58-
if self.strategy == "last" or self.strategy == "mean":
59-
return self._fitted_scalar_value
58+
def _predict(self, y, exog=None):
59+
y_squeezed = y.squeeze()
6060

61-
# For "seasonal_last" strategy
62-
prediction_index = (self.horizon - 1) % self.seasonal_period
63-
return self._fitted_last_season[prediction_index]
61+
if self.strategy == "last":
62+
return y_squeezed[-1]
63+
elif self.strategy == "mean":
64+
return np.mean(y_squeezed)
65+
elif self.strategy == "seasonal_last":
66+
period = y_squeezed[-self.seasonal_period :]
67+
idx = (self.horizon - 1) % self.seasonal_period
68+
return period[idx]
6469
else:
65-
y_squeezed = y.squeeze()
66-
67-
if self.strategy == "last":
68-
return y_squeezed[-1]
69-
elif self.strategy == "mean":
70-
return np.mean(y_squeezed)
71-
elif self.strategy == "seasonal_last":
72-
period = y_squeezed[-self.seasonal_period :]
73-
idx = (self.horizon - 1) % self.seasonal_period
74-
return period[idx]
75-
else:
76-
raise ValueError(
77-
f"Unknown strategy: {self.strategy}. "
78-
"Valid strategies are 'last', 'mean', 'seasonal_last'."
79-
)
70+
raise ValueError(
71+
f"Unknown strategy: {self.strategy}. "
72+
"Valid strategies are 'last', 'mean', 'seasonal_last'."
73+
)

aeon/forecasting/_regression.py

Lines changed: 44 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,10 @@ def _fit(self, y, exog=None):
7070
self.regressor_ = LinearRegression()
7171
else:
7272
self.regressor_ = self.regressor
73-
73+
self._n_exog = 0
7474
# Combine y and exog for windowing
7575
if exog is not None:
76+
self._n_exog = exog.shape[0]
7677
if exog.ndim == 1:
7778
exog = exog.reshape(1, -1)
7879
if exog.shape[1] != y.shape[1]:
@@ -96,16 +97,18 @@ def _fit(self, y, exog=None):
9697
X = X[:, :, :].reshape(X.shape[0], -1)
9798

9899
# Ignore the final horizon values for X
99-
X = X[: -self.horizon]
100+
X_train = X[: -self.horizon]
100101

101102
# Extract y_train from the original series
102103
y_train = y.squeeze()[self.window + self.horizon - 1 :]
103104

104-
self.last_ = combined_data[:, -self.window :]
105-
self.regressor_.fit(X=X, y=y_train)
105+
self.regressor_.fit(X=X_train, y=y_train)
106+
107+
last = X[[-1]]
108+
self.forecast_ = self.regressor_.predict(last)[0]
106109
return self
107110

108-
def _predict(self, y=None, exog=None):
111+
def _predict(self, y, exog=None):
109112
"""
110113
Predict the next horizon steps ahead.
111114
@@ -122,26 +125,52 @@ def _predict(self, y=None, exog=None):
122125
float
123126
single prediction self.horizon steps ahead of y.
124127
"""
125-
if y is None:
126-
# Flatten the last window to be compatible with sklearn regressors
127-
last_window_flat = self.last_.reshape(1, -1)
128-
return self.regressor_.predict(last_window_flat)[0]
129-
128+
y = y[:, -self.window :]
129+
y = y.squeeze()
130+
# Test data compliant for regression based
131+
if len(y) < self.window:
132+
raise ValueError(
133+
f" Series passed in predict length = {len(y)} but this "
134+
f"RegressionForecaster was trained on window length = "
135+
f"{self.window}"
136+
)
130137
# Combine y and exog for prediction
131138
if exog is not None:
139+
if exog.shape[0] != self._n_exog:
140+
raise ValueError(
141+
f" Forecaster passed {exog.shape[0]} exogenous variables in "
142+
f"predict but this RegressionForecaster was trained on"
143+
f" {self._n_exog} variables in fit"
144+
)
145+
132146
if exog.ndim == 1:
133147
exog = exog.reshape(1, -1)
134-
if exog.shape[1] != y.shape[1]:
135-
raise ValueError("y and exog must have the same number of time points.")
148+
if exog.shape[1] < self.window:
149+
raise ValueError(
150+
f" Exogenous variables passed in predict of length = {len(y)} but "
151+
f"this RegressionForecaster was trained on window length = "
152+
f"{self.window}"
153+
)
154+
155+
exog = exog[:, -self.window :]
136156
combined_data = np.vstack([y, exog])
137157
else:
158+
if self._n_exog > 0:
159+
raise ValueError(
160+
f" predict passed no exogenous variables, but this "
161+
f"RegressionForecaster was trained on {self._n_exog} exog in fit"
162+
)
138163
combined_data = y
139164

140165
# Extract the last window and flatten for prediction
141-
last_window = combined_data[:, -self.window :]
142-
last_window_flat = last_window.reshape(1, -1)
166+
last_window = combined_data.reshape(1, -1)
167+
168+
return self.regressor_.predict(last_window)[0]
143169

144-
return self.regressor_.predict(last_window_flat)[0]
170+
def _forecast(self, y, exog=None):
171+
"""Forecast values for time series X."""
172+
self.fit(y, exog)
173+
return self.forecast_
145174

146175
@classmethod
147176
def _get_test_params(cls, parameter_set: str = "default"):

aeon/forecasting/base.py

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -96,14 +96,13 @@ def fit(self, y, exog=None):
9696
def _fit(self, y, exog=None): ...
9797

9898
@final
99-
def predict(self, y=None, exog=None):
99+
def predict(self, y, exog=None):
100100
"""Predict the next horizon steps ahead.
101101
102102
Parameters
103103
----------
104-
y : np.ndarray, default = None
105-
A time series to predict the next horizon value for. If None,
106-
predict the next horizon value after series seen in fit.
104+
y : np.ndarray
105+
A time series to predict the next horizon value for.
107106
exog : np.ndarray, default =None
108107
Optional exogenous time series data assumed to be aligned with y.
109108
@@ -113,28 +112,25 @@ def predict(self, y=None, exog=None):
113112
single prediction self.horizon steps ahead of y.
114113
"""
115114
self._check_is_fitted()
116-
if y is not None:
117-
self._check_X(y, self.axis)
118-
y = self._convert_y(y, self.axis)
115+
self._check_X(y, self.axis)
116+
y = self._convert_y(y, self.axis)
119117
if exog is not None:
120118
exog = self._convert_y(exog, self.axis)
121-
122119
return self._predict(y, exog)
123120

124121
@abstractmethod
125-
def _predict(self, y=None, exog=None): ...
122+
def _predict(self, y, exog=None): ...
126123

127124
@final
128125
def forecast(self, y, exog=None):
129-
"""Forecast the next horizon steps ahead.
126+
"""Forecast the next horizon steps ahead of ``y``.
130127
131-
By default this is simply fit followed by predict.
128+
By default this is simply fit followed by returning forecast_.
132129
133130
Parameters
134131
----------
135-
y : np.ndarray, default = None
136-
A time series to predict the next horizon value for. If None,
137-
predict the next horizon value after series seen in fit.
132+
y : np.ndarray
133+
A time series to predict the next horizon value for.
138134
exog : np.ndarray, default =None
139135
Optional exogenous time series data assumed to be aligned with y.
140136
@@ -150,9 +146,9 @@ def forecast(self, y, exog=None):
150146
return self._forecast(y, exog)
151147

152148
def _forecast(self, y, exog=None):
153-
"""Forecast values for time series X."""
149+
"""Forecast horizon steps ahead for time series ``y``."""
154150
self.fit(y, exog)
155-
return self._predict(y, exog)
151+
return self.forecast_
156152

157153
@final
158154
def direct_forecast(self, y, prediction_horizon, exog=None):
@@ -174,7 +170,10 @@ def direct_forecast(self, y, prediction_horizon, exog=None):
174170
The number of future time steps to forecast.
175171
exog : np.ndarray, default =None
176172
Optional exogenous time series data assumed to be aligned with y.
177-
predictions : np.ndarray
173+
174+
Returns
175+
-------
176+
np.ndarray
178177
An array of shape `(prediction_horizon,)` containing the forecasts for
179178
each horizon.
180179
@@ -210,15 +209,15 @@ def direct_forecast(self, y, prediction_horizon, exog=None):
210209

211210
def iterative_forecast(self, y, prediction_horizon):
212211
"""
213-
Forecast ``prediction_horizon`` prediction using a single model from `y`.
212+
Forecast ``prediction_horizon`` prediction using a single model fit on `y`.
214213
215214
This function implements the iterative forecasting strategy (also called
216-
recursive or iterated). This involves a single model fit on y which is then
217-
used to make ``prediction_horizon`` ahead using its own predictions as
218-
inputs for future forecasts. This is done by taking
219-
the prediction at step ``i`` and feeding it back into the model to help
220-
predict for step ``i+1``. The basic contract of
221-
`iterative_forecast` is that `fit` is only ever called once.
215+
recursive or iterated). This involves a single model fit on ``y`` which is then
216+
used to make ``prediction_horizon`` ahead forecasts using its own predictions as
217+
inputs for future forecasts. This is done by taking the prediction at step
218+
``i`` and feeding it back into the model to help predict for step ``i+1``.
219+
The basic contract of `iterative_forecast` is that `fit` is only ever called
220+
once.
222221
223222
y : np.ndarray
224223
The time series to make forecasts about.
@@ -227,7 +226,7 @@ def iterative_forecast(self, y, prediction_horizon):
227226
228227
Returns
229228
-------
230-
predictions : np.ndarray
229+
np.ndarray
231230
An array of shape `(prediction_horizon,)` containing the forecasts for
232231
each horizon.
233232

aeon/forecasting/tests/test_base.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@
44
import pandas as pd
55
import pytest
66

7-
from aeon.forecasting import NaiveForecaster, RegressionForecaster
7+
from aeon.forecasting import BaseForecaster, NaiveForecaster, RegressionForecaster
88

99

1010
def test_base_forecaster():
1111
"""Test base forecaster functionality."""
1212
f = NaiveForecaster()
1313
y = np.random.rand(50)
1414
f.fit(y)
15-
p1 = f.predict()
15+
p1 = f.predict(y)
1616
assert p1 == y[-1]
1717
p2 = f.forecast(y)
1818
p3 = f._forecast(y)
@@ -54,8 +54,8 @@ def test_direct_forecast():
5454
assert p == preds[i]
5555

5656

57-
def test_recursive_forecast():
58-
"""Test recursive forecasting."""
57+
def test_iterative_forecast():
58+
"""Test iterative forecasting."""
5959
y = np.random.rand(50)
6060
f = RegressionForecaster(window=4)
6161
preds = f.iterative_forecast(y, prediction_horizon=10)
@@ -79,3 +79,21 @@ def test_direct_forecast_with_exog():
7979
# Check that predictions are different from when no exog is used
8080
preds_no_exog = f.direct_forecast(y, prediction_horizon=10)
8181
assert not np.array_equal(preds, preds_no_exog)
82+
83+
84+
def test_fit_is_empty():
85+
"""Test empty fit."""
86+
87+
class _EmptyFit(BaseForecaster):
88+
_tags = {"fit_is_empty": True}
89+
90+
def _fit(self, y):
91+
return self
92+
93+
def _predict(self, y):
94+
return 0
95+
96+
dummy = _EmptyFit(horizon=1, axis=1)
97+
y = np.arange(50)
98+
dummy.fit(y)
99+
assert dummy.is_fitted

0 commit comments

Comments
 (0)