Skip to content

Commit c81bbb0

Browse files
authored
[ENH] Added RecurrentRegressor for time series regression (#2894)
* RNN net architecture added * Fixed activation function for both str and list * RNN test mesage corrected * RNN regressor added * init updated w rnn * reverted workflow * reverted workflow * requested changes applied * updated with latest recurrent network * merged with main
1 parent f5e1a28 commit c81bbb0

File tree

2 files changed

+330
-0
lines changed

2 files changed

+330
-0
lines changed

aeon/regression/deep_learning/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
"EncoderRegressor",
1313
"MLPRegressor",
1414
"DisjointCNNRegressor",
15+
"RecurrentRegressor",
1516
]
1617

1718
from aeon.regression.deep_learning._cnn import TimeCNNRegressor
@@ -28,4 +29,5 @@
2829
)
2930
from aeon.regression.deep_learning._mlp import MLPRegressor
3031
from aeon.regression.deep_learning._resnet import ResNetRegressor
32+
from aeon.regression.deep_learning._rnn import RecurrentRegressor
3133
from aeon.regression.deep_learning.base import BaseDeepRegressor

aeon/regression/deep_learning/_rnn.py

Lines changed: 328 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,328 @@
1+
"""Recurrent Neural Network (RNN) for regression."""
2+
3+
from __future__ import annotations
4+
5+
__maintainer__ = [""]
6+
__all__ = ["RecurrentRegressor"]
7+
8+
import gc
9+
import os
10+
import time
11+
from copy import deepcopy
12+
from typing import TYPE_CHECKING, Any
13+
14+
import numpy as np
15+
from sklearn.utils import check_random_state
16+
17+
from aeon.networks import RecurrentNetwork
18+
from aeon.regression.deep_learning.base import BaseDeepRegressor
19+
20+
if TYPE_CHECKING:
21+
import tensorflow as tf
22+
from tensorflow.keras.callbacks import Callback
23+
24+
25+
class RecurrentRegressor(BaseDeepRegressor):
26+
"""
27+
Recurrent Neural Network (RNN) regressor.
28+
29+
Adapted from the implementation used in sktime-dl for time series regression.
30+
31+
Parameters
32+
----------
33+
rnn_type : str, default = "lstm"
34+
Type of RNN layer to use. Options: "lstm", "gru", "simple_rnn"
35+
n_layers : int, default = 1
36+
Number of RNN layers
37+
n_units : int, default = 64
38+
Number of units in each RNN layer
39+
dropout_rate : float, default = 0.2
40+
Dropout rate for regularization
41+
bidirectional : bool, default = False
42+
Whether to use bidirectional RNN layers
43+
activation : str, default = "tanh"
44+
Activation function for RNN layers
45+
return_sequence_last : bool, default = None
46+
Whether RNN layers should return sequences. If None, automatically determined
47+
n_epochs : int, default = 100
48+
Number of epochs to train the model
49+
batch_size : int, default = 32
50+
Number of samples per gradient update
51+
use_mini_batch_size : bool, default = False
52+
Condition on using the mini batch size formula
53+
callbacks : keras callback or list of callbacks, default = None
54+
The default list of callbacks are set to ModelCheckpoint and ReduceLROnPlateau
55+
random_state : int, RandomState instance or None, default=None
56+
If `int`, random_state is the seed used by the random number generator;
57+
If `RandomState` instance, random_state is the random number generator;
58+
If `None`, the random number generator is the `RandomState` instance used
59+
by `np.random`.
60+
file_path : str, default = './'
61+
File path when saving model_Checkpoint callback
62+
save_best_model : bool, default = False
63+
Whether or not to save the best model
64+
save_last_model : bool, default = False
65+
Whether or not to save the last model
66+
save_init_model : bool, default = False
67+
Whether to save the initialization of the model
68+
best_file_name : str, default = "best_model"
69+
The name of the file of the best model
70+
last_file_name : str, default = "last_model"
71+
The name of the file of the last model
72+
init_file_name : str, default = "init_model"
73+
The name of the file of the init model
74+
verbose : bool, default = False
75+
Whether to output extra information
76+
loss : str, default = "mean_squared_error"
77+
The name of the keras training loss
78+
optimizer : keras.optimizer, default = None
79+
The keras optimizer used for training. If None, uses Adam with lr=0.001
80+
metrics : str or list[str], default="mean_squared_error"
81+
The evaluation metrics to use during training
82+
output_activation : str, default = "linear"
83+
The output activation for the regressor
84+
85+
Examples
86+
--------
87+
>>> from aeon.regression.deep_learning import RecurrentRegressor
88+
>>> from aeon.testing.data_generation import make_example_3d_numpy
89+
>>> X, y = make_example_3d_numpy(n_cases=10, n_channels=1, n_timepoints=12,
90+
... return_y=True, regression_target=True,
91+
... random_state=0)
92+
>>> rgs = RecurrentRegressor(n_epochs=20, batch_size=4) # doctest: +SKIP
93+
>>> rgs.fit(X, y) # doctest: +SKIP
94+
RecurrentRegressor(...)
95+
"""
96+
97+
def __init__(
98+
self,
99+
rnn_type: str = "lstm",
100+
n_layers: int = 1,
101+
n_units: int = 64,
102+
dropout_intermediate: float = 0.2,
103+
dropout_output: float = 0.2,
104+
bidirectional: bool = False,
105+
activation: str = "tanh",
106+
return_sequence_last: bool | None = None,
107+
n_epochs: int = 100,
108+
callbacks: Callback | list[Callback] | None = None,
109+
verbose: bool = False,
110+
loss: str = "mean_squared_error",
111+
output_activation: str = "linear",
112+
metrics: str | list[str] = "mean_squared_error",
113+
batch_size: int = 32,
114+
use_mini_batch_size: bool = False,
115+
random_state: int | np.random.RandomState | None = None,
116+
file_path: str = "./",
117+
save_best_model: bool = False,
118+
save_last_model: bool = False,
119+
save_init_model: bool = False,
120+
best_file_name: str = "best_model",
121+
last_file_name: str = "last_model",
122+
init_file_name: str = "init_model",
123+
optimizer: tf.keras.optimizers.Optimizer | None = None,
124+
):
125+
self.rnn_type = rnn_type
126+
self.n_layers = n_layers
127+
self.n_units = n_units
128+
self.dropout_intermediate = dropout_intermediate
129+
self.dropout_output = dropout_output
130+
self.bidirectional = bidirectional
131+
self.activation = activation
132+
self.return_sequence_last = return_sequence_last
133+
self.n_epochs = n_epochs
134+
self.callbacks = callbacks
135+
self.verbose = verbose
136+
self.loss = loss
137+
self.metrics = metrics
138+
self.use_mini_batch_size = use_mini_batch_size
139+
self.random_state = random_state
140+
self.output_activation = output_activation
141+
self.file_path = file_path
142+
self.save_best_model = save_best_model
143+
self.save_last_model = save_last_model
144+
self.save_init_model = save_init_model
145+
self.best_file_name = best_file_name
146+
self.init_file_name = init_file_name
147+
self.optimizer = optimizer
148+
self.history = None
149+
150+
super().__init__(batch_size=batch_size, last_file_name=last_file_name)
151+
152+
self._network = RecurrentNetwork(
153+
rnn_type=self.rnn_type,
154+
n_layers=self.n_layers,
155+
n_units=self.n_units,
156+
dropout_intermediate=self.dropout_intermediate,
157+
dropout_output=self.dropout_output,
158+
bidirectional=self.bidirectional,
159+
activation=self.activation,
160+
return_sequence_last=self.return_sequence_last,
161+
)
162+
163+
def build_model(
164+
self, input_shape: tuple[int, ...], **kwargs: Any
165+
) -> tf.keras.Model:
166+
"""
167+
Construct a compiled, un-trained, keras model that is ready for training.
168+
169+
In aeon, time series are stored in numpy arrays of shape (d,m), where d
170+
is the number of dimensions, m is the series length. Keras/tensorflow assume
171+
data is in shape (m,d). This method also assumes (m,d). Transpose should
172+
happen in fit.
173+
174+
Parameters
175+
----------
176+
input_shape : tuple
177+
The shape of the data fed into the input layer, should be (m,d)
178+
179+
Returns
180+
-------
181+
output : a compiled Keras Model
182+
"""
183+
import tensorflow as tf
184+
185+
self.optimizer_ = (
186+
tf.keras.optimizers.Adam(learning_rate=0.001)
187+
if self.optimizer is None
188+
else self.optimizer
189+
)
190+
191+
rng = check_random_state(self.random_state)
192+
self.random_state_ = rng.randint(0, np.iinfo(np.int32).max)
193+
tf.keras.utils.set_random_seed(self.random_state_)
194+
195+
input_layer, output_layer = self._network.build_network(input_shape, **kwargs)
196+
197+
output_layer = tf.keras.layers.Dense(
198+
units=1,
199+
activation=self.output_activation,
200+
)(output_layer)
201+
202+
model = tf.keras.models.Model(inputs=input_layer, outputs=output_layer)
203+
204+
model.compile(
205+
loss=self.loss,
206+
optimizer=self.optimizer_,
207+
metrics=self._metrics,
208+
)
209+
210+
return model
211+
212+
def _fit(self, X: np.ndarray, y: np.ndarray) -> RecurrentRegressor:
213+
"""
214+
Fit the regressor on the training set (X, y).
215+
216+
Parameters
217+
----------
218+
X : np.ndarray
219+
The training input samples of shape (n_cases, n_channels, n_timepoints).
220+
y : np.ndarray
221+
The training data target values of shape (n_cases,).
222+
223+
Returns
224+
-------
225+
self : object
226+
"""
227+
import tensorflow as tf
228+
229+
# Transpose to conform to Keras input style.
230+
X = X.transpose(0, 2, 1)
231+
232+
if isinstance(self.metrics, list):
233+
self._metrics = self.metrics
234+
elif isinstance(self.metrics, str):
235+
self._metrics = [self.metrics]
236+
237+
self.input_shape = X.shape[1:]
238+
self.training_model_ = self.build_model(self.input_shape)
239+
240+
if self.save_init_model:
241+
self.training_model_.save(self.file_path + self.init_file_name + ".keras")
242+
243+
if self.verbose:
244+
self.training_model_.summary()
245+
246+
self.file_name_ = (
247+
self.best_file_name if self.save_best_model else str(time.time_ns())
248+
)
249+
250+
if self.callbacks is None:
251+
self.callbacks_ = [
252+
tf.keras.callbacks.ReduceLROnPlateau(
253+
monitor="loss", factor=0.5, patience=50, min_lr=0.0001
254+
),
255+
tf.keras.callbacks.ModelCheckpoint(
256+
filepath=self.file_path + self.file_name_ + ".keras",
257+
monitor="loss",
258+
save_best_only=True,
259+
),
260+
]
261+
else:
262+
self.callbacks_ = self._get_model_checkpoint_callback(
263+
callbacks=self.callbacks,
264+
file_path=self.file_path,
265+
file_name=self.file_name_,
266+
)
267+
268+
if self.use_mini_batch_size:
269+
mini_batch_size = min(self.batch_size, X.shape[0] // 10)
270+
else:
271+
mini_batch_size = self.batch_size
272+
273+
self.history = self.training_model_.fit(
274+
X,
275+
y,
276+
batch_size=mini_batch_size,
277+
epochs=self.n_epochs,
278+
verbose=self.verbose,
279+
callbacks=self.callbacks_,
280+
)
281+
282+
try:
283+
self.model_ = tf.keras.models.load_model(
284+
self.file_path + self.file_name_ + ".keras", compile=False
285+
)
286+
if not self.save_best_model:
287+
os.remove(self.file_path + self.file_name_ + ".keras")
288+
except FileNotFoundError:
289+
self.model_ = deepcopy(self.training_model_)
290+
291+
if self.save_last_model:
292+
self.save_last_model_to_file(file_path=self.file_path)
293+
294+
gc.collect()
295+
return self
296+
297+
@classmethod
298+
def _get_test_params(
299+
cls, parameter_set: str = "default"
300+
) -> dict[str, Any] | list[dict[str, Any]]:
301+
"""
302+
Return testing parameter settings for the estimator.
303+
304+
Parameters
305+
----------
306+
parameter_set : str, default="default"
307+
Name of the set of test parameters to return, for use in tests. If no
308+
special parameters are defined for a value, will return `"default"` set.
309+
For regressors, a "default" set of parameters should be provided for
310+
general testing, and a "results_comparison" set for comparing against
311+
previously recorded results if the general set does not produce suitable
312+
probabilities to compare against.
313+
314+
Returns
315+
-------
316+
params : dict or list of dict, default={}
317+
Parameters to create testing instances of the class.
318+
Each dict are parameters to construct an "interesting" test instance, i.e.,
319+
`MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance.
320+
"""
321+
param = {
322+
"n_epochs": 10,
323+
"batch_size": 4,
324+
"n_layers": 1,
325+
"n_units": 6,
326+
"rnn_type": "lstm",
327+
}
328+
return [param]

0 commit comments

Comments
 (0)