Skip to content

[ENH] Add Basic ARIMA model #2860

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 39 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
d381d5e
arima first
TonyBagnall May 24, 2025
3a0552b
move utils
TonyBagnall May 24, 2025
0ac5380
make functions private
TonyBagnall May 24, 2025
44b36a7
Modularise SARIMA model
May 28, 2025
6d18de9
Add ARIMA forecaster to forecasting package
May 28, 2025
b7e6424
Add example to ARIMA forecaster, this also tests the forecaster is pr…
May 28, 2025
e33fa4d
Basic ARIMA model
May 28, 2025
f613f7e
Convert ARIMA to numba version
May 28, 2025
a6b708c
Merge branch 'main' into arb/base_arima
alexbanwell1 May 28, 2025
9eb00f6
Adjust parameters to allow modification in fit
May 28, 2025
d4ed4b1
Update example and return native python type
May 28, 2025
2893e1b
Fix examples for tests
May 28, 2025
9801e8b
Fix Nelder-Mead Optimisation Algorithm Example
May 28, 2025
2f928c7
Fix Nelder-Mead Optimisation Algorithm Example #2
May 28, 2025
94cd5b3
Remove Nelder-Mead Example due to issues with numba caching functions
May 28, 2025
0d0d63f
Fix return type issue
May 28, 2025
39a3ed2
Address PR Feedback
May 28, 2025
05a2785
Ignore small tolerances in floating point value in output of example
May 28, 2025
73966ab
Fix kpss_test example
May 28, 2025
a0f090d
Fix kpss_test example #2
May 28, 2025
6884703
Update documentation for ARIMAForecaster, change constant_term to be …
Jun 2, 2025
44a8647
Merge branch 'main' into arb/base_arima
alexbanwell1 Jun 2, 2025
9af3a56
Modify ARIMA to allow predicting multiple values by updating the stat…
Jun 8, 2025
4c63af5
Merge branch 'main' into arb/base_arima
TonyBagnall Jun 9, 2025
e898f2f
Fix bug using self.d rather than self.d_
Jun 9, 2025
11c4987
Merge branch 'arb/base_arima' of https://github.com/aeon-toolkit/aeon…
Jun 9, 2025
6314a6f
Merge branch 'main' into arb/base_arima
TonyBagnall Jun 11, 2025
72b7980
Merge branch 'main' into arb/base_arima
TonyBagnall Jun 11, 2025
3c644a0
refactor ARIMA
TonyBagnall Jun 11, 2025
350252e
Merge branch 'main' into arb/base_arima
MatthewMiddlehurst Jun 16, 2025
1bd6a32
Merge branch 'main' into arb/base_arima
TonyBagnall Jun 16, 2025
b91d135
docstring
TonyBagnall Jun 16, 2025
420cd72
Merge branch 'main' into arb/base_arima
TonyBagnall Jun 18, 2025
061f286
Merge branch 'main' into arb/base_arima
TonyBagnall Jun 21, 2025
149c0ad
find forecast_ in fit
TonyBagnall Jun 21, 2025
745806e
Merge branch 'main' into arb/base_arima
MatthewMiddlehurst Jul 4, 2025
1d300a4
Merge branch 'main' into arb/base_arima
TonyBagnall Jul 10, 2025
9d8b24f
remove optional y
TonyBagnall Jul 10, 2025
d9b1e7a
add iterative
TonyBagnall Jul 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions aeon/forecasting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
"BaseForecaster",
"RegressionForecaster",
"ETSForecaster",
"ARIMAForecaster",
]

from aeon.forecasting._arima import ARIMAForecaster
from aeon.forecasting._ets import ETSForecaster
from aeon.forecasting._naive import NaiveForecaster
from aeon.forecasting._regression import RegressionForecaster
Expand Down
235 changes: 235 additions & 0 deletions aeon/forecasting/_arima.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
"""ARIMAForecaster.

An implementation of the ARIMA forecasting algorithm.
"""

__maintainer__ = ["alexbanwell1", "TonyBagnall"]
__all__ = ["ARIMAForecaster"]

from math import comb

import numpy as np
from numba import njit

from aeon.forecasting.base import BaseForecaster
from aeon.utils.optimisation._nelder_mead import nelder_mead

NOGIL = False
CACHE = True


