Skip to content

Commit 3dc043b

Browse files
TinaJin0228TonyBagnallMatthewMiddlehurst
authored
[ENH] enhance naive forecaster with new strategies (#2869)
* naive forecaster with new strategies * few modifications according to reviews * made modifications and add documentation * modifications according to reviews * try to pass the test * delete one comment * modify 1)more test cases for seasonal_last 2)raise exception for y input in _predict, and related documentation * fix typo * take y * Update _naive.py --------- Co-authored-by: Tony Bagnall <a.j.bagnall@soton.ac.uk> Co-authored-by: MatthewMiddlehurst <pfm15hbu@gmail.com>
1 parent a163862 commit 3dc043b

File tree

4 files changed

+214
-66
lines changed

4 files changed

+214
-66
lines changed

aeon/forecasting/_naive.py

Lines changed: 67 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,79 @@
1-
"""Naive Forecaster."""
1+
"""Naive forecaster with multiple strategies."""
2+
3+
__maintainer__ = []
4+
__all__ = ["NaiveForecaster"]
5+
6+
7+
import numpy as np
28

39
from aeon.forecasting.base import BaseForecaster
410

511

612
class NaiveForecaster(BaseForecaster):
7-
"""Naive forecaster that always predicts the last value seen in training."""
13+
"""
14+
Naive forecaster with multiple strategies and flexible horizon.
15+
16+
Parameters
17+
----------
18+
strategy : str, default="last"
19+
The forecasting strategy to use.
20+
Options: "last", "mean", "seasonal_last".
21+
- "last" predicts the last value of the input series for all horizon steps.
22+
- "mean": predicts the mean of the input series for all horizon steps.
23+
- "seasonal_last": predicts the last season value in the training series.
24+
Returns np.nan if the effective seasonal data is empty.
25+
seasonal_period : int, default=1
26+
The seasonal period to use for the "seasonal_last" strategy.
27+
E.g., 12 for monthly data with annual seasonality.
28+
horizon : int, default =1
29+
The number of time steps ahead to forecast. If horizon is one, the forecaster
30+
will learn to predict one point ahead.
31+
Only relevant for "seasonal_last".
32+
"""
833

9-
def __init__(self):
10-
"""Initialize NaiveForecaster."""
11-
self.last_value_ = None
12-
super().__init__(horizon=1, axis=1)
34+
def __init__(self, strategy="last", seasonal_period=1, horizon=1):
35+
self.strategy = strategy
36+
self.seasonal_period = seasonal_period
37+
38+
super().__init__(horizon=horizon, axis=1)
1339

1440
def _fit(self, y, exog=None):
15-
"""Fit Naive forecaster."""
16-
y = y.squeeze()
17-
self.last_value_ = y[-1]
41+
y_squeezed = y.squeeze()
42+
43+
if self.strategy == "last":
44+
self._fitted_scalar_value = y_squeezed[-1]
45+
elif self.strategy == "mean":
46+
self._fitted_scalar_value = np.mean(y_squeezed)
47+
elif self.strategy == "seasonal_last":
48+
self._fitted_last_season = y_squeezed[-self.seasonal_period :]
49+
else:
50+
raise ValueError(
51+
f"Unknown strategy: {self.strategy}. "
52+
"Valid strategies are 'last', 'mean', 'seasonal_last'."
53+
)
1854
return self
1955

2056
def _predict(self, y=None, exog=None):
21-
"""Predict using Naive forecaster."""
22-
return self.last_value_
57+
if y is None:
58+
if self.strategy == "last" or self.strategy == "mean":
59+
return self._fitted_scalar_value
60+
61+
# For "seasonal_last" strategy
62+
prediction_index = (self.horizon - 1) % self.seasonal_period
63+
return self._fitted_last_season[prediction_index]
64+
else:
65+
y_squeezed = y.squeeze()
2366

