From f2002fcadd1fba83b7b2403ee8ddead4ed4cdb39 Mon Sep 17 00:00:00 2001 From: MatthewMiddlehurst Date: Mon, 9 Jun 2025 21:06:56 +0100 Subject: [PATCH 1/7] butchering --- aeon/forecasting/_ets.py | 76 +++++++----------------------- aeon/forecasting/base.py | 6 +-- aeon/forecasting/tests/test_ets.py | 15 ++---- 3 files changed, 25 insertions(+), 72 deletions(-) diff --git a/aeon/forecasting/_ets.py b/aeon/forecasting/_ets.py index cb0fef4e0d..8a5947122e 100644 --- a/aeon/forecasting/_ets.py +++ b/aeon/forecasting/_ets.py @@ -94,6 +94,7 @@ class ETSForecaster(BaseForecaster): _tags = { "capability:horizon": False, + "fit_is_empty": True, } def __init__( @@ -108,46 +109,33 @@ def __init__( phi: float = 0.99, horizon: int = 1, ): - self.alpha = alpha - self.beta = beta - self.gamma = gamma - self.phi = phi - self.forecast_val_ = 0.0 - self.level_ = 0.0 - self.trend_ = 0.0 - self.seasonality_ = None - self._beta = beta - self._gamma = gamma self.error_type = error_type self.trend_type = trend_type self.seasonality_type = seasonality_type self.seasonal_period = seasonal_period - self._seasonal_period = seasonal_period - self.n_timepoints_ = 0 - self.avg_mean_sq_err_ = 0 - self.liklihood_ = 0 - self.k_ = 0 - self.aic_ = 0 - self.residuals_ = [] - self.fitted_values_ = [] - super().__init__(horizon=horizon, axis=1) + self.alpha = alpha + self.beta = beta + self.gamma = gamma + self.phi = phi - def _fit(self, y, exog=None): - """Fit Exponential Smoothing forecaster to series y. + super().__init__(horizon=horizon, axis=1) - Fit a forecaster to predict self.horizon steps ahead using y. + def _predict(self, y=None, exog=None): + """ + Predict the next horizon steps ahead. Parameters ---------- - y : np.ndarray - A time series on which to learn a forecaster to predict horizon ahead + y : np.ndarray, default = None + A time series to predict the next horizon value for. If None, + predict the next horizon value after series seen in fit. exog : np.ndarray, default =None Optional exogenous time series data assumed to be aligned with y Returns ------- - self - Fitted ETSForecaster. + float + single prediction self.horizon steps ahead of y. """ _validate_parameter(self.error_type, False) _validate_parameter(self.seasonality_type, True) @@ -163,6 +151,10 @@ def _get_int(x): return 2 return x + self._seasonal_period = self.seasonal_period + self._beta = self.beta + self._gamma = self.gamma + self._error_type = _get_int(self.error_type) self._seasonality_type = _get_int(self.seasonality_type) self._trend_type = _get_int(self.trend_type) @@ -198,25 +190,7 @@ def _get_int(x): self._gamma, self.phi, ) - return self - def _predict(self, y=None, exog=None): - """ - Predict the next horizon steps ahead. - - Parameters - ---------- - y : np.ndarray, default = None - A time series to predict the next horizon value for. If None, - predict the next horizon value after series seen in fit. - exog : np.ndarray, default =None - Optional exogenous time series data assumed to be aligned with y - - Returns - ------- - float - single prediction self.horizon steps ahead of y. - """ fitted_value = _predict( self._trend_type, self._seasonality_type, @@ -230,20 +204,6 @@ def _predict(self, y=None, exog=None): ) return fitted_value - def _initialise(self, data): - """ - Initialize level, trend, and seasonality values for the ETS model. - - Parameters - ---------- - data : array-like - The time series data - (should contain at least two full seasons if seasonality is specified) - """ - self.level_, self.trend_, self.seasonality_ = _initialise( - self._trend_type, self._seasonality_type, self._seasonal_period, data - ) - @njit(fastmath=True, cache=True) def _numba_fit( diff --git a/aeon/forecasting/base.py b/aeon/forecasting/base.py index 3a084b6d99..bcd66f9a4d 100644 --- a/aeon/forecasting/base.py +++ b/aeon/forecasting/base.py @@ -86,8 +86,8 @@ def fit(self, y, exog=None): self.is_fitted = True return self._fit(y, exog) - @abstractmethod - def _fit(self, y, exog=None): ... + def _fit(self, y, exog=None): + return self def predict(self, y=None, exog=None): """Predict the next horizon steps ahead. @@ -105,7 +105,7 @@ def predict(self, y=None, exog=None): float single prediction self.horizon steps ahead of y. """ - self._check_is_fitted() + # self._check_is_fitted() if y is not None: self._check_X(y, self.axis) y = self._convert_y(y, self.axis) diff --git a/aeon/forecasting/tests/test_ets.py b/aeon/forecasting/tests/test_ets.py index f4f5f86590..598a4565e2 100644 --- a/aeon/forecasting/tests/test_ets.py +++ b/aeon/forecasting/tests/test_ets.py @@ -1,8 +1,5 @@ """Test ETS.""" -__maintainer__ = [] -__all__ = [] - import numpy as np import pytest @@ -26,8 +23,7 @@ def test_ets_forecaster_additive(): seasonality_type="additive", seasonal_period=4, ) - forecaster.fit(data) - p = forecaster.predict() + p = forecaster.predict(data) assert np.isclose(p, 9.191190608800001) @@ -47,8 +43,7 @@ def test_ets_forecaster_mult_error(): seasonality_type="additive", seasonal_period=4, ) - forecaster.fit(data) - p = forecaster.predict() + p = forecaster.predict(data) assert np.isclose(p, 16.20176819429869) @@ -68,8 +63,7 @@ def test_ets_forecaster_mult_compnents(): seasonality_type="multiplicative", seasonal_period=4, ) - forecaster.fit(data) - p = forecaster.predict() + p = forecaster.predict(data) assert np.isclose(p, 12.301259229712382) @@ -89,8 +83,7 @@ def test_ets_forecaster_multiplicative(): seasonality_type="multiplicative", seasonal_period=4, ) - forecaster.fit(data) - p = forecaster.predict() + p = forecaster.predict(data) assert np.isclose(p, 16.811888294476528) From df062f4fd5af9859fa2cae7dc5aa54c8f1e085bd Mon Sep 17 00:00:00 2001 From: MatthewMiddlehurst Date: Mon, 30 Jun 2025 22:59:46 +0100 Subject: [PATCH 2/7] fixes and mixins --- aeon/forecasting/_ets.py | 4 +- aeon/forecasting/_regression.py | 10 +++- aeon/forecasting/base.py | 94 ++++++++++++++++++--------------- 3 files changed, 62 insertions(+), 46 deletions(-) diff --git a/aeon/forecasting/_ets.py b/aeon/forecasting/_ets.py index 3d27035d46..a0815cc92e 100644 --- a/aeon/forecasting/_ets.py +++ b/aeon/forecasting/_ets.py @@ -12,13 +12,13 @@ import numpy as np from numba import njit -from aeon.forecasting.base import BaseForecaster +from aeon.forecasting.base import BaseForecaster, DirectForecastingMixin ADDITIVE = "additive" MULTIPLICATIVE = "multiplicative" -class ETSForecaster(BaseForecaster): +class ETSForecaster(BaseForecaster, DirectForecastingMixin): """Exponential Smoothing (ETS) forecaster. Implements the ETS (Error, Trend, Seasonality) forecaster, supporting additive diff --git a/aeon/forecasting/_regression.py b/aeon/forecasting/_regression.py index 206d0ca05d..7f15f27d10 100644 --- a/aeon/forecasting/_regression.py +++ b/aeon/forecasting/_regression.py @@ -8,10 +8,16 @@ import numpy as np from sklearn.linear_model import LinearRegression -from aeon.forecasting.base import BaseForecaster +from aeon.forecasting.base import ( + BaseForecaster, + DirectForecastingMixin, + IterativeForecastingMixin, +) -class RegressionForecaster(BaseForecaster): +class RegressionForecaster( + BaseForecaster, DirectForecastingMixin, IterativeForecastingMixin +): """ Regression based forecasting. diff --git a/aeon/forecasting/base.py b/aeon/forecasting/base.py index 4d6cdb5574..788dab95e4 100644 --- a/aeon/forecasting/base.py +++ b/aeon/forecasting/base.py @@ -5,7 +5,7 @@ """ __maintainer__ = ["TonyBagnall"] -__all__ = ["BaseForecaster"] +__all__ = ["BaseForecaster", "DirectForecastingMixin", "IterativeForecastingMixin"] from abc import abstractmethod from typing import final @@ -44,7 +44,7 @@ class BaseForecaster(BaseSeriesEstimator): def __init__(self, horizon: int, axis: int): self.horizon = horizon - self.meta_ = None # Meta data related to y on the last fit + super().__init__(axis) @final @@ -92,11 +92,8 @@ def fit(self, y, exog=None): self.is_fitted = True return self._fit(y, exog) - def _fit(self, y, exog=None): - return self - @final - def predict(self, y=None, exog=None): + def predict(self, y, exog=None): """Predict the next horizon steps ahead. Parameters @@ -112,7 +109,9 @@ def predict(self, y=None, exog=None): float single prediction self.horizon steps ahead of y. """ - # self._check_is_fitted() + if not self.get_tag("fit_is_empty"): + self._check_is_fitted() + if y is not None: self._check_X(y, self.axis) y = self._convert_y(y, self.axis) @@ -121,9 +120,6 @@ def predict(self, y=None, exog=None): return self._predict(y, exog) - @abstractmethod - def _predict(self, y=None, exog=None): ... - @final def forecast(self, y, exog=None): """Forecast the next horizon steps ahead. @@ -149,11 +145,52 @@ def forecast(self, y, exog=None): exog = self._convert_y(exog, self.axis) return self._forecast(y, exog) - def _forecast(self, y, exog=None): + def _fit(self, y, exog): + return self + + @abstractmethod + def _predict(self, y, exog): ... + + def _forecast(self, y, exog): """Forecast values for time series X.""" self.fit(y, exog) return self._predict(y, exog) + def _convert_y(self, y: VALID_SERIES_INNER_TYPES, axis: int): + """Convert y to self.get_tag("y_inner_type").""" + if axis > 1 or axis < 0: + raise ValueError(f"Input axis should be 0 or 1, saw {axis}") + + inner_type = self.get_tag("y_inner_type") + if not isinstance(inner_type, list): + inner_type = [inner_type] + inner_names = [i.split(".")[-1] for i in inner_type] + + input = type(y).__name__ + if input not in inner_names: + if inner_names[0] == "ndarray": + y = y.to_numpy() + elif inner_names[0] == "DataFrame": + transpose = False + if y.ndim == 1 and axis == 1: + transpose = True + y = pd.DataFrame(y) + if transpose: + y = y.T + else: + raise ValueError( + f"Unsupported inner type {inner_names[0]} derived from {inner_type}" + ) + if y.ndim > 1 and self.axis != axis: + y = y.T + elif y.ndim == 1 and isinstance(y, np.ndarray): + y = y[np.newaxis, :] if self.axis == 1 else y[:, np.newaxis] + return y + + +class DirectForecastingMixin: + """Mixin class for direct forecasting.""" + @final def direct_forecast(self, y, prediction_horizon, exog=None): """ @@ -208,6 +245,10 @@ def direct_forecast(self, y, prediction_horizon, exog=None): preds[i] = self.forecast(y, exog) return preds + +class IterativeForecastingMixin: + """Mixin class for iterative forecasting.""" + def iterative_forecast(self, y, prediction_horizon): """ Forecast ``prediction_horizon`` prediction using a single model from `y`. @@ -254,34 +295,3 @@ def iterative_forecast(self, y, prediction_horizon): preds[i] = self.predict(y) y = np.append(y, preds[i]) return preds - - def _convert_y(self, y: VALID_SERIES_INNER_TYPES, axis: int): - """Convert y to self.get_tag("y_inner_type").""" - if axis > 1 or axis < 0: - raise ValueError(f"Input axis should be 0 or 1, saw {axis}") - - inner_type = self.get_tag("y_inner_type") - if not isinstance(inner_type, list): - inner_type = [inner_type] - inner_names = [i.split(".")[-1] for i in inner_type] - - input = type(y).__name__ - if input not in inner_names: - if inner_names[0] == "ndarray": - y = y.to_numpy() - elif inner_names[0] == "DataFrame": - transpose = False - if y.ndim == 1 and axis == 1: - transpose = True - y = pd.DataFrame(y) - if transpose: - y = y.T - else: - raise ValueError( - f"Unsupported inner type {inner_names[0]} derived from {inner_type}" - ) - if y.ndim > 1 and self.axis != axis: - y = y.T - elif y.ndim == 1 and isinstance(y, np.ndarray): - y = y[np.newaxis, :] if self.axis == 1 else y[:, np.newaxis] - return y From d4cf31bb2d07fb1506f4c0c48847b0742c519830 Mon Sep 17 00:00:00 2001 From: MatthewMiddlehurst Date: Mon, 30 Jun 2025 23:27:32 +0100 Subject: [PATCH 3/7] naive and testing --- aeon/forecasting/_naive.py | 40 +++++++--------------------- aeon/forecasting/tests/test_base.py | 4 +-- aeon/forecasting/tests/test_naive.py | 22 +++++---------- 3 files changed, 18 insertions(+), 48 deletions(-) diff --git a/aeon/forecasting/_naive.py b/aeon/forecasting/_naive.py index da242018e2..706abd6219 100644 --- a/aeon/forecasting/_naive.py +++ b/aeon/forecasting/_naive.py @@ -31,49 +31,29 @@ class NaiveForecaster(BaseForecaster): Only relevant for "seasonal_last". """ + _tags = { + "fit_is_empty": True, + } + def __init__(self, strategy="last", seasonal_period=1, horizon=1): self.strategy = strategy self.seasonal_period = seasonal_period super().__init__(horizon=horizon, axis=1) - def _fit(self, y, exog=None): + def _predict(self, y, exog=None): y_squeezed = y.squeeze() if self.strategy == "last": - self._fitted_scalar_value = y_squeezed[-1] + return y_squeezed[-1] elif self.strategy == "mean": - self._fitted_scalar_value = np.mean(y_squeezed) + return np.mean(y_squeezed) elif self.strategy == "seasonal_last": - self._fitted_last_season = y_squeezed[-self.seasonal_period :] + period = y_squeezed[-self.seasonal_period :] + idx = (self.horizon - 1) % self.seasonal_period + return period[idx] else: raise ValueError( f"Unknown strategy: {self.strategy}. " "Valid strategies are 'last', 'mean', 'seasonal_last'." ) - return self - - def _predict(self, y=None, exog=None): - if y is None: - if self.strategy == "last" or self.strategy == "mean": - return self._fitted_scalar_value - - # For "seasonal_last" strategy - prediction_index = (self.horizon - 1) % self.seasonal_period - return self._fitted_last_season[prediction_index] - else: - y_squeezed = y.squeeze() - - if self.strategy == "last": - return y_squeezed[-1] - elif self.strategy == "mean": - return np.mean(y_squeezed) - elif self.strategy == "seasonal_last": - period = y_squeezed[-self.seasonal_period :] - idx = (self.horizon - 1) % self.seasonal_period - return period[idx] - else: - raise ValueError( - f"Unknown strategy: {self.strategy}. " - "Valid strategies are 'last', 'mean', 'seasonal_last'." - ) diff --git a/aeon/forecasting/tests/test_base.py b/aeon/forecasting/tests/test_base.py index e6b729c62b..7159f163d3 100644 --- a/aeon/forecasting/tests/test_base.py +++ b/aeon/forecasting/tests/test_base.py @@ -12,10 +12,10 @@ def test_base_forecaster(): f = NaiveForecaster() y = np.random.rand(50) f.fit(y) - p1 = f.predict() + p1 = f.predict(y) assert p1 == y[-1] p2 = f.forecast(y) - p3 = f._forecast(y) + p3 = f._forecast(y, None) assert p2 == p1 assert p3 == p2 with pytest.raises(ValueError, match="Exogenous variables passed"): diff --git a/aeon/forecasting/tests/test_naive.py b/aeon/forecasting/tests/test_naive.py index c0f9a98bd2..35603f2e35 100644 --- a/aeon/forecasting/tests/test_naive.py +++ b/aeon/forecasting/tests/test_naive.py @@ -9,8 +9,7 @@ def test_naive_forecaster_last_strategy(): """Test NaiveForecaster with 'last' strategy.""" sample_data = np.array([10, 20, 30, 40, 50]) forecaster = NaiveForecaster(strategy="last", horizon=3) - forecaster.fit(sample_data) - predictions = forecaster.predict() + predictions = forecaster.predict(sample_data) expected = 50 np.testing.assert_array_equal(predictions, expected) @@ -19,8 +18,7 @@ def test_naive_forecaster_mean_strategy(): """Test NaiveForecaster with 'mean' strategy.""" sample_data = np.array([10, 20, 30, 40, 50]) forecaster = NaiveForecaster(strategy="mean", horizon=2) - forecaster.fit(sample_data) - predictions = forecaster.predict() + predictions = forecaster.predict(sample_data) expected = 30 # Mean of [10, 20, 30, 40, 50] is 30 np.testing.assert_array_equal(predictions, expected) @@ -32,35 +30,27 @@ def test_naive_forecaster_seasonal_last_strategy(): # Last season is [6, 7, 8] for seasonal_period = 3 forecaster = NaiveForecaster(strategy="seasonal_last", seasonal_period=3, horizon=4) forecaster.fit(data) - pred = forecaster.predict() - pred2 = forecaster.predict(y=data) + pred = forecaster.predict(data) expected = 6 # predicts the 1-st element of the last season. np.testing.assert_array_equal(pred, expected) - np.testing.assert_array_equal(pred2, expected) # Test horizon within the season length forecaster = NaiveForecaster(strategy="seasonal_last", seasonal_period=3, horizon=2) forecaster.fit(data) - pred = forecaster.predict() - pred2 = forecaster.predict(y=data) + pred = forecaster.predict(data) expected = 7 # predicts the 2-nd element of the last season. np.testing.assert_array_equal(pred, expected) - np.testing.assert_array_equal(pred2, expected) # Test horizon wrapping around to a new season forecaster = NaiveForecaster(strategy="seasonal_last", seasonal_period=3, horizon=7) forecaster.fit(data) - pred = forecaster.predict() - pred2 = forecaster.predict(y=data) + pred = forecaster.predict(data) expected = 6 # predicts the 1-st element of the last season. np.testing.assert_array_equal(pred, expected) - np.testing.assert_array_equal(pred2, expected) # Last season is now [5, 6, 7, 8] with seasonal_period = 4 forecaster = NaiveForecaster(strategy="seasonal_last", seasonal_period=4, horizon=6) forecaster.fit(data) - pred = forecaster.predict() - pred2 = forecaster.predict(y=data) + pred = forecaster.predict(data) expected = 6 # predicts the 2nd element of the new last season. np.testing.assert_array_equal(pred, expected) - np.testing.assert_array_equal(pred2, expected) From f1a3da16ae72dfed0330fb81bd792b6ba1705895 Mon Sep 17 00:00:00 2001 From: MatthewMiddlehurst Date: Mon, 30 Jun 2025 23:40:02 +0100 Subject: [PATCH 4/7] test errors --- aeon/forecasting/_ets.py | 4 +--- aeon/forecasting/tests/test_regressor.py | 20 +++++++++----------- 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/aeon/forecasting/_ets.py b/aeon/forecasting/_ets.py index a0815cc92e..18ccf000ed 100644 --- a/aeon/forecasting/_ets.py +++ b/aeon/forecasting/_ets.py @@ -84,9 +84,7 @@ class ETSForecaster(BaseForecaster, DirectForecastingMixin): ... error_type='additive', trend_type='multiplicative', ... seasonality_type='multiplicative', seasonal_period=4 ... ) - >>> forecaster.fit(y) - ETSForecaster(...) - >>> forecaster.predict() + >>> forecaster.predict(y) 366.90200486015596 """ diff --git a/aeon/forecasting/tests/test_regressor.py b/aeon/forecasting/tests/test_regressor.py index f50519d1c9..5e63db2a92 100644 --- a/aeon/forecasting/tests/test_regressor.py +++ b/aeon/forecasting/tests/test_regressor.py @@ -13,18 +13,16 @@ def test_regression_forecaster(): y = np.random.rand(100) f = RegressionForecaster(window=10) f.fit(y) - p = f.predict() - p2 = f.predict(y) + p = f.predict(y) + p2 = f.forecast(y) assert p == p2 - p3 = f.forecast(y) - assert p == p3 f2 = RegressionForecaster(regressor=LinearRegression(), window=10) f2.fit(y) - p2 = f2.predict() + p2 = f2.predict(y) assert p == p2 f2 = RegressionForecaster(regressor=DummyRegressor(), window=10) f2.fit(y) - f2.predict() + f2.predict(y) with pytest.raises(ValueError): f = RegressionForecaster(window=-1) @@ -46,14 +44,14 @@ def test_regression_forecaster_with_exog(): # Test fit and predict with exog f.fit(y, exog=exog) - p1 = f.predict() - assert isinstance(p1, float) + p = f.predict(y, exog=exog) + assert isinstance(p, float) # Test that exog variable has an impact exog_zeros = np.zeros(n_samples) f.fit(y, exog=exog_zeros) - p2 = f.predict() - assert p1 != p2 + p2 = f.predict(y, exog=exog) + assert p != p2 # Test that forecast method works and is equivalent to fit+predict y_new = np.arange(50, 150) @@ -61,7 +59,7 @@ def test_regression_forecaster_with_exog(): # Manual fit + predict f.fit(y=y_new, exog=exog_new) - p_manual = f.predict() + p_manual = f.predict(y_new, exog=exog_new) # forecast() method p_forecast = f.forecast(y=y_new, exog=exog_new) From de68f61756b64f89b3fe1fc218d8731ece0231e3 Mon Sep 17 00:00:00 2001 From: MatthewMiddlehurst Date: Thu, 3 Jul 2025 12:36:04 +0100 Subject: [PATCH 5/7] tests --- aeon/forecasting/_ets.py | 93 +++++++------------ aeon/forecasting/base.py | 20 +++- .../_yield_forecasting_checks.py | 8 +- 3 files changed, 61 insertions(+), 60 deletions(-) diff --git a/aeon/forecasting/_ets.py b/aeon/forecasting/_ets.py index 18ccf000ed..38d6b6164e 100644 --- a/aeon/forecasting/_ets.py +++ b/aeon/forecasting/_ets.py @@ -45,29 +45,6 @@ class ETSForecaster(BaseForecaster, DirectForecastingMixin): phi : float, default=0.99 Trend damping parameter (used only for damped trend models). - Attributes - ---------- - forecast_val_ : float - Forecast value for the given horizon. - level_ : float - Estimated level component. - trend_ : float - Estimated trend component. - seasonality_ : array-like or None - Estimated seasonal components. - aic_ : float - Akaike Information Criterion of the fitted model. - avg_mean_sq_err_ : float - Average mean squared error of the fitted model. - residuals_ : list of float - Residuals from the fitted model. - fitted_values_ : list of float - Fitted values for the training data. - liklihood_ : float - Log-likelihood of the fitted model. - n_timepoints_ : int - Number of time points in the training series. - References ---------- .. [1] R. J. Hyndman and G. Athanasopoulos, @@ -146,56 +123,58 @@ def _get_int(x): return 2 return x - self._seasonal_period = self.seasonal_period - self._beta = self.beta - self._gamma = self.gamma + error_type = _get_int(self.error_type) + seasonality_type = _get_int(self.seasonality_type) + trend_type = _get_int(self.trend_type) - self._error_type = _get_int(self.error_type) - self._seasonality_type = _get_int(self.seasonality_type) - self._trend_type = _get_int(self.trend_type) - if self._seasonal_period < 1 or self._seasonality_type == 0: - self._seasonal_period = 1 + seasonal_period = self.seasonal_period + if self.seasonal_period < 1 or seasonality_type == 0: + seasonal_period = 1 + beta = self.beta if self._trend_type == 0: # Required for the equations in _update_states to work correctly - self._beta = 0 - if self._seasonality_type == 0: + beta = 0 + + gamma = self.gamma + if seasonality_type == 0: # Required for the equations in _update_states to work correctly - self._gamma = 0 + gamma = 0 + data = y.squeeze() ( - self.level_, - self.trend_, - self.seasonality_, - self.n_timepoints_, - self.residuals_, - self.fitted_values_, - self.avg_mean_sq_err_, - self.liklihood_, - self.k_, - self.aic_, + level_, + trend_, + seasonality_, + n_timepoints_, + residuals_, + fitted_values_, + avg_mean_sq_err_, + liklihood_, + k_, + aic_, ) = _numba_fit( data, - self._error_type, - self._trend_type, - self._seasonality_type, - self._seasonal_period, + error_type, + trend_type, + seasonality_type, + seasonal_period, self.alpha, - self._beta, - self._gamma, + beta, + gamma, self.phi, ) fitted_value = _predict( - self._trend_type, - self._seasonality_type, - self.level_, - self.trend_, - self.seasonality_, + trend_type, + seasonality_type, + level_, + trend_, + seasonality_, self.phi, self.horizon, - self.n_timepoints_, - self._seasonal_period, + n_timepoints_, + seasonal_period, ) return fitted_value diff --git a/aeon/forecasting/base.py b/aeon/forecasting/base.py index 788dab95e4..d69ed1c113 100644 --- a/aeon/forecasting/base.py +++ b/aeon/forecasting/base.py @@ -89,8 +89,11 @@ def fit(self, y, exog=None): if exog is not None: exog = self._convert_y(exog, self.axis) + self._fit(y, exog) + + # this should happen last self.is_fitted = True - return self._fit(y, exog) + return self @final def predict(self, y, exog=None): @@ -112,9 +115,24 @@ def predict(self, y, exog=None): if not self.get_tag("fit_is_empty"): self._check_is_fitted() + horizon = self.get_tag("capability:horizon") + if not horizon and self.horizon > 1: + raise ValueError( + f"Horizon is set >1, but {self.__class__.__name__} cannot handle a " + f"horizon greater than 1" + ) + + exog_tag = self.get_tag("capability:exogenous") + if not exog_tag and exog is not None: + raise ValueError( + f"Exogenous variables passed but {self.__class__.__name__} cannot " + "handle exogenous variables" + ) + if y is not None: self._check_X(y, self.axis) y = self._convert_y(y, self.axis) + if exog is not None: exog = self._convert_y(exog, self.axis) diff --git a/aeon/testing/estimator_checking/_yield_forecasting_checks.py b/aeon/testing/estimator_checking/_yield_forecasting_checks.py index 0a2fc3bea2..ef37e5c0d0 100644 --- a/aeon/testing/estimator_checking/_yield_forecasting_checks.py +++ b/aeon/testing/estimator_checking/_yield_forecasting_checks.py @@ -37,7 +37,9 @@ def check_forecaster_overrides_and_tags(estimator_class): # Test that all forecasters implement abstract predict. assert "_predict" in estimator_class.__dict__ - # todo decide what to do with "fit_is_empty" and abstract "_fit" + # Test that fit_is_empty is correctly set + fit_is_empty = estimator_class.get_class_tag(tag_name="fit_is_empty") + assert fit_is_empty == ("_fit" not in estimator_class.__dict__) # Test valid tag for X_inner_type X_inner_type = estimator_class.get_class_tag(tag_name="X_inner_type") @@ -62,7 +64,9 @@ def check_forecaster_output(estimator, datatype): estimator.fit( FULL_TEST_DATA_DICT[datatype]["train"][0], ) - y_pred = estimator.predict() + y_pred = estimator.predict( + FULL_TEST_DATA_DICT[datatype]["train"][0], + ) assert isinstance(y_pred, float), ( f"predict() output should be float, got" f" {type(y_pred)}" ) From 61c6d5fe1d81a84690ef73eb9d7147930be2023e Mon Sep 17 00:00:00 2001 From: MatthewMiddlehurst Date: Thu, 3 Jul 2025 12:45:56 +0100 Subject: [PATCH 6/7] tests --- aeon/forecasting/_ets.py | 2 +- aeon/forecasting/base.py | 27 +++++++++++++++++++++++---- aeon/forecasting/tests/test_base.py | 2 +- 3 files changed, 25 insertions(+), 6 deletions(-) diff --git a/aeon/forecasting/_ets.py b/aeon/forecasting/_ets.py index 38d6b6164e..17e8b057a9 100644 --- a/aeon/forecasting/_ets.py +++ b/aeon/forecasting/_ets.py @@ -132,7 +132,7 @@ def _get_int(x): seasonal_period = 1 beta = self.beta - if self._trend_type == 0: + if trend_type == 0: # Required for the equations in _update_states to work correctly beta = 0 diff --git a/aeon/forecasting/base.py b/aeon/forecasting/base.py index d69ed1c113..2be4942ba8 100644 --- a/aeon/forecasting/base.py +++ b/aeon/forecasting/base.py @@ -129,9 +129,8 @@ def predict(self, y, exog=None): "handle exogenous variables" ) - if y is not None: - self._check_X(y, self.axis) - y = self._convert_y(y, self.axis) + self._check_X(y, self.axis) + y = self._convert_y(y, self.axis) if exog is not None: exog = self._convert_y(exog, self.axis) @@ -157,11 +156,31 @@ def forecast(self, y, exog=None): float single prediction self.horizon steps ahead of y. """ + horizon = self.get_tag("capability:horizon") + if not horizon and self.horizon > 1: + raise ValueError( + f"Horizon is set >1, but {self.__class__.__name__} cannot handle a " + f"horizon greater than 1" + ) + + exog_tag = self.get_tag("capability:exogenous") + if not exog_tag and exog is not None: + raise ValueError( + f"Exogenous variables passed but {self.__class__.__name__} cannot " + "handle exogenous variables" + ) + self._check_X(y, self.axis) y = self._convert_y(y, self.axis) + if exog is not None: exog = self._convert_y(exog, self.axis) - return self._forecast(y, exog) + + y_pred = self._forecast(y, exog) + + # this should happen last + self.is_fitted = True + return y_pred def _fit(self, y, exog): return self diff --git a/aeon/forecasting/tests/test_base.py b/aeon/forecasting/tests/test_base.py index 7159f163d3..7fb03370aa 100644 --- a/aeon/forecasting/tests/test_base.py +++ b/aeon/forecasting/tests/test_base.py @@ -19,7 +19,7 @@ def test_base_forecaster(): assert p2 == p1 assert p3 == p2 with pytest.raises(ValueError, match="Exogenous variables passed"): - f.fit(y, exog=y) + f.predict(y, exog=y) def test_convert_y(): From d4340e0cc222b99f0cb455dba6393dc3a2e5461e Mon Sep 17 00:00:00 2001 From: MatthewMiddlehurst Date: Thu, 3 Jul 2025 13:02:05 +0100 Subject: [PATCH 7/7] notebook --- examples/forecasting/forecasting.ipynb | 234 ++++++++++++++++--------- examples/forecasting/iterative.ipynb | 8 - examples/forecasting/regression.ipynb | 14 +- 3 files changed, 159 insertions(+), 97 deletions(-) diff --git a/examples/forecasting/forecasting.ipynb b/examples/forecasting/forecasting.ipynb index aa7f8f3a04..5e0f51d7b3 100644 --- a/examples/forecasting/forecasting.ipynb +++ b/examples/forecasting/forecasting.ipynb @@ -58,9 +58,12 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "metadata": { + "ExecuteTime": { + "end_time": "2025-07-03T12:01:17.776541Z", + "start_time": "2025-07-03T12:01:16.833749Z" + } + }, "source": [ "import inspect\n", "\n", @@ -73,7 +76,17 @@ " if not func[0].startswith(\"_\")\n", "]\n", "print(public_methods)" - ] + ], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['clone', 'fit', 'forecast', 'get_fitted_params', 'get_metadata_routing', 'get_params', 'get_tag', 'get_tags', 'predict', 'reset', 'set_params', 'set_tags']\n" + ] + } + ], + "execution_count": 1 }, { "cell_type": "markdown", @@ -89,15 +102,30 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "metadata": { + "ExecuteTime": { + "end_time": "2025-07-03T12:01:17.785953Z", + "start_time": "2025-07-03T12:01:17.783053Z" + } + }, "source": [ "from aeon.utils.data_types import SERIES_DATA_TYPES\n", "\n", "print(\" Possible data structures for input to forecaster \", SERIES_DATA_TYPES)\n", "print(\"\\n Tags for BaseForecaster: \", BaseForecaster.get_class_tags())" - ] + ], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " Possible data structures for input to forecaster ['pd.Series', 'pd.DataFrame', 'np.ndarray']\n", + "\n", + " Tags for BaseForecaster: {'python_version': None, 'python_dependencies': None, 'cant_pickle': False, 'non_deterministic': False, 'algorithm_type': None, 'capability:missing_values': False, 'capability:multithreading': False, 'capability:univariate': True, 'capability:multivariate': False, 'X_inner_type': 'np.ndarray', 'capability:horizon': True, 'capability:exogenous': False, 'fit_is_empty': False, 'y_inner_type': 'np.ndarray'}\n" + ] + } + ], + "execution_count": 2 }, { "cell_type": "markdown", @@ -109,9 +137,12 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "metadata": { + "ExecuteTime": { + "end_time": "2025-07-03T12:01:17.831267Z", + "start_time": "2025-07-03T12:01:17.798471Z" + } + }, "source": [ "import pandas as pd\n", "\n", @@ -121,7 +152,17 @@ "print(type(y))\n", "y2 = pd.Series(y)\n", "y3 = pd.DataFrame(y)" - ] + ], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "execution_count": 3 }, { "cell_type": "markdown", @@ -137,41 +178,55 @@ }, { "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "ename": "NameError", - "evalue": "name 'y' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[31m---------------------------------------------------------------------------\u001b[39m", - "\u001b[31mNameError\u001b[39m Traceback (most recent call last)", - "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[1]\u001b[39m\u001b[32m, line 5\u001b[39m\n\u001b[32m 2\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01maeon\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mforecasting\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m NaiveForecaster\n\u001b[32m 4\u001b[39m d = NaiveForecaster(strategy=\u001b[33m\"\u001b[39m\u001b[33mlast\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m----> \u001b[39m\u001b[32m5\u001b[39m d.fit(\u001b[43my\u001b[49m)\n\u001b[32m 6\u001b[39m p = d.predict()\n\u001b[32m 7\u001b[39m \u001b[38;5;28mprint\u001b[39m(p)\n", - "\u001b[31mNameError\u001b[39m: name 'y' is not defined" - ] + "metadata": { + "ExecuteTime": { + "end_time": "2025-07-03T12:01:18.145102Z", + "start_time": "2025-07-03T12:01:18.141759Z" } - ], + }, "source": [ "# Fit then predict\n", "from aeon.forecasting import NaiveForecaster\n", "\n", "d = NaiveForecaster(strategy=\"last\")\n", "d.fit(y)\n", - "p = d.predict()\n", + "p = d.predict(y)\n", "print(p)" - ] + ], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "432.0\n" + ] + } + ], + "execution_count": 4 }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "metadata": { + "ExecuteTime": { + "end_time": "2025-07-03T12:01:18.156822Z", + "start_time": "2025-07-03T12:01:18.153674Z" + } + }, "source": [ "# forecast is equivalent to fit_predict in other estimators\n", "p2 = d.forecast(y)\n", "print(p2)" - ] + ], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "432.0\n" + ] + } + ], + "execution_count": 5 }, { "cell_type": "markdown", @@ -182,16 +237,30 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "metadata": { + "ExecuteTime": { + "end_time": "2025-07-03T12:01:18.179989Z", + "start_time": "2025-07-03T12:01:18.176785Z" + } + }, "source": [ "s = NaiveForecaster(strategy=\"seasonal_last\", horizon=2, seasonal_period=4)\n", "s.fit(y)\n", - "p = s.predict()\n", + "p = s.predict(y)\n", "print(f\"Last season: {y[-4:]}\")\n", "print(f\"Forecast: {p}\")" - ] + ], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Last season: [508. 461. 390. 432.]\n", + "Forecast: 461.0\n" + ] + } + ], + "execution_count": 6 }, { "cell_type": "markdown", @@ -209,21 +278,35 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "metadata": { + "ExecuteTime": { + "end_time": "2025-07-03T12:01:18.262330Z", + "start_time": "2025-07-03T12:01:18.256990Z" + } + }, "source": [ "from aeon.forecasting import RegressionForecaster\n", "\n", "r = RegressionForecaster(window=20)\n", "r.fit(y)\n", - "p = r.predict()\n", + "p = r.predict(y)\n", "print(p)\n", "r2 = RegressionForecaster(window=10, horizon=5)\n", "r2.fit(y)\n", "p = r2.predict(y)\n", "print(p)" - ] + ], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "451.6754197093645\n", + "527.3689709356747\n" + ] + } + ], + "execution_count": 7 }, { "cell_type": "markdown", @@ -236,14 +319,28 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "metadata": { + "ExecuteTime": { + "end_time": "2025-07-03T12:01:18.284759Z", + "start_time": "2025-07-03T12:01:18.279541Z" + } + }, "source": [ "p1 = r.forecast(y)\n", "p2 = r2.forecast(y)\n", "print(p1, \",\\n\", p2)" - ] + ], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "451.6754197093645 ,\n", + " 527.3689709356747\n" + ] + } + ], + "execution_count": 8 }, { "cell_type": "markdown", @@ -257,14 +354,20 @@ }, { "cell_type": "code", - "execution_count": 20, "metadata": { + "collapsed": false, "ExecuteTime": { - "end_time": "2024-11-16T19:21:26.225501Z", - "start_time": "2024-11-16T19:21:26.204872Z" - }, - "collapsed": false + "end_time": "2025-07-03T12:01:19.715565Z", + "start_time": "2025-07-03T12:01:18.294767Z" + } }, + "source": [ + "from aeon.forecasting import ETSForecaster\n", + "\n", + "ets = ETSForecaster()\n", + "ets.fit(y)\n", + "ets.predict(y)" + ], "outputs": [ { "data": { @@ -272,37 +375,12 @@ "460.302772481884" ] }, - "execution_count": 20, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], - "source": [ - "from aeon.forecasting import ETSForecaster\n", - "\n", - "ets = ETSForecaster()\n", - "ets.fit(y)\n", - "ets.predict()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "ExecuteTime": { - "end_time": "2024-11-16T19:21:27.095665Z", - "start_time": "2024-11-16T19:21:27.077715Z" - } - }, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] + "execution_count": 9 } ], "metadata": { diff --git a/examples/forecasting/iterative.ipynb b/examples/forecasting/iterative.ipynb index e98d588656..133907f97f 100644 --- a/examples/forecasting/iterative.ipynb +++ b/examples/forecasting/iterative.ipynb @@ -304,14 +304,6 @@ } ], "execution_count": 27 - }, - { - "metadata": {}, - "cell_type": "code", - "outputs": [], - "execution_count": null, - "source": "", - "id": "5e283827ebb7141b" } ], "metadata": { diff --git a/examples/forecasting/regression.ipynb b/examples/forecasting/regression.ipynb index b442c6b46f..cacf597def 100644 --- a/examples/forecasting/regression.ipynb +++ b/examples/forecasting/regression.ipynb @@ -220,11 +220,11 @@ "airline = load_airline()\n", "rf = RegressionForecaster(window=50)\n", "rf.fit(airline)\n", - "p3 = rf.predict()\n", + "p3 = rf.predict(airline)\n", "print(f\" Forecast for airline with linear regression = {p1} and {p3}\")\n", "rf2 = RegressionForecaster(regressor=DrCIFRegressor(n_estimators=10), window=50)\n", "rf2.fit(airline)\n", - "p4 = rf.predict()\n", + "p4 = rf.predict(airline)\n", "print(f\" Forecast for airline with DrCIF = {p1} and {p3}\")" ], "id": "bdc35a7a671ee254", @@ -332,17 +332,9 @@ "source": [ "All aeon forecasters predict a single value. If you want to forecast a range of \n", "values ahead you should use the functions `recursive_foreacast` or `direct_forecast` \n", - "(notebooks coming soon). " + "(notebooks coming soon)." ], "id": "939c3a82b06e1b95" - }, - { - "metadata": {}, - "cell_type": "code", - "outputs": [], - "execution_count": null, - "source": "", - "id": "4bf0708606952c16" } ], "metadata": {