class ARIMAForecaster(BaseForecaster):
"""AutoRegressive Integrated Moving Average (ARIMA) forecaster.

The model automatically selects the parameters of the model based
on information criteria, such as AIC.

Parameters
----------
horizon : int, default=1
The forecasting horizon, i.e., the number of steps ahead to predict.

Attributes
----------
data_ : list of float
Original training series values.
differenced_data_ : list of float
Differenced version of the training data used for stationarity.
residuals_ : list of float
Residual errors from the fitted model.
aic_ : float
Akaike Information Criterion for the selected model.
p_, d_, q_ : int
Orders of the ARIMA model: autoregressive (p), differencing (d),
and moving average (q) terms.
constant_term_ : float
Constant/intercept term in the model.
c_ : float
Estimated constant term (internal use).
phi_ : array-like
Coefficients for the non-seasonal autoregressive terms.
theta_ : array-like
Coefficients for the non-seasonal moving average terms.

References
----------
.. [1] R. J. Hyndman and G. Athanasopoulos,
Forecasting: Principles and Practice. OTexts, 2014.
https://otexts.com/fpp3/

Examples
--------
>>> from aeon.forecasting import ARIMAForecaster
>>> from aeon.datasets import load_airline
>>> y = load_airline()
>>> forecaster = ARIMAForecaster(p=2,d=1)
>>> forecaster.fit(y)
ARIMAForecaster(d=1, p=2)
>>> forecaster.predict()
550.9147246631132
"""

def __init__(self, p=1, d=0, q=1, constant_term=0, horizon=1):
super().__init__(horizon=horizon, axis=1)
self.data_ = []
self.differenced_data_ = []
self.residuals_ = []
self.aic_ = 0
self.p = p
self.d = d
self.q = q
self.constant_term = constant_term
self.p_ = 0
self.d_ = 0
self.q_ = 0
self.constant_term_ = 0
self.model_ = []
self.c_ = 0
self.phi_ = 0
self.theta_ = 0
self.parameters_ = []

def _fit(self, y, exog=None):
"""Fit AutoARIMA forecaster to series y.

Fit a forecaster to predict self.horizon steps ahead using y.

Parameters
----------
y : np.ndarray
A time series on which to learn a forecaster to predict horizon ahead
exog : np.ndarray, default =None
Optional exogenous time series data assumed to be aligned with y

Returns
-------
self
Fitted ARIMAForecaster.
"""
self.p_ = self.p
self.d_ = self.d
self.q_ = self.q
self.constant_term_ = self.constant_term
self.data_ = np.array(y.squeeze(), dtype=np.float64)
self.model_ = np.array((self.constant_term, self.p, self.q), dtype=np.int32)
self.differenced_data_ = np.diff(self.data_, n=self.d)
(self.parameters_, self.aic_) = nelder_mead(
_arima_model_wrapper,
np.sum(self.model_[:3]),
self.data_,
self.model_,
)
(self.c_, self.phi_, self.theta_) = _extract_params(
self.parameters_, self.model_
)
(self.aic_, self.residuals_) = _arima_model(
self.parameters_, _calc_arima, self.differenced_data_, self.model_
)
return self

