diff --git a/aeon/forecasting/__init__.py b/aeon/forecasting/__init__.py index 0b134857dd..5b54dd2a19 100644 --- a/aeon/forecasting/__init__.py +++ b/aeon/forecasting/__init__.py @@ -1,13 +1,15 @@ """Forecasters.""" __all__ = [ - "NaiveForecaster", "BaseForecaster", + "NaiveForecaster", "RegressionForecaster", "ETSForecaster", "TVPForecaster", + "ARIMAForecaster", ] +from aeon.forecasting._arima import ARIMAForecaster from aeon.forecasting._ets import ETSForecaster from aeon.forecasting._naive import NaiveForecaster from aeon.forecasting._regression import RegressionForecaster diff --git a/aeon/forecasting/_arima.py b/aeon/forecasting/_arima.py new file mode 100644 index 0000000000..f19445c02e --- /dev/null +++ b/aeon/forecasting/_arima.py @@ -0,0 +1,282 @@ +"""ARIMAForecaster. + +An implementation of the ARIMA forecasting algorithm. +""" + +__maintainer__ = ["alexbanwell1", "TonyBagnall"] +__all__ = ["ARIMAForecaster"] + +import numpy as np +from numba import njit + +from aeon.forecasting.base import BaseForecaster +from aeon.utils.optimisation._nelder_mead import nelder_mead + + +class ARIMAForecaster(BaseForecaster): + """AutoRegressive Integrated Moving Average (ARIMA) forecaster. + + ARIMA with fixed model structure and fitted parameters found with an + nelder mead optimizer to minimise the AIC. + + Parameters + ---------- + p : int, default=1, + Autoregressive (p) order of the ARIMA model + d : int, default=0, + Differencing (d) order of the ARIMA model + q : int, default=1, + Moving average (q) order of the ARIMA model + use_constant: bool = False, + Presence of a constant/intercept term in the model. + + Attributes + ---------- + residuals_ : np.ndarray + Residual errors from the fitted model. + aic_ : float + Akaike Information Criterion for the fitted model. + c_ : float, default = 0 + Intercept term. + phi_ : np.ndarray + Coefficients for autoregressive terms (length p). + theta_ : np.ndarray + Coefficients for moving average terms (length q). + + References + ---------- + .. [1] R. J. Hyndman and G. Athanasopoulos, + Forecasting: Principles and Practice. OTexts, 2014. + https://otexts.com/fpp3/ + + Examples + -------- + >>> from aeon.forecasting import ARIMAForecaster + >>> from aeon.datasets import load_airline + >>> y = load_airline() + >>> forecaster = ARIMAForecaster(p=2,d=1) + >>> forecaster.forecast(y) + 474.49449... + """ + + _tags = { + "capability:horizon": False, # cannot fit to a horizon other than 1 + } + + def __init__(self, p: int = 1, d: int = 0, q: int = 1, use_constant: bool = False): + self.p = p + self.d = d + self.q = q + self.use_constant = use_constant + self.phi_ = 0 + self.theta_ = 0 + self.c_ = 0 + self._series = [] + self._differenced_series = [] + self.residuals_ = [] + self.fitted_values_ = [] + self.aic_ = 0 + self._model = [] + self._parameters = [] + super().__init__(horizon=1, axis=1) + + def _fit(self, y, exog=None): + """Fit ARIMA forecaster to series y to predict one ahead using y. + + Parameters + ---------- + y : np.ndarray + A time series on which to learn a forecaster to predict horizon ahead + exog : np.ndarray, default =None + Not allowed for this forecaster + + Returns + ------- + self + Fitted ARIMAForecaster. + """ + self._series = np.array(y.squeeze(), dtype=np.float64) + self._model = np.array( + (1 if self.use_constant else 0, self.p, self.q), dtype=np.int32 + ) + self._differenced_series = np.diff(self._series, n=self.d) + + (self._parameters, self.aic_) = nelder_mead( + _arima_model_wrapper, + np.sum(self._model[:3]), + self._differenced_series, + self._model, + ) + (self.c_, self.phi_, self.theta_) = _extract_params( + self._parameters, self._model + ) + (self.aic_, self.residuals_, self.fitted_values_) = _arima_model( + self._parameters, + _calc_arma, + self._differenced_series, + self._model, + np.empty(0), + ) + self.forecast_ = _calc_arma( + self._differenced_series, + self._model, + len(y), + self._parameters, + self.residuals_, + ) + + return self + + def _predict(self, y, exog=None): + """ + Predict the next step ahead for training data or y. + + Parameters + ---------- + y : np.ndarray, default = None + A time series to predict the next horizon value for. If None, + predict the next horizon value after series seen in fit. + exog : np.ndarray, default =None + Optional exogenous time series data assumed to be aligned with y + + Returns + ------- + float + Prediction 1 step ahead of the data seen in fit or passed as y. + """ + series = y.squeeze() + # Difference the series using numpy + differenced_series = np.diff(self._series, n=self.d) + pred = _single_forecast(differenced_series, self.c_, self.phi_, self.theta_) + forecast = pred + series[-self.d :].sum() if self.d > 0 else pred + # Need to undifference it! + return forecast + + def _forecast(self, y, exog=None): + """Forecast one ahead for time series y.""" + self.fit(y, exog) + return self.forecast_ + + def iterative_forecast(self, y, prediction_horizon): + self.fit(y) + preds = np.zeros(prediction_horizon) + preds[0] = self.forecast_ + differenced_series = np.diff(self._series, n=self.d) + for i in range(1, prediction_horizon): + differenced_series = np.append(differenced_series, preds[i - 1]) + preds[i] = _single_forecast( + differenced_series, self.c_, self.phi_, self.theta_ + ) + return preds + + +@njit(cache=True, fastmath=True) +def _aic(residuals, num_params): + """Calculate the log-likelihood of a model.""" + variance = np.mean(residuals**2) + liklihood = len(residuals) * (np.log(2 * np.pi) + np.log(variance) + 1) + return liklihood + 2 * num_params + + +@njit(fastmath=True) +def _arima_model_wrapper(params, data, model): + return _arima_model(params, _calc_arma, data, model, np.empty(0))[0] + + +# Define the ARIMA(p, d, q) likelihood function +@njit(cache=True, fastmath=True) +def _arima_model(params, base_function, data, model, residuals): + """Calculate the log-likelihood of an ARIMA model given the parameters.""" + formatted_params = _extract_params(params, model) # Extract parameters + + # Initialize residuals + n = len(data) + m = len(residuals) + num_predictions = n - m + 1 + residuals = np.concatenate((residuals, np.zeros(num_predictions - 1))) + expect_full_history = m > 0 # I.e. we've been provided with some residuals + fitted_values = np.zeros(num_predictions) + for t in range(num_predictions): + fitted_values[t] = base_function( + data, + model, + m + t, + formatted_params, + residuals, + expect_full_history, + ) + if t != num_predictions - 1: + # Only calculate residuals for the predictions we have data for + residuals[m + t] = data[m + t] - fitted_values[t] + return _aic(residuals, len(params)), residuals, fitted_values + + +@njit(cache=True, fastmath=True) +def _extract_params(params, model): + """Extract ARIMA parameters from the parameter vector.""" + if len(params) != np.sum(model): + previous_length = np.sum(model) + model = model[:-1] # Remove the seasonal period + if len(params) != np.sum(model): + raise ValueError( + f"Expected {previous_length} parameters for a non-seasonal model or \ + {np.sum(model)} parameters for a seasonal model, got {len(params)}" + ) + starts = np.cumsum(np.concatenate((np.zeros(1, dtype=np.int32), model[:-1]))) + n = len(starts) + max_len = np.max(model) + result = np.full((n, max_len), np.nan, dtype=params.dtype) + for i in range(n): + length = model[i] + start = starts[i] + result[i, :length] = params[start : start + length] + return result + + +@njit(cache=True, fastmath=True) +def _calc_arma(data, model, t, formatted_params, residuals, expect_full_history=False): + """Calculate the ARMA forecast for time t.""" + if len(model) != 3: + raise ValueError("Model must be of the form (c, p, q)") + p = model[1] + q = model[2] + if expect_full_history and (t - p < 0 or t - q < 0): + raise ValueError( + f"Insufficient data for ARIMA model at time {t}. " + f"Expected at least {p} past values for AR and {q} for MA." + ) + # AR part + phi = formatted_params[1][:p] + ar_term = 0 if (t - p) < 0 else np.dot(phi, data[t - p : t][::-1]) + + # MA part + theta = formatted_params[2][:q] + ma_term = 0 if (t - q) < 0 else np.dot(theta, residuals[t - q : t][::-1]) + + c = formatted_params[0][0] if model[0] else 0 + y_hat = c + ar_term + ma_term + return y_hat + + +@njit(cache=True, fastmath=True) +def _single_forecast(series, c, phi, theta): + """Calculate the ARMA forecast with fixed model. + + This is equivalent to filter in statsmodels. Assumes differenced if necessary. + """ + p = len(phi) + q = len(theta) + n = len(series) + residuals = np.zeros(n) + max_lag = max(p, q) + # Compute in-sample residuals + for t in range(max_lag, n): + ar_part = np.dot(phi, series[t - np.arange(1, p + 1)]) if p > 0 else 0.0 + ma_part = np.dot(theta, residuals[t - np.arange(1, q + 1)]) if q > 0 else 0.0 + pred = c + ar_part + ma_part + residuals[t] = series[t] - pred + # Forecast next value using most recent p values and q residuals + ar_forecast = np.dot(phi, series[-p:][::-1]) if p > 0 else 0.0 + ma_forecast = np.dot(theta, residuals[-q:][::-1]) if q > 0 else 0.0 + f = c + ar_forecast + ma_forecast + return f diff --git a/aeon/utils/forecasting/__init__.py b/aeon/utils/forecasting/__init__.py new file mode 100644 index 0000000000..a168fa0f11 --- /dev/null +++ b/aeon/utils/forecasting/__init__.py @@ -0,0 +1 @@ +"""Forecasting utils.""" diff --git a/aeon/utils/forecasting/_hypo_tests.py b/aeon/utils/forecasting/_hypo_tests.py new file mode 100644 index 0000000000..2d581e971e --- /dev/null +++ b/aeon/utils/forecasting/_hypo_tests.py @@ -0,0 +1,102 @@ +import numpy as np + + +def kpss_test(y, regression="c", lags=None): # Test if time series is stationary + """ + Perform the KPSS (Kwiatkowski-Phillips-Schmidt-Shin) test for stationarity. + + The KPSS test evaluates the null hypothesis that a time series is + (trend or level) stationary against the alternative of a unit root + (non-stationarity). It can test for either stationarity around a + constant (level stationarity) or arounda deterministic trend + (trend stationarity). + + Parameters + ---------- + y : array-like + Time series data to test for stationarity. + regression : str, default="c" + Indicates the null hypothesis for stationarity: + - "c" : Stationary around a constant (level stationarity) + - "ct" : Stationary around a constant and linear trend (trend stationarity) + lags : int or None, optional + Number of lags to use for the + HAC (heteroskedasticity and autocorrelation consistent) variance estimator. + If None, defaults to sqrt(n), where n is the sample size. + + Returns + ------- + kpss_stat : float + The KPSS test statistic. + stationary : bool + True if the series is judged stationary at the 5% significance level + (i.e., test statistic is below the critical value); False otherwise. + + Notes + ----- + - Uses asymptotic 5% critical values from Kwiatkowski et al. (1992): 0.463 for level + stationarity, 0.146 for trend stationarity. + - Returns True for stationary if the test statistic is below the 5% critical value. + + References + ---------- + Kwiatkowski, D., Phillips, P.C.B., Schmidt, P., & Shin, Y. (1992). + "Testing the null hypothesis of stationarity against the alternative + of a unit root." + Journal of Econometrics, 54(1–3), 159–178. + https://doi.org/10.1016/0304-4076(92)90104-Y + + Examples + -------- + >>> from aeon.utils.forecasting._hypo_tests import kpss_test + >>> from aeon.datasets import load_airline + >>> y = load_airline() + >>> kpss_test(y) + (np.float64(1.1966313813...), np.False_) + """ + y = np.asarray(y) + n = len(y) + + # Step 1: Fit regression model to estimate residuals + if regression == "c": # Constant + X = np.ones((n, 1)) + elif regression == "ct": # Constant + Trend + X = np.column_stack((np.ones(n), np.arange(1, n + 1))) + else: + raise ValueError("regression must be 'c' or 'ct'") + + beta = np.linalg.lstsq(X, y, rcond=None)[0] # Estimate regression coefficients + residuals = y - X @ beta # Get residuals (u_t) + + # Step 2: Compute cumulative sum of residuals (S_t) + S_t = np.cumsum(residuals) + + # Step 3: Estimate long-run variance (HAC variance) + if lags is None: + # lags = int(12 * (n / 100)**(1/4)) # Default statsmodels lag length + lags = int(np.sqrt(n)) # Default lag length + + gamma_0 = np.sum(residuals**2) / (n - X.shape[1]) # Lag-0 autocovariance + gamma = [np.sum(residuals[k:] * residuals[:-k]) / n for k in range(1, lags + 1)] + + # Bartlett weights + weights = [1 - (k / (lags + 1)) for k in range(1, lags + 1)] + + # Long-run variance + sigma_squared = gamma_0 + 2 * np.sum([w * g for w, g in zip(weights, gamma)]) + + # Step 4: Calculate the KPSS statistic + kpss_stat = np.sum(S_t**2) / (n**2 * sigma_squared) + + # 5% critical values for KPSS test + if regression == "ct": + # p. 162 Kwiatkowski et al. (1992): y_t = beta * t + r_t + e_t, + # where beta is the trend, r_t a random walk and e_t a stationary + # error term. + crit = 0.146 + else: # hypo == "c" + # special case of the model above, where beta = 0 (so the null + # hypothesis is that the data is stationary around r_0). + crit = 0.463 + + return kpss_stat, kpss_stat < crit diff --git a/aeon/utils/forecasting/_seasonality.py b/aeon/utils/forecasting/_seasonality.py new file mode 100644 index 0000000000..356b1a40d2 --- /dev/null +++ b/aeon/utils/forecasting/_seasonality.py @@ -0,0 +1,101 @@ +"""Seasonality Tools. + +Includes autocorrelation function (ACF) and seasonal period estimation. +""" + +import numpy as np +from numba import njit + + +@njit(cache=True, fastmath=True) +def acf(X, max_lag): + """ + Compute the sample autocorrelation function (ACF) of a time series. + + Up to a specified maximum lag. + + The autocorrelation at lag k is defined as the Pearson correlation + coefficient between the series and a lagged version of itself. + If both segments at a given lag have zero variance, the function + returns 1 for that lag. If only one segment has zero variance, + the function returns 0. + + Parameters + ---------- + X : array-like, shape (n_samples,) + The input time series data. + max_lag : int + The maximum lag (number of steps) for which to + compute the autocorrelation. + + Returns + ------- + acf_values : np.ndarray, shape (max_lag,) + The autocorrelation values for lags 1 through `max_lag`. + + Notes + ----- + The function handles cases where the lagged segments have zero + variance to avoid division by zero. + The returned values correspond to + lags 1, 2, ..., `max_lag` (not including lag 0). + """ + length = len(X) + X_t = np.zeros(max_lag, dtype=float) + for lag in range(1, max_lag + 1): + lag_length = length - lag + x1 = X[:-lag] + x2 = X[lag:] + s1 = np.sum(x1) + s2 = np.sum(x2) + m1 = s1 / lag_length + m2 = s2 / lag_length + ss1 = np.sum(x1 * x1) + ss2 = np.sum(x2 * x2) + v1 = ss1 - s1 * m1 + v2 = ss2 - s2 * m2 + v1_is_zero, v2_is_zero = v1 <= 1e-9, v2 <= 1e-9 + if v1_is_zero and v2_is_zero: # Both zero variance, + # so must be 100% correlated + X_t[lag - 1] = 1 + elif v1_is_zero or v2_is_zero: # One zero variance + # the other not + X_t[lag - 1] = 0 + else: + X_t[lag - 1] = np.sum((x1 - m1) * (x2 - m2)) / np.sqrt(v1 * v2) + return X_t + + +@njit(cache=True, fastmath=True) +def calc_seasonal_period(data): + """ + Estimate the seasonal period of a time series using autocorrelation analysis. + + This function computes the autocorrelation function (ACF) of + the input series up to lag 24. It then identifies peaks in the + ACF above the mean value, treating the first such peak + as the estimated seasonal period. If no peak is found, + a period of 1 is returned. + + Parameters + ---------- + data : array-like, shape (n_samples,) + The input time series data. + + Returns + ------- + period : int + The estimated seasonal period (lag) of the series. Returns 1 if no significant + peak is detected in the autocorrelation. + """ + lags = acf(data, 24) + lags = np.concatenate((np.array([1.0]), lags)) + peaks = [] + mean_lags = np.mean(lags) + for i in range(1, len(lags) - 1): # Skip the first (lag 0) and last elements + if lags[i] >= lags[i - 1] and lags[i] >= lags[i + 1] and lags[i] > mean_lags: + peaks.append(i) + if not peaks: + return 1 + else: + return peaks[0] diff --git a/aeon/utils/optimisation/__init__.py b/aeon/utils/optimisation/__init__.py new file mode 100644 index 0000000000..11eddea791 --- /dev/null +++ b/aeon/utils/optimisation/__init__.py @@ -0,0 +1 @@ +"""Optimisation utils.""" diff --git a/aeon/utils/optimisation/_nelder_mead.py b/aeon/utils/optimisation/_nelder_mead.py new file mode 100644 index 0000000000..e59a70c5dd --- /dev/null +++ b/aeon/utils/optimisation/_nelder_mead.py @@ -0,0 +1,119 @@ +"""Optimisation algorithms for automatic parameter tuning.""" + +import numpy as np +from numba import njit + + +@njit(fastmath=True) +def nelder_mead( + loss_function, + num_params, + data, + model, + tol=1e-6, + max_iter=500, +): + """ + Perform optimisation using the Nelder–Mead simplex algorithm. + + This function minimises a given loss (objective) function using the Nelder–Mead + algorithm, a derivative-free method that iteratively refines a simplex of candidate + solutions. The implementation supports unconstrained minimisation of functions + with a fixed number of parameters. + + Parameters + ---------- + loss_function : callable + The objective function to minimise. Should accept a 1D NumPy array of length + `num_params` and return a scalar value. + num_params : int + The number of parameters (dimensions) in the optimisation problem. + data : np.ndarray + The input data used by the loss function. The shape and content depend on the + specific loss function being minimised. + model : np.ndarray + The model or context in which the loss function operates. This could be any + other object that the `loss_function` requires to compute its value. + The exact type and structure of `model` should be compatible with the + `loss_function`. + tol : float, optional (default=1e-6) + Tolerance for convergence. The algorithm stops when the maximum difference + between function values at simplex vertices is less than `tol`. + max_iter : int, optional (default=500) + Maximum number of iterations to perform. + + Returns + ------- + best_params : np.ndarray, shape (`num_params`,) + The parameter vector that minimises the loss function. + best_value : float + The value of the loss function at the optimal parameter vector. + + Notes + ----- + - The initial simplex is constructed by setting each parameter to 0.5, + with one additional point per dimension at 0.6 for that dimension. + - This implementation does not support constraints or bounds on the parameters. + - The algorithm does not guarantee finding a global minimum. + + References + ---------- + .. [1] Nelder, J. A. and Mead, R. (1965). + A Simplex Method for Function Minimization. + The Computer Journal, 7(4), 308–313. + https://doi.org/10.1093/comjnl/7.4.308 + """ + points = np.full((num_params + 1, num_params), 0.5) + for i in range(num_params): + points[i + 1][i] = 0.6 + values = np.array([loss_function(v, data, model) for v in points]) + for _iteration in range(max_iter): + # Order simplex by function values + order = np.argsort(values) + points = points[order] + values = values[order] + + # Centroid of the best n points + centre_point = points[:-1].sum(axis=0) / len(points[:-1]) + + # Reflection + # centre + distance between centre and largest value + reflected_point = centre_point + (centre_point - points[-1]) + reflected_value = loss_function(reflected_point, data, model) + # if between best and second best, use reflected value + if len(values) > 1 and values[0] <= reflected_value < values[-2]: + points[-1] = reflected_point + values[-1] = reflected_value + continue + # Expansion + # Otherwise if it is better than the best value + if reflected_value < values[0]: + expanded_point = centre_point + 2 * (reflected_point - centre_point) + expanded_value = loss_function(expanded_point, data, model) + # if less than reflected value use expanded, otherwise go back to reflected + if expanded_value < reflected_value: + points[-1] = expanded_point + values[-1] = expanded_value + else: + points[-1] = reflected_point + values[-1] = reflected_value + continue + # Contraction + # Otherwise if reflection is worse than all current values + contracted_point = centre_point - 0.5 * (centre_point - points[-1]) + contracted_value = loss_function(contracted_point, data, model) + # If contraction is better use that otherwise move to shrinkage + if contracted_value < values[-1]: + points[-1] = contracted_point + values[-1] = contracted_value + continue + + # Shrinkage + for i in range(1, len(points)): + points[i] = points[0] - 0.5 * (points[0] - points[i]) + values[i] = loss_function(points[i], data, model) + + # Convergence check + if np.max(np.abs(values - values[0])) < tol: + break + return points[0], values[0]