Skip to content

Commit acb9d9f

Browse files
Soohwan KimSoohwan Kim
Soohwan Kim
authored and
Soohwan Kim
committed
Update model & README.md
1 parent aead2f2 commit acb9d9f

File tree

3 files changed

+14
-257
lines changed

3 files changed

+14
-257
lines changed

README.md

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,22 +62,25 @@ batch_size, sequence_length, dim = 3, 12345, 80
6262
cuda = torch.cuda.is_available()
6363
device = torch.device('cuda' if cuda else 'cpu')
6464

65+
criterion = nn.CTCLoss()
66+
6567
inputs = torch.rand(batch_size, sequence_length, dim).to(device)
6668
input_lengths = torch.IntTensor([12345, 12300, 12000])
6769
targets = torch.LongTensor([[1, 3, 3, 3, 3, 3, 4, 5, 6, 2],
6870
[1, 3, 3, 3, 3, 3, 4, 5, 2, 0],
6971
[1, 3, 3, 3, 3, 3, 4, 2, 0, 0]]).to(device)
7072
target_lengths = torch.LongTensor([9, 8, 7])
7173

72-
model = nn.DataParallel(Conformer(num_classes=10, input_dim=dim,
73-
encoder_dim=32, num_encoder_layers=3,
74-
decoder_dim=32)).to(device)
74+
model = Conformer(num_classes=10,
75+
input_dim=dim,
76+
encoder_dim=32,
77+
num_encoder_layers=3)
7578

7679
# Forward propagate
77-
outputs = model(inputs, input_lengths, targets, target_lengths)
80+
outputs, output_lengths = model(inputs, input_lengths)
7881