def _predict(self, y=None, exog=None):
"""
Predict the next horizon steps ahead.

Parameters
----------
y : np.ndarray, default = None
A time series to predict the next horizon value for. If None,
predict the next horizon value after series seen in fit.
exog : np.ndarray, default =None
Optional exogenous time series data assumed to be aligned with y

Returns
-------
float
single prediction self.horizon steps ahead of y.
"""
y = np.array(y, dtype=np.float64)
value = _calc_arima(
self.differenced_data_,
self.model_,
len(self.differenced_data_),
_extract_params(self.parameters_, self.model_),
self.residuals_,
)
history = self.data_[::-1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be y if its not None

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, currently it's only doing 1 ahead, will need to adjust this!

# Step 2: undo ordinary differencing
for k in range(1, self.d_ + 1):
value += (-1) ** (k + 1) * comb(self.d_, k) * history[k - 1]
return value.item()


@njit(cache=True, fastmath=True)
def _aic(residuals, num_params):
"""Calculate the log-likelihood of a model."""
variance = np.mean(residuals**2)
liklihood = len(residuals) * (np.log(2 * np.pi) + np.log(variance) + 1)
return liklihood + 2 * num_params


@njit(fastmath=True)
def _arima_model_wrapper(params, data, model):
return _arima_model(params, _calc_arima, data, model)[0]


# Define the ARIMA(p, d, q) likelihood function
@njit(cache=True, fastmath=True)
def _arima_model(params, base_function, data, model):
"""Calculate the log-likelihood of an ARIMA model given the parameters."""
formatted_params = _extract_params(params, model) # Extract parameters

# Initialize residuals
n = len(data)
residuals = np.zeros(n)
for t in range(n):
y_hat = base_function(
data,
model,
t,
formatted_params,
residuals,
)
residuals[t] = data[t] - y_hat
return _aic(residuals, len(params)), residuals


@njit(cache=True, fastmath=True)
def _extract_params(params, model):
"""Extract ARIMA parameters from the parameter vector."""
if len(params) != np.sum(model):
previous_length = np.sum(model)
model = model[:-1] # Remove the seasonal period
if len(params) != np.sum(model):
raise ValueError(
f"Expected {previous_length} parameters for a non-seasonal model or \
{np.sum(model)} parameters for a seasonal model, got {len(params)}"
)
starts = np.cumsum(np.concatenate((np.zeros(1, dtype=np.int32), model[:-1])))
n = len(starts)
max_len = np.max(model)
result = np.full((n, max_len), np.nan, dtype=params.dtype)
for i in range(n):
length = model[i]
start = starts[i]
result[i, :length] = params[start : start + length]
return result


@njit(cache=True, fastmath=True)
def _calc_arima(data, model, t, formatted_params, residuals):
"""Calculate the ARIMA forecast for time t."""
if len(model) != 3:
raise ValueError("Model must be of the form (c, p, q)")
# AR part
p = model[1]
phi = formatted_params[1][:p]
ar_term = 0 if (t - p) < 0 else np.dot(phi, data[t - p : t][::-1])

# MA part
q = model[2]
theta = formatted_params[2][:q]
ma_term = 0 if (t - q) < 0 else np.dot(theta, residuals[t - q : t][::-1])

c = formatted_params[0][0] if model[0] else 0
y_hat = c + ar_term + ma_term
return y_hat
1 change: 1 addition & 0 deletions aeon/utils/forecasting/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Forecasting utils."""
63 changes: 63 additions & 0 deletions aeon/utils/forecasting/_hypo_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import numpy as np


def kpss_test(y, regression="c", lags=None): # Test if time series is stationary
"""
Implement the KPSS test for stationarity.

Parameters
----------
y (array-like): Time series data
regression (str): 'c' for constant, 'ct' for constant + trend
lags (int): Number of lags for HAC variance estimation (default: sqrt(n))

Returns
-------
kpss_stat (float): KPSS test statistic
stationary (bool): Whether the series is stationary according to the test
"""
y = np.asarray(y)
n = len(y)

# Step 1: Fit regression model to estimate residuals
if regression == "c": # Constant
X = np.ones((n, 1))
elif regression == "ct": # Constant + Trend
X = np.column_stack((np.ones(n), np.arange(1, n + 1)))
else:
raise ValueError("regression must be 'c' or 'ct'")

beta = np.linalg.lstsq(X, y, rcond=None)[0] # Estimate regression coefficients
residuals = y - X @ beta # Get residuals (u_t)

# Step 2: Compute cumulative sum of residuals (S_t)
S_t = np.cumsum(residuals)

# Step 3: Estimate long-run variance (HAC variance)
if lags is None:
# lags = int(12 * (n / 100)**(1/4)) # Default statsmodels lag length
lags = int(np.sqrt(n)) # Default lag length

gamma_0 = np.sum(residuals**2) / (n - X.shape[1]) # Lag-0 autocovariance
gamma = [np.sum(residuals[k:] * residuals[:-k]) / n for k in range(1, lags + 1)]

# Bartlett weights
weights = [1 - (k / (lags + 1)) for k in range(1, lags + 1)]

# Long-run variance
sigma_squared = gamma_0 + 2 * np.sum([w * g for w, g in zip(weights, gamma)])

# Step 4: Calculate the KPSS statistic
kpss_stat = np.sum(S_t**2) / (n**2 * sigma_squared)

if regression == "ct":
# p. 162 Kwiatkowski et al. (1992): y_t = beta * t + r_t + e_t,
# where beta is the trend, r_t a random walk and e_t a stationary
# error term.
crit = 0.146
else: # hypo == "c"
# special case of the model above, where beta = 0 (so the null
# hypothesis is that the data is stationary around r_0).
crit = 0.463

return kpss_stat, kpss_stat < crit
Loading
Loading