Skip to content

Commit e2ef4f1

Browse files
committed
Merge branch 'main' into ajb/ets_rework
2 parents 3022d7b + 76abbdd commit e2ef4f1

File tree

5 files changed

+382
-27
lines changed

5 files changed

+382
-27
lines changed

aeon/forecasting/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
"BaseForecaster",
66
"RegressionForecaster",
77
"ETSForecaster",
8+
"TVPForecaster",
89
]
910

1011
from aeon.forecasting._ets import ETSForecaster
1112
from aeon.forecasting._naive import NaiveForecaster
1213
from aeon.forecasting._regression import RegressionForecaster
14+
from aeon.forecasting._tvp import TVPForecaster
1315
from aeon.forecasting.base import BaseForecaster

aeon/forecasting/_tvp.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
"""Time-Varying Parameter (TVP) Forecaster using Kalman filter."""
2+
3+
import numpy as np
4+
5+
from aeon.forecasting.base import BaseForecaster
6+
7+
8+
class TVPForecaster(BaseForecaster):
9+
r"""Time-Varying Parameter (TVP) Forecaster using Kalman filter as described in [1].
10+
11+
This forecaster models the target series using a time-varying linear autoregression:
12+
13+
.. math::
14+
15+
\\hat{y}_t = \beta_0,t+\beta_1,t * y_{t-1} + ... + \beta_k,t * y_{t-k}
16+
17+
where the coefficients $\beta_t$ evolve based on observations $y_t$. At each
18+
step, a weight vector is calculated based in the latest residual. This is used to
19+
adjust the $\beta$ parameter values and the estimate of parameter variance.
20+
21+
TVP can be considered as related to stochastic gradient descent (SGD) regression,
22+
with the update weight being the dynamically calculated Kalman gain based on the
23+
covariance of the parameters rather than a fixed learning rate.
24+
25+
Parameters
26+
----------
27+
window : int
28+
Number of autoregressive lags to use, called window to co-ordinate with
29+
RegressionForecaster.
30+
var : float, default=0.01
31+
Observation noise variance. ``var`` controls the influence of recency in the
32+
update. A small var (such as the default 0.01) means the parameters will be
33+
more
34+
affected by
35+
recent values. A large var (e.g., 1.0 or more) means the observations are
36+
noisy, so the filter will adjust the parameters less to match recent values.
37+
beta_var : float, default=0.01
38+
State evolution noise variance, applied to all coefficients at each step. Small
39+
``beta_var`` leads to slowly evolving parameters.
40+
41+
References
42+
----------
43+
.. [1] Durbin & Koopman, Time Series Analysis by State Space Methods
44+
Oxford University Press, 2nd Edition, 2012
45+
"""
46+
47+
def __init__(self, window, horizon=1, var=0.01, beta_var=0.01):
48+
self.window = window
49+
self.var = var
50+
self.beta_var = beta_var
51+
super().__init__(axis=1, horizon=horizon)
52+
53+
def _fit(self, y, exog=None):
54+
y = y.squeeze()
55+
56+
# Create autoregressive design matrix
57+
X = np.lib.stride_tricks.sliding_window_view(y, window_shape=self.window)
58+
X = X[: -self.horizon]
59+
ones = np.ones((X.shape[0], 1))
60+
X = np.hstack([ones, X]) # Add intercept column
61+
62+
y_train = y[self.window + self.horizon - 1 :]
63+
64+
# Kalman filter initialisation
65+
k = X.shape[1] # number of coefficients (lags + intercept)
66+
beta = np.zeros(k)
67+
beta_covariance = np.eye(k)
68+
beta_var = self.beta_var * np.eye(k)
69+
70+
for t in range(len(y_train)):
71+
x_t = X[t]
72+
y_t = y_train[t]
73+
74+
# Predict covariance
75+
beta_covariance = beta_covariance + beta_var
76+
77+
# Forecast error
78+
error_t = y_t - x_t @ beta
79+
total_variance = x_t @ beta_covariance @ x_t + self.var
80+
kalman_weight = beta_covariance @ x_t / total_variance
81+
82+
# Update beta parameters with kalman weights times error.
83+
beta = beta + kalman_weight * error_t
84+
beta_covariance = (
85+
beta_covariance - np.outer(kalman_weight, x_t) @ beta_covariance
86+
)
87+
88+
self._beta = beta
89+
self._last_window = y[-self.window :]
90+
self.forecast_ = (
91+
np.insert(self._last_window, 0, 1.0) @ self._beta
92+
) # include intercept
93+
return self
94+
95+
def _predict(self, y, exog=None):
96+
y = y.squeeze()
97+
x_t = np.insert(y[-self.window :], 0, 1.0) # include intercept term
98+
y_hat = x_t @ self._beta
99+
return y_hat
100+
101+
@classmethod
102+
def _get_test_params(cls, parameter_set: str = "default"):
103+
"""Return testing parameter settings for the estimator.
104+
105+
Parameters
106+
----------
107+
parameter_set : str, default='default'
108+
Name of the parameter set to return.
109+
110+
Returns
111+
-------
112+
dict
113+
Dictionary of testing parameter settings.
114+
"""
115+
return {"window": 4}

