15
15
import torch
16
16
import torch .nn as nn
17
17
from torch import Tensor
18
+ from typing import Tuple
18
19
19
- from conformer .decoder import DecoderRNNT
20
20
from conformer .encoder import ConformerEncoder
21
21
from conformer .modules import Linear
22
22
@@ -31,17 +31,13 @@ class Conformer(nn.Module):
31
31
num_classes (int): Number of classification classes
32
32
input_dim (int, optional): Dimension of input vector
33
33
encoder_dim (int, optional): Dimension of conformer encoder
34
- decoder_dim (int, optional): Dimension of conformer decoder
35
34
num_encoder_layers (int, optional): Number of conformer blocks
36
- num_decoder_layers (int, optional): Number of decoder layers
37
- decoder_rnn_type (str, optional): type of RNN cell
38
35
num_attention_heads (int, optional): Number of attention heads
39
36
feed_forward_expansion_factor (int, optional): Expansion factor of feed forward module
40
37
conv_expansion_factor (int, optional): Expansion factor of conformer convolution module
41
38
feed_forward_dropout_p (float, optional): Probability of feed forward module dropout
42
39
attention_dropout_p (float, optional): Probability of attention module dropout
43
40
conv_dropout_p (float, optional): Probability of conformer convolution module dropout
44
- decoder_dropout_p (float, optional): Probability of conformer decoder dropout
45
41
conv_kernel_size (int or tuple, optional): Size of the convolving kernel
46
42
half_step_residual (bool): Flag indication whether to use half step residual or not
47
43
@@ -58,20 +54,16 @@ def __init__(
58
54
num_classes : int ,
59
55
input_dim : int = 80 ,
60
56
encoder_dim : int = 512 ,
61
- decoder_dim : int = 640 ,
62
57
num_encoder_layers : int = 17 ,
63
- num_decoder_layers : int = 1 ,
64
58
num_attention_heads : int = 8 ,
65
59
feed_forward_expansion_factor : int = 4 ,
66
60
conv_expansion_factor : int = 2 ,
67
61
input_dropout_p : float = 0.1 ,
68
62
feed_forward_dropout_p : float = 0.1 ,
69
63
attention_dropout_p : float = 0.1 ,
70
64
conv_dropout_p : float = 0.1 ,
71
- decoder_dropout_p : float = 0.1 ,
72
65
conv_kernel_size : int = 31 ,
73
66
half_step_residual : bool = True ,
74
- decoder_rnn_type : str = "lstm" ,
75
67
) -> None :
76
68
super (Conformer , self ).__init__ ()
77
69
self .encoder = ConformerEncoder (
@@ -88,137 +80,27 @@ def __init__(
88
80
conv_kernel_size = conv_kernel_size ,
89
81
half_step_residual = half_step_residual ,
90
82
)
91
- self .decoder = DecoderRNNT (
92
- num_classes = num_classes ,
93
- hidden_state_dim = decoder_dim ,
94
- output_dim = encoder_dim ,
95
- num_layers = num_decoder_layers ,
96
- rnn_type = decoder_rnn_type ,
97
- dropout_p = decoder_dropout_p ,
98
- )
99
83
self .fc = Linear (encoder_dim << 1 , num_classes , bias = False )
100
84
101
- def set_encoder (self , encoder ):
102
- """ Setter for encoder """
103
- self .encoder = encoder
104
-
105
- def set_decoder (self , decoder ):
106
- """ Setter for decoder """
107
- self .decoder = decoder
108
-
109
85
def count_parameters (self ) -> int :
110
86
""" Count parameters of encoder """
111
- num_encoder_parameters = self .encoder .count_parameters ()
112
- num_decoder_parameters = self .decoder .count_parameters ()
113
- return num_encoder_parameters + num_decoder_parameters
87
+ return self .encoder .count_parameters ()
114
88
115
89
def update_dropout (self , dropout_p ) -> None :
116
90
""" Update dropout probability of model """
117
91
self .encoder .update_dropout (dropout_p )
118
- self .decoder .update_dropout (dropout_p )
119
-
120
- def joint (self , encoder_outputs : Tensor , decoder_outputs : Tensor ) -> Tensor :
121
- """
122
- Joint `encoder_outputs` and `decoder_outputs`.
123
-
124
- Args:
125
- encoder_outputs (torch.FloatTensor): A output sequence of encoder. `FloatTensor` of size
126
- ``(batch, seq_length, dimension)``
127
- decoder_outputs (torch.FloatTensor): A output sequence of decoder. `FloatTensor` of size
128
- ``(batch, seq_length, dimension)``
129
-
130
- Returns:
131
- * outputs (torch.FloatTensor): outputs of joint `encoder_outputs` and `decoder_outputs`..
132
- """
133
- if encoder_outputs .dim () == 3 and decoder_outputs .dim () == 3 :
134
- input_length = encoder_outputs .size (1 )
135
- target_length = decoder_outputs .size (1 )
136
92
137
- encoder_outputs = encoder_outputs .unsqueeze (2 )
138
- decoder_outputs = decoder_outputs .unsqueeze (1 )
139
-
140
- encoder_outputs = encoder_outputs .repeat ([1 , 1 , target_length , 1 ])
141
- decoder_outputs = decoder_outputs .repeat ([1 , input_length , 1 , 1 ])
142
-
143
- outputs = torch .cat ((encoder_outputs , decoder_outputs ), dim = - 1 )
144
- outputs = self .fc (outputs )
145
-
146
- return outputs
147
-
148
- def forward (
149
- self ,
150
- inputs : Tensor ,
151
- input_lengths : Tensor ,
152
- targets : Tensor ,
153
- target_lengths : Tensor
154
- ) -> Tensor :
93
+ def forward (self , inputs : Tensor , input_lengths : Tensor ) -> Tuple [Tensor , Tensor ]:
155
94
"""
156
95
Forward propagate a `inputs` and `targets` pair for training.
157
96
158
97
Args:
159
98
inputs (torch.FloatTensor): A input sequence passed to encoder. Typically for inputs this will be a padded
160
99
`FloatTensor` of size ``(batch, seq_length, dimension)``.
161
100
input_lengths (torch.LongTensor): The length of input tensor. ``(batch)``
162
- targets (torch.LongTensr): A target sequence passed to decoder. `IntTensor` of size ``(batch, seq_length)``
163
- target_lengths (torch.LongTensor): The length of target tensor. ``(batch)``
164
101
165
102
Returns:
166
103
* predictions (torch.FloatTensor): Result of model predictions.
167
104
"""
168
- encoder_outputs , _ = self .encoder (inputs , input_lengths )
169
- decoder_outputs , _ = self .decoder (targets , target_lengths )
170
- outputs = self .joint (encoder_outputs , decoder_outputs )
171
- return outputs
172
-
173
- @torch .no_grad ()
174
- def decode (self , encoder_output : Tensor , max_length : int ) -> Tensor :
175
- """
176
- Decode `encoder_outputs`.
177
-
178
- Args:
179
- encoder_output (torch.FloatTensor): A output sequence of encoder. `FloatTensor` of size
180
- ``(seq_length, dimension)``
181
- max_length (int): max decoding time step
182
-
183
- Returns:
184
- * predicted_log_probs (torch.FloatTensor): Log probability of model predictions.
185
- """
186
- pred_tokens , hidden_state = list (), None
187
- decoder_input = encoder_output .new_tensor ([[self .decoder .sos_id ]], dtype = torch .long )
188
-
189
- for t in range (max_length ):
190
- decoder_output , hidden_state = self .decoder (decoder_input , hidden_states = hidden_state )
191
- step_output = self .joint (encoder_output [t ].view (- 1 ), decoder_output .view (- 1 ))
192
- step_output = step_output .softmax (dim = 0 )
193
- pred_token = step_output .argmax (dim = 0 )
194
- pred_token = int (pred_token .item ())
195
- pred_tokens .append (pred_token )
196
- decoder_input = step_output .new_tensor ([[pred_token ]], dtype = torch .long )
197
-
198
- return torch .LongTensor (pred_tokens )
199
-
200
- @torch .no_grad ()
201
- def recognize (self , inputs : Tensor , input_lengths : Tensor ):
202
- """
203
- Recognize input speech. This method consists of the forward of the encoder and the decode() of the decoder.
204
-
205
- Args:
206
- inputs (torch.FloatTensor): A input sequence passed to encoder. Typically for inputs this will be a padded
207
- `FloatTensor` of size ``(batch, seq_length, dimension)``.
208
- input_lengths (torch.LongTensor): The length of input tensor. ``(batch)``
209
-
210
- Returns:
211
- * predictions (torch.FloatTensor): Result of model predictions.
212
- """
213
- outputs = list ()
214
-
215
- encoder_outputs , output_lengths = self .encoder (inputs , input_lengths )
216
- max_length = encoder_outputs .size (1 )
217
-
218
- for encoder_output in encoder_outputs :
219
- decoded_seq = self .decode (encoder_output , max_length )
220
- outputs .append (decoded_seq )
221
-
222
- outputs = torch .stack (outputs , dim = 1 ).transpose (0 , 1 )
223
-
224
- return outputs
105
+ encoder_outputs , encoder_output_lengths = self .encoder (inputs , input_lengths )
106
+ return encoder_outputs , encoder_output_lengths
0 commit comments