79-
# Recognize input speech
80-
outputs = model.module.recognize(inputs, input_lengths)
82+
# Calculate CTC Loss
83+
loss = criterion(outputs.transpose(0, 1), targets, output_lengths, target_lengths)
8184
```
8285

8386
## Troubleshoots and Contributing

conformer/decoder.py

Lines changed: 0 additions & 128 deletions
This file was deleted.

conformer/model.py

Lines changed: 5 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
import torch
1616
import torch.nn as nn
1717
from torch import Tensor
18+
from typing import Tuple
1819

19-
from conformer.decoder import DecoderRNNT
2020
from conformer.encoder import ConformerEncoder
2121
from conformer.modules import Linear
2222

@@ -31,17 +31,13 @@ class Conformer(nn.Module):
3131
num_classes (int): Number of classification classes
3232
input_dim (int, optional): Dimension of input vector
3333
encoder_dim (int, optional): Dimension of conformer encoder
34-
decoder_dim (int, optional): Dimension of conformer decoder
3534
num_encoder_layers (int, optional): Number of conformer blocks
36-
num_decoder_layers (int, optional): Number of decoder layers
37-
decoder_rnn_type (str, optional): type of RNN cell
3835
num_attention_heads (int, optional): Number of attention heads
3936
feed_forward_expansion_factor (int, optional): Expansion factor of feed forward module
4037
conv_expansion_factor (int, optional): Expansion factor of conformer convolution module
4138
feed_forward_dropout_p (float, optional): Probability of feed forward module dropout
4239
attention_dropout_p (float, optional): Probability of attention module dropout
4340
conv_dropout_p (float, optional): Probability of conformer convolution module dropout
44-
decoder_dropout_p (float, optional): Probability of conformer decoder dropout
4541
conv_kernel_size (int or tuple, optional): Size of the convolving kernel
4642
half_step_residual (bool): Flag indication whether to use half step residual or not
4743
@@ -58,20 +54,16 @@ def __init__(
5854
num_classes: int,
5955
input_dim: int = 80,
6056
encoder_dim: int = 512,
61-
decoder_dim: int = 640,
6257
num_encoder_layers: int = 17,
63-
num_decoder_layers: int = 1,
6458
num_attention_heads: int = 8,
6559
feed_forward_expansion_factor: int = 4,
6660
conv_expansion_factor: int = 2,
6761
input_dropout_p: float = 0.1,
6862
feed_forward_dropout_p: float = 0.1,
6963
attention_dropout_p: float = 0.1,
7064
conv_dropout_p: float = 0.1,
71-
decoder_dropout_p: float = 0.1,
7265
conv_kernel_size: int = 31,
7366
half_step_residual: bool = True,
74-
decoder_rnn_type: str = "lstm",
7567
) -> None:
7668
super(Conformer, self).__init__()
7769
self.encoder = ConformerEncoder(
@@ -88,137 +80,27 @@ def __init__(
8880
conv_kernel_size=conv_kernel_size,
8981
half_step_residual=half_step_residual,
9082
)
91-
self.decoder = DecoderRNNT(
92-
num_classes=num_classes,
93-
hidden_state_dim=decoder_dim,
94-
output_dim=encoder_dim,
95-
num_layers=num_decoder_layers,
96-
rnn_type=decoder_rnn_type,
97-
dropout_p=decoder_dropout_p,
98-
)
9983
self.fc = Linear(encoder_dim << 1, num_classes, bias=False)
10084

101-
def set_encoder(self, encoder):
102-
""" Setter for encoder """
103-
self.encoder = encoder
104-
105-
def set_decoder(self, decoder):
106-
""" Setter for decoder """
107-
self.decoder = decoder
108-
10985
def count_parameters(self) -> int:
11086
""" Count parameters of encoder """
111-
num_encoder_parameters = self.encoder.count_parameters()
112-
num_decoder_parameters = self.decoder.count_parameters()
113-
return num_encoder_parameters + num_decoder_parameters
87+
return self.encoder.count_parameters()
11488

11589
def update_dropout(self, dropout_p) -> None:
11690
""" Update dropout probability of model """
11791
self.encoder.update_dropout(dropout_p)
118-
self.decoder.update_dropout(dropout_p)
119-
120-
def joint(self, encoder_outputs: Tensor, decoder_outputs: Tensor) -> Tensor:
121-
"""
122-
Joint `encoder_outputs` and `decoder_outputs`.
123-
124-
Args:
125-
encoder_outputs (torch.FloatTensor): A output sequence of encoder. `FloatTensor` of size
126-
``(batch, seq_length, dimension)``
127-
decoder_outputs (torch.FloatTensor): A output sequence of decoder. `FloatTensor` of size
128-
``(batch, seq_length, dimension)``
129-
130-
Returns:
131-
* outputs (torch.FloatTensor): outputs of joint `encoder_outputs` and `decoder_outputs`..
132-
"""
133-
if encoder_outputs.dim() == 3 and decoder_outputs.dim() == 3:
134-
input_length = encoder_outputs.size(1)
135-
target_length = decoder_outputs.size(1)
13692

137-
encoder_outputs = encoder_outputs.unsqueeze(2)
138-
decoder_outputs = decoder_outputs.unsqueeze(1)
139-
140-
encoder_outputs = encoder_outputs.repeat([1, 1, target_length, 1])
141-
decoder_outputs = decoder_outputs.repeat([1, input_length, 1, 1])
142-
143-
outputs = torch.cat((encoder_outputs, decoder_outputs), dim=-1)
144-
outputs = self.fc(outputs)
145-
146-
return outputs
147-
148-
def forward(
149-
self,
150-
inputs: Tensor,
151-
input_lengths: Tensor,
152-
targets: Tensor,
153-
target_lengths: Tensor
154-
) -> Tensor:
93+
def forward(self, inputs: Tensor, input_lengths: Tensor) -> Tuple[Tensor, Tensor]:
15594
"""
15695
Forward propagate a `inputs` and `targets` pair for training.
15796
15897
Args:
15998
inputs (torch.FloatTensor): A input sequence passed to encoder. Typically for inputs this will be a padded
16099
`FloatTensor` of size ``(batch, seq_length, dimension)``.
161100
input_lengths (torch.LongTensor): The length of input tensor. ``(batch)``
162-
targets (torch.LongTensr): A target sequence passed to decoder. `IntTensor` of size ``(batch, seq_length)``
163-
target_lengths (torch.LongTensor): The length of target tensor. ``(batch)``
164101
165102
Returns:
166103
* predictions (torch.FloatTensor): Result of model predictions.
167104
"""
168-
encoder_outputs, _ = self.encoder(inputs, input_lengths)
169-
decoder_outputs, _ = self.decoder(targets, target_lengths)
170-
outputs = self.joint(encoder_outputs, decoder_outputs)
171-
return outputs
172-
173-
@torch.no_grad()
174-
def decode(self, encoder_output: Tensor, max_length: int) -> Tensor:
175-
"""
176-
Decode `encoder_outputs`.
177-
178-
Args:
179-
encoder_output (torch.FloatTensor): A output sequence of encoder. `FloatTensor` of size
180-
``(seq_length, dimension)``
181-
max_length (int): max decoding time step
182-
183-
Returns:
184-
* predicted_log_probs (torch.FloatTensor): Log probability of model predictions.
185-
"""
186-
pred_tokens, hidden_state = list(), None
187-
decoder_input = encoder_output.new_tensor([[self.decoder.sos_id]], dtype=torch.long)
188-
189-
for t in range(max_length):
190-
decoder_output, hidden_state = self.decoder(decoder_input, hidden_states=hidden_state)
191-
step_output = self.joint(encoder_output[t].view(-1), decoder_output.view(-1))
192-
step_output = step_output.softmax(dim=0)
193-
pred_token = step_output.argmax(dim=0)
194-
pred_token = int(pred_token.item())
195-
pred_tokens.append(pred_token)
196-
decoder_input = step_output.new_tensor([[pred_token]], dtype=torch.long)
197-
198-
return torch.LongTensor(pred_tokens)
199-
200-
@torch.no_grad()
201-
def recognize(self, inputs: Tensor, input_lengths: Tensor):
202-
"""
203-
Recognize input speech. This method consists of the forward of the encoder and the decode() of the decoder.
204-
205-
Args:
206-
inputs (torch.FloatTensor): A input sequence passed to encoder. Typically for inputs this will be a padded
207-
`FloatTensor` of size ``(batch, seq_length, dimension)``.
208-
input_lengths (torch.LongTensor): The length of input tensor. ``(batch)``
209-
210-
Returns:
211-
* predictions (torch.FloatTensor): Result of model predictions.
212-
"""
213-
outputs = list()
214-
215-
encoder_outputs, output_lengths = self.encoder(inputs, input_lengths)
216-
max_length = encoder_outputs.size(1)
217-
218-
for encoder_output in encoder_outputs:
219-
decoded_seq = self.decode(encoder_output, max_length)
220-
outputs.append(decoded_seq)
221-
222-
outputs = torch.stack(outputs, dim=1).transpose(0, 1)
223-
224-
return outputs
105+
encoder_outputs, encoder_output_lengths = self.encoder(inputs, input_lengths)
106+
return encoder_outputs, encoder_output_lengths

0 commit comments

Comments
 (0)