aeon/forecasting/tests/test_tvp.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
"""Test TVP forecaster.
2+
3+
Tests include convergence properties described in Durbin & Koopman, 2012.
4+
5+
"""
6+
7+
import numpy as np
8+
9+
from aeon.forecasting._tvp import TVPForecaster
10+
11+
12+
def test_direct():
13+
"""Test aeon TVP Forecaster equivalent to statsmodels."""
14+
expected = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0])
15+
tvp = TVPForecaster(window=5, horizon=1, var=0.01, beta_var=0.01)
16+
p = tvp.forecast(expected)
17+
p2 = tvp.direct_forecast(expected, prediction_horizon=5)
18+
assert p == p2[0]
19+
20+
21+
def test_static_ar1_convergence_to_ols():
22+
"""Test TVPForecaster converges to the OLS solution for a static AR(1) process."""
23+
# Simulate AR(1) data with constant parameters
24+
rng = np.random.RandomState(0)
25+
true_phi = 0.6
26+
true_intercept = 2.0
27+
noise_std = 0.5
28+
n = 500
29+
y = np.zeros(n)
30+
# Initialize y[0] near the steady-state mean to avoid startup bias
31+
y[0] = true_intercept / (1 - true_phi)
32+
for t in range(1, n):
33+
y[t] = true_intercept + true_phi * y[t - 1] + rng.normal(0, noise_std)
34+
# Fit with beta_var=0 (no parameter drift) and observation variance = noise_var
35+
forecaster = TVPForecaster(window=1, horizon=1, var=noise_std**2, beta_var=0.0)
36+
forecaster.fit(y)
37+
beta_est = forecaster._beta # [intercept, phi] estimated
38+
# Compute static OLS estimates for comparison
39+
X = np.vstack(
40+
[np.ones(n - 1), y[: n - 1]]
41+
).T # regress y[t] on [1, y[t-1]] for t=1..n-1
42+
y_resp = y[1:]
43+
beta_ols, *_ = np.linalg.lstsq(X, y_resp, rcond=None)
44+
# The TVP forecaster (with no drift) should converge to OLS estimates
45+
assert beta_est.shape == (2,)
46+
# Check that estimated parameters are close to OLS solution
47+
assert np.allclose(beta_est, beta_ols, atol=0.1)
48+
# Also check they are close to true parameters
49+
assert abs(beta_est[0] - true_intercept) < 0.2
50+
assert abs(beta_est[1] - true_phi) < 0.1
51+
52+
53+
def test_tvp_adapts_to_changing_coefficient():
54+
"""Test TVP adapts its parameters when the true AR(1) coefficient changes."""
55+
rng = np.random.RandomState(42)
56+
# Piecewise AR(1): phi changes from 0.2 to 0.8 at t=100, intercept remains 1.0
57+
n = 200
58+
phi1, phi2 = 0.2, 0.8
59+
intercept = 1.0
60+
noise_std = 0.05
61+
y = np.zeros(n)
62+
# Start near the mean of first regime
63+
y[0] = intercept / (1 - phi1)
64+
# First half (t=1 to 99) with phi1
65+
for t in range(1, 100):
66+
y[t] = intercept + phi1 * y[t - 1] + rng.normal(0, noise_std)
67+
# Second half (t=100 to 199) with phi2
68+
for t in range(100, n):
69+
y[t] = intercept + phi2 * y[t - 1] + rng.normal(0, noise_std)
70+
# Fit TVPForecaster with nonzero beta_var to allow parameter drift
71+
forecaster = TVPForecaster(window=1, horizon=1, var=noise_std**2, beta_var=0.1)
72+
forecaster.fit(y)
73+
beta_final = forecaster._beta
74+
# Compute OLS on first and second half segments for reference
75+
X1 = np.vstack([np.ones(99), y[:99]]).T
76+
y1 = y[1:100]
77+
beta1_ols, *_ = np.linalg.lstsq(X1, y1, rcond=None)
78+
# use points 100..198 to predict 101..199
79+
X2 = np.vstack([np.ones(n - 101), y[100 : n - 1]]).T
80+
y2 = y[101:n]
81+
beta2_ols, *_ = np.linalg.lstsq(X2, y2, rcond=None)
82+
# The final estimated phi should be much closer to phi2 than phi1
83+
estimated_intercept, estimated_phi = beta_final[0], beta_final[1]
84+
# Validate that phi coefficient increased towards phi2
85+
assert estimated_phi > 0.5 # moved well above the initial ~0.2
86+
assert abs(estimated_phi - phi2) < 0.1 # close to the new true phi
87+
# Validate intercept remains reasonable (around true intercept)
88+
assert abs(estimated_intercept - intercept) < 0.5
89+
# Check that final phi is closer to second-half OLS estimate than first-half
90+
assert abs(estimated_phi - beta2_ols[1]) < abs(estimated_phi - beta1_ols[1])

docs/api_reference/forecasting.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,5 @@
1111
NaiveForecaster
1212
RegressionForecaster
1313
ETSForecaster
14+
TVPForecaster
1415
```

0 commit comments

Comments
 (0)