Skip to content

Commit 235e9d0

Browse files
authored
[ENH] Added RNN in networks (#2875)
* RNN net architecture added * Fixed activation function for both str and list * RNN test mesage corrected * reverted workflow * reverted workflow * requested changes applied * convo resolved
1 parent ace4b9c commit 235e9d0

File tree

3 files changed

+795
-0
lines changed

3 files changed

+795
-0
lines changed

aeon/networks/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
"AEDRNNNetwork",
1919
"AEBiGRUNetwork",
2020
"DisjointCNNNetwork",
21+
"RecurrentNetwork",
2122
]
2223
from aeon.networks._ae_abgru import AEAttentionBiGRUNetwork
2324
from aeon.networks._ae_bgru import AEBiGRUNetwork
@@ -34,4 +35,5 @@
3435
from aeon.networks._lite import LITENetwork
3536
from aeon.networks._mlp import MLPNetwork
3637
from aeon.networks._resnet import ResNetNetwork
38+
from aeon.networks._rnn import RecurrentNetwork
3739
from aeon.networks.base import BaseDeepLearningNetwork

aeon/networks/_rnn.py

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
"""Implements a Recurrent Neural Network (RNN) for time series forecasting."""
2+
3+
__maintainer__ = []
4+
5+
from aeon.networks.base import BaseDeepLearningNetwork
6+
7+
8+
class RecurrentNetwork(BaseDeepLearningNetwork):
9+
"""
10+
Implements a Recurrent Neural Network (RNN) for time series forecasting.
11+
12+
This implementation provides a flexible RNN architecture that can be configured
13+
to use different types of recurrent cells including Simple RNN, Long Short-Term
14+
Memory (LSTM) [1], and Gated Recurrent Unit (GRU) [2]. The network supports
15+
multiple layers, bidirectional processing, and various dropout configurations
16+
for regularization.
17+
18+
Parameters
19+
----------
20+
rnn_type : str, default='lstm'
21+
Type of RNN cell to use ('lstm', 'gru', or 'simple').
22+
n_layers : int, default=1
23+
Number of recurrent layers.
24+
n_units : list or int, default=64
25+
Number of units in each recurrent layer. If an int, the same number
26+
of units is used in each layer. If a list, specifies the number of
27+
units for each layer and must match the number of layers.
28+
dropout_intermediate : float, default=0.0
29+
Dropout rate applied after each intermediate recurrent layer (not last layer).
30+
dropout_output : float, default=0.0
31+
Dropout rate applied after the last recurrent layer.
32+
bidirectional : bool, default=False
33+
Whether to use bidirectional recurrent layers.
34+
activation : str or list of str, default='tanh'
35+
Activation function(s) for the recurrent layers. If a string, the same
36+
activation is used for all layers. If a list, specifies activation for
37+
each layer and must match the number of layers.
38+
return_sequence_last : bool, default=False
39+
Whether the last recurrent layer returns the full sequence (True)
40+
or just the last output (False).
41+
42+
References
43+
----------
44+
.. [1] Hochreiter, S., & Schmidhuber, J. (1997). Long short-term memory.
45+
Neural computation, 9(8), 1735-1780.
46+
.. [2] Cho, K., Van Merriënboer, B., Gulcehre, C., Bahdanau, D., Bougares, F.,
47+
Schwenk, H., & Bengio, Y. (2014). Learning phrase representations using
48+
RNN encoder-decoder for statistical machine translation.
49+
arXiv preprint arXiv:1406.1078.
50+
"""
51+
52+
_config = {
53+
"python_dependencies": ["tensorflow"],
54+
"python_version": "<3.13",
55+
"structure": "encoder",
56+
}
57+
58+
def __init__(
59+
self,
60+
rnn_type="simple",
61+
n_layers=1,
62+
n_units=64,
63+
dropout_intermediate=0.0,
64+
dropout_output=0.0,
65+
bidirectional=False,
66+
activation="tanh",
67+
return_sequence_last=False,
68+
):
69+
super().__init__()
70+
self.rnn_type = rnn_type.lower()
71+
self.n_layers = n_layers
72+
self.n_units = n_units
73+
self.dropout_intermediate = dropout_intermediate
74+
self.dropout_output = dropout_output
75+
self.bidirectional = bidirectional
76+
self.activation = activation
77+
self.return_sequence_last = return_sequence_last
78+
self._rnn_cell = None
79+
80+
def build_network(self, input_shape, **kwargs):
81+
"""Construct a network and return its input and output layers.
82+
83+
Parameters
84+
----------
85+
input_shape : tuple
86+
The shape of the data fed into the input layer (n_timepoints, n_features)
87+
kwargs : dict
88+
Additional keyword arguments to be passed to the network
89+
90+
Returns
91+
-------
92+
input_layer : a keras layer
93+
output_layer : a keras layer
94+
"""
95+
import tensorflow as tf
96+
97+
# Validate parameters
98+
if self.rnn_type not in ["lstm", "gru", "simple"]:
99+
raise ValueError(
100+
f"Unknown RNN type: {self.rnn_type}. Should be 'lstm', 'gru' 'simple'"
101+
)
102+
103+
# Process n_units to a list
104+
if isinstance(self.n_units, list):
105+
if len(self.n_units) != self.n_layers:
106+
raise ValueError(
107+
f"Number of units {len(self.n_units)} should be"
108+
f" the same as number of layers but is"
109+
f" not: {self.n_layers}"
110+
)
111+
self._n_units = self.n_units
112+
else:
113+
self._n_units = [self.n_units] * self.n_layers
114+
115+
# Process activation to a list
116+
if isinstance(self.activation, list):
117+
if len(self.activation) != self.n_layers:
118+
raise ValueError(
119+
f"Number of activations {len(self.activation)} should be"
120+
f" the same as number of layers but is"
121+
f" not: {self.n_layers}"
122+
)
123+
self._activation = self.activation
124+
else:
125+
self._activation = [self.activation] * self.n_layers
126+
127+
# Select RNN cell type
128+
if self.rnn_type == "lstm":
129+
self._rnn_cell = tf.keras.layers.LSTM
130+
elif self.rnn_type == "gru":
131+
self._rnn_cell = tf.keras.layers.GRU
132+
else: # simple
133+
self._rnn_cell = tf.keras.layers.SimpleRNN
134+
135+
# Create input layer
136+
input_layer = tf.keras.layers.Input(shape=input_shape)
137+
x = input_layer
138+
139+
# Build RNN layers
140+
for i in range(self.n_layers):
141+
# Determine return_sequences for current layer
142+
# All layers except the last must return sequences for stacking
143+
# The last layer uses the return_sequence_last parameter
144+
is_last_layer = i == (self.n_layers - 1)
145+
return_sequences = (not is_last_layer) or self.return_sequence_last
146+
147+
# Create the recurrent layer
148+
if self.bidirectional:
149+
x = tf.keras.layers.Bidirectional(
150+
self._rnn_cell(
151+
units=self._n_units[i],
152+
activation=self._activation[i],
153+
return_sequences=return_sequences,
154+
name=f"{self.rnn_type}_{i+1}",
155+
)
156+
)(x)
157+
else:
158+
x = self._rnn_cell(
159+
units=self._n_units[i],
160+
activation=self._activation[i],
161+
return_sequences=return_sequences,
162+
name=f"{self.rnn_type}_{i+1}",
163+
)(x)
164+
165+
# Add appropriate dropout based on layer position
166+
if is_last_layer:
167+
# Apply output dropout to the last layer
168+
if self.dropout_output > 0:
169+
x = tf.keras.layers.Dropout(
170+
self.dropout_output, name="dropout_output"
171+
)(x)
172+
else:
173+
# Apply intermediate dropout to all layers except the last
174+
if self.dropout_intermediate > 0:
175+
x = tf.keras.layers.Dropout(
176+
self.dropout_intermediate, name=f"dropout_intermediate_{i+1}"
177+
)(x)
178+
179+
# Return input and output layers
180+
return input_layer, x

0 commit comments

Comments
 (0)