24-
def _forecast(self, y, exog=None):
25-
"""Forecast using dummy forecaster."""
26-
y = y.squeeze()
27-
return y[-1]
67+
if self.strategy == "last":
68+
return y_squeezed[-1]
69+
elif self.strategy == "mean":
70+
return np.mean(y_squeezed)
71+
elif self.strategy == "seasonal_last":
72+
period = y_squeezed[-self.seasonal_period :]
73+
idx = (self.horizon - 1) % self.seasonal_period
74+
return period[idx]
75+
else:
76+
raise ValueError(
77+
f"Unknown strategy: {self.strategy}. "
78+
"Valid strategies are 'last', 'mean', 'seasonal_last'."
79+
)

aeon/forecasting/tests/test_naive.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
"""Test Naive Forecaster."""
2+
3+
import numpy as np
4+
5+
from aeon.forecasting import NaiveForecaster
6+
7+
8+
def test_naive_forecaster_last_strategy():
9+
"""Test NaiveForecaster with 'last' strategy."""
10+
sample_data = np.array([10, 20, 30, 40, 50])
11+
forecaster = NaiveForecaster(strategy="last", horizon=3)
12+
forecaster.fit(sample_data)
13+
predictions = forecaster.predict()
14+
expected = 50
15+
np.testing.assert_array_equal(predictions, expected)
16+
17+
18+
def test_naive_forecaster_mean_strategy():
19+
"""Test NaiveForecaster with 'mean' strategy."""
20+
sample_data = np.array([10, 20, 30, 40, 50])
21+
forecaster = NaiveForecaster(strategy="mean", horizon=2)
22+
forecaster.fit(sample_data)
23+
predictions = forecaster.predict()
24+
expected = 30 # Mean of [10, 20, 30, 40, 50] is 30
25+
np.testing.assert_array_equal(predictions, expected)
26+
27+
28+
def test_naive_forecaster_seasonal_last_strategy():
29+
"""Test NaiveForecaster with 'seasonal_last' strategy."""
30+
data = np.array([1, 2, 3, 4, 5, 6, 7, 8])
31+
32+
# Last season is [6, 7, 8] for seasonal_period = 3
33+
forecaster = NaiveForecaster(strategy="seasonal_last", seasonal_period=3, horizon=4)
34+
forecaster.fit(data)
35+
pred = forecaster.predict()
36+
pred2 = forecaster.predict(y=data)
37+
expected = 6 # predicts the 1-st element of the last season.
38+
np.testing.assert_array_equal(pred, expected)
39+
np.testing.assert_array_equal(pred2, expected)
40+
41+
# Test horizon within the season length
42+
forecaster = NaiveForecaster(strategy="seasonal_last", seasonal_period=3, horizon=2)
43+
forecaster.fit(data)
44+
pred = forecaster.predict()
45+
pred2 = forecaster.predict(y=data)
46+
expected = 7 # predicts the 2-nd element of the last season.
47+
np.testing.assert_array_equal(pred, expected)
48+
np.testing.assert_array_equal(pred2, expected)
49+
50+
# Test horizon wrapping around to a new season
51+
forecaster = NaiveForecaster(strategy="seasonal_last", seasonal_period=3, horizon=7)
52+
forecaster.fit(data)
53+
pred = forecaster.predict()
54+
pred2 = forecaster.predict(y=data)
55+
expected = 6 # predicts the 1-st element of the last season.
56+
np.testing.assert_array_equal(pred, expected)
57+
np.testing.assert_array_equal(pred2, expected)
58+
59+
# Last season is now [5, 6, 7, 8] with seasonal_period = 4
60+
forecaster = NaiveForecaster(strategy="seasonal_last", seasonal_period=4, horizon=6)
61+
forecaster.fit(data)
62+
pred = forecaster.predict()
63+
pred2 = forecaster.predict(y=data)
64+
expected = 6 # predicts the 2nd element of the new last season.
65+
np.testing.assert_array_equal(pred, expected)
66+
np.testing.assert_array_equal(pred2, expected)

docs/api_reference/forecasting.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
:toctree: auto_generated/
88
:template: class.rst
99
10-
DummyForecaster
1110
BaseForecaster
11+
NaiveForecaster
1212
RegressionForecaster
1313
ETSForecaster
1414
```

0 commit comments

Comments
 (0)