1
- from torch import nn
2
-
3
- from tsl .nn .base .attention import MultiHeadAttention
4
- from tsl .nn .layers .norm import LayerNorm
5
- from tsl .nn import utils
6
1
from functools import partial
2
+ from typing import Optional
7
3
8
4
import torch .nn .functional as F
5
+ from torch import nn , Tensor
6
+
7
+ from tsl .nn import utils
8
+ from tsl .nn .base .attention import MultiHeadAttention
9
+ from tsl .nn .layers .norm import LayerNorm
9
10
10
11
11
12
class TransformerLayer (nn .Module ):
12
- r"""
13
- A TransformerLayer which can be instantiated to attent the temporal or spatial dimension.
13
+ r"""A Transformer layer from the paper `"Attention Is All You Need"
14
+ <https://arxiv.org/abs/1706.03762>`_ (Vaswani et al., NeurIPS 2017).
15
+
16
+ This layer can be instantiated to attend the temporal or spatial dimension.
14
17
15
18
Args:
16
19
input_size (int): Input size.
17
20
hidden_size (int): Dimension of the learned representations.
18
21
ff_size (int): Units in the MLP after self attention.
19
22
n_heads (int, optional): Number of parallel attention heads.
20
- axis (str, optional): Dimension on which to apply attention to update the representations.
21
- causal (bool, optional): Whether to causally mask the attention scores (can be `True` only if `axis` is `steps`).
23
+ axis (str, optional): Dimension on which to apply attention to update
24
+ the representations. Can be either, 'time' or 'nodes'.
25
+ (default: :obj:`'time'`)
26
+ causal (bool, optional): If :obj:`True`, then causally mask attention
27
+ scores in temporal attention (has an effect only if :attr:`axis` is
28
+ :obj:`'time'`). (default: :obj:`True`)
22
29
activation (str, optional): Activation function.
23
30
dropout (float, optional): Dropout probability.
24
31
"""
32
+
25
33
def __init__ (self ,
26
34
input_size ,
27
35
hidden_size ,
28
36
ff_size = None ,
29
37
n_heads = 1 ,
30
- axis = 'steps ' ,
38
+ axis = 'time ' ,
31
39
causal = True ,
32
40
activation = 'elu' ,
33
41
dropout = 0. ):
@@ -60,27 +68,32 @@ def __init__(self,
60
68
61
69
self .activation = utils .get_functional_activation (activation )
62
70
63
- def forward (self , x , mask = None ):
71
+ def forward (self , x : Tensor , mask : Optional [ Tensor ] = None ):
64
72
""""""
65
73
# x: [batch, steps, nodes, features]
66
- x = self .skip_conn (x ) + self .dropout (self .att (self .norm1 (x ), attn_mask = mask )[0 ])
74
+ x = self .skip_conn (x ) + self .dropout (
75
+ self .att (self .norm1 (x ), attn_mask = mask )[0 ])
67
76
x = x + self .mlp (x )
68
77
return x
69
78
70
79
71
80
class SpatioTemporalTransformerLayer (nn .Module ):
72
- r"""
73
- A TransformerLayer which attend both the spatial and temporal dimensions by stacking two `MultiHeadAttention` layers.
81
+ r"""A :class:`~tsl.nn.blocks.encoders.TransformerLayer` which attend both
82
+ the spatial and temporal dimensions by stacking two
83
+ :class:`~tsl.nn.base.MultiHeadAttention` layers.
74
84
75
85
Args:
76
86
input_size (int): Input size.
77
87
hidden_size (int): Dimension of the learned representations.
78
88
ff_size (int): Units in the MLP after self attention.
79
89
n_heads (int, optional): Number of parallel attention heads.
80
- causal (bool, optional): Whether to causally mask the attention scores (can be `True` only if `axis` is `steps`).
90
+ causal (bool, optional): If :obj:`True`, then causally mask attention
91
+ scores in temporal attention.
92
+ (default: :obj:`True`)
81
93
activation (str, optional): Activation function.
82
94
dropout (float, optional): Dropout probability.
83
95
"""
96
+
84
97
def __init__ (self ,
85
98
input_size ,
86
99
hidden_size ,
@@ -95,7 +108,7 @@ def __init__(self,
95
108
kdim = input_size ,
96
109
vdim = input_size ,
97
110
heads = n_heads ,
98
- axis = 'steps ' ,
111
+ axis = 'time ' ,
99
112
causal = causal )
100
113
101
114
self .spatial_att = MultiHeadAttention (embed_dim = hidden_size ,
@@ -122,19 +135,18 @@ def __init__(self,
122
135
123
136
self .dropout = nn .Dropout (dropout )
124
137
125
- def forward (self , x , mask = None ):
138
+ def forward (self , x : Tensor , mask : Optional [ Tensor ] = None ):
126
139
""""""
127
140
# x: [batch, steps, nodes, features]
128
-
129
- x = self . skip_conn ( x ) + self . dropout ( self .temporal_att (self .norm1 (x ), attn_mask = mask )[0 ])
141
+ x = self . skip_conn ( x ) + self . dropout (
142
+ self .temporal_att (self .norm1 (x ), attn_mask = mask )[0 ])
130
143
x = x + self .dropout (self .spatial_att (self .norm2 (x ), attn_mask = mask )[0 ])
131
144
x = x + self .mlp (x )
132
145
return x
133
146
134
147
135
148
class Transformer (nn .Module ):
136
- r"""
137
- A stack of Transformer layers.
149
+ r"""A stack of Transformer layers.
138
150
139
151
Args:
140
152
input_size (int): Input size.
@@ -143,19 +155,25 @@ class Transformer(nn.Module):
143
155
output_size (int, optional): Size of an optional linear readout.
144
156
n_layers (int, optional): Number of Transformer layers.
145
157
n_heads (int, optional): Number of parallel attention heads.
146
- axis (str, optional): Dimension on which to apply attention to update the representations.
147
- causal (bool, optional): Whether to causally mask the attention scores (can be `True` only if `axis` is `steps`).
158
+ axis (str, optional): Dimension on which to apply attention to update
159
+ the representations. Can be either, 'time', 'nodes', or 'both'.
160
+ (default: :obj:`'time'`)
161
+ causal (bool, optional): If :obj:`True`, then causally mask attention
162
+ scores in temporal attention (has an effect only if :attr:`axis` is
163
+ :obj:`'time'` or :obj:`'both'`).
164
+ (default: :obj:`True`)
148
165
activation (str, optional): Activation function.
149
166
dropout (float, optional): Dropout probability.
150
167
"""
168
+
151
169
def __init__ (self ,
152
170
input_size ,
153
171
hidden_size ,
154
172
ff_size = None ,
155
173
output_size = None ,
156
174
n_layers = 1 ,
157
175
n_heads = 1 ,
158
- axis = 'steps ' ,
176
+ axis = 'time ' ,
159
177
causal = True ,
160
178
activation = 'elu' ,
161
179
dropout = 0. ):
@@ -165,7 +183,7 @@ def __init__(self,
165
183
if ff_size is None :
166
184
ff_size = hidden_size
167
185
168
- if axis in ['steps ' , 'nodes' ]:
186
+ if axis in ['time ' , 'nodes' ]:
169
187
transformer_layer = partial (TransformerLayer , axis = axis )
170
188
elif axis == 'both' :
171
189
transformer_layer = SpatioTemporalTransformerLayer
@@ -174,13 +192,14 @@ def __init__(self,
174
192
175
193
layers = []
176
194
for i in range (n_layers ):
177
- layers .append (transformer_layer (input_size = input_size if i == 0 else hidden_size ,
178
- hidden_size = hidden_size ,
179
- ff_size = ff_size ,
180
- n_heads = n_heads ,
181
- causal = causal ,
182
- activation = activation ,
183
- dropout = dropout ))
195
+ layers .append (transformer_layer (
196
+ input_size = input_size if i == 0 else hidden_size ,
197
+ hidden_size = hidden_size ,
198
+ ff_size = ff_size ,
199
+ n_heads = n_heads ,
200
+ causal = causal ,
201
+ activation = activation ,
202
+ dropout = dropout ))
184
203
185
204
self .net = nn .Sequential (* layers )
186
205
@@ -189,7 +208,7 @@ def __init__(self,
189
208
else :
190
209
self .register_parameter ('readout' , None )
191
210
192
- def forward (self , x ):
211
+ def forward (self , x : Tensor ):
193
212
""""""
194
213
x = self .net (x )
195
214
if self .readout is not None :
0 commit comments