Skip to content

Commit e2fc23b

Browse files
authored
[BUG] Forecasting regressor fix and notebook (#2885)
* stop returning an array * notebook * Create window.png * fix test and typo * remove unfinished sentence
1 parent a987498 commit e2fc23b

File tree

7 files changed

+441
-195
lines changed

7 files changed

+441
-195
lines changed

aeon/classification/compose/_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class ClassifierPipeline(BaseCollectionPipeline, BaseClassifier):
4040
A transform or list of transformers to use prior to classification.
4141
List of tuples (str, transformer) of transformers can also be passed, where
4242
the str is used to name the transformer.
43-
The objecst are cloned prior, as such the state of the input will not be
43+
The objects are cloned prior, as such the state of the input will not be
4444
modified by fitting the pipeline.
4545
estimator : aeon or sklearn classifier
4646
A classifier to use at the end of the pipeline.

aeon/forecasting/_regression.py

Lines changed: 7 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@ def _fit(self, y, exog=None):
6666
else:
6767
self.regressor_ = self.regressor
6868
y = y.squeeze()
69+
if self.window < 1 or self.window > len(y) - 3:
70+
raise ValueError(
71+
f" window value {self.window} is invalid for series " f"length {len(y)}"
72+
)
6973
X = np.lib.stride_tricks.sliding_window_view(y, window_shape=self.window)
7074
# Ignore the final horizon values: need to store these for pred with empty y
7175
X = X[: -self.horizon]
@@ -91,35 +95,13 @@ def _predict(self, y=None, exog=None):
9195
9296
Returns
9397
-------
94-
np.ndarray
98+
float
9599
single prediction self.horizon steps ahead of y.
96100
"""
97101
if y is None:
98-
return self.regressor_.predict(self.last_)
102+
return self.regressor_.predict(self.last_)[0]
99103
last = y[:, -self.window :]
100-
return self.regressor_.predict(last)
101-
102-
def _forecast(self, y, exog=None):
103-
"""
104-
Forecast the next horizon steps ahead.
105-
106-
Parameters
107-
----------
108-
y : np.ndarray
109-
A time series to predict the next horizon value for.
110-
exog : np.ndarray, default=None
111-
Optional exogenous time series data. Included for interface
112-
compatibility but ignored in this estimator.
113-
114-
Returns
115-
-------
116-
np.ndarray
117-
single prediction self.horizon steps ahead of y.
118-
119-
NOTE: deal with horizons
120-
"""
121-
self.fit(y, exog)
122-
return self.predict()
104+
return self.regressor_.predict(last)[0]
123105

124106
@classmethod
125107
def _get_test_params(cls, parameter_set: str = "default"):

aeon/forecasting/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,8 @@ def predict(self, y=None, exog=None):
111111
y = self._convert_y(y, self.axis)
112112
if exog is not None:
113113
raise NotImplementedError("Exogenous variables not yet supported")
114-
return self._predict(y, exog)
114+
x = self._predict(y, exog)
115+
return x
115116

116117
@abstractmethod
117118
def _predict(self, y=None, exog=None): ...

aeon/forecasting/tests/test_regressor.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
"""Test the regression forecaster."""
22

3+
import numpy as np
4+
import pytest
35
from sklearn.linear_model import LinearRegression
46

5-
from aeon.datasets import load_airline
67
from aeon.forecasting import RegressionForecaster
8+
from aeon.regression import DummyRegressor
79

810

911
def test_regression_forecaster():
1012
"""Test the regression forecaster."""
11-
y = load_airline()
13+
y = np.random.rand(100)
1214
f = RegressionForecaster(window=10)
1315
f.fit(y)
1416
p = f.predict()
@@ -20,3 +22,13 @@ def test_regression_forecaster():
2022
f2.fit(y)
2123
p2 = f2.predict()
2224
assert p == p2
25+
f2 = RegressionForecaster(regressor=DummyRegressor(), window=10)
26+
f2.fit(y)
27+
f2.predict()
28+
29+
with pytest.raises(ValueError):
30+
f = RegressionForecaster(window=-1)
31+
f.fit(y)
32+
with pytest.raises(ValueError):
33+
f = RegressionForecaster(window=101)
34+
f.fit(y)

0 commit comments

Comments
 (0)