2
2
import numpy as np
3
3
4
4
import torch
5
- from torch import nn
5
+ from torch import nn , Tensor
6
6
from torch .nn import Module
7
7
import torch .nn .functional as F
8
8
9
- from einops import rearrange
9
+ from einops import rearrange , repeat
10
+
11
+ from beartype import beartype
12
+ from beartype .typing import Optional
10
13
11
14
def exists (val ):
12
15
return val is not None
@@ -22,7 +25,6 @@ def __init__(
22
25
):
23
26
super ().__init__ ()
24
27
self .temperature = temperature
25
- self .softmax = torch .nn .Softmax (dim = 3 )
26
28
27
29
self .key_layers = nn .ModuleList ([
28
30
nn .Conv1d (
@@ -50,7 +52,13 @@ def __init__(
50
52
nn .Conv1d (dim_in , attn_channels , kernel_size = 1 , padding = 0 , bias = True )
51
53
])
52
54
53
- def forward (self , queries : torch .Tensor , keys : torch .Tensor , mask : torch .Tensor = None ):
55
+ @beartype
56
+ def forward (
57
+ self ,
58
+ queries : Tensor ,
59
+ keys : Tensor ,
60
+ mask : Optional [Tensor ] = None
61
+ ):
54
62
key_out = keys
55
63
for layer in self .key_layers :
56
64
key_out = layer (key_out )
@@ -61,12 +69,15 @@ def forward(self, queries: torch.Tensor, keys: torch.Tensor, mask: torch.Tensor
61
69
62
70
key_out = rearrange (key_out , 'b c t -> b t c' )
63
71
query_out = rearrange (query_out , 'b c t -> b t c' )
64
- attn_logp = torch .cdist (query_out , key_out ).unsqueeze (1 )
72
+
73
+ attn_logp = torch .cdist (query_out , key_out )
74
+ attn_logp = rearrange (attn_logp , 'b ... -> b 1 ...' )
65
75
66
76
if exists (mask ):
67
- attn_logp .data .masked_fill_ (~ mask .bool ().unsqueeze (2 ), - float ("inf" ))
77
+ mask = rearrange (mask .bool (), '... c -> ... 1 c' )
78
+ attn_logp .data .masked_fill_ (~ mask , - torch .finfo (attn_logp .dtype ).max )
68
79
69
- attn = self .softmax (attn_logp )
80
+ attn = attn_logp .softmax (dim = - 1 )
70
81
return attn , attn_logp
71
82
72
83
def pad_tensor (input , pad , value = 0 ):
@@ -110,34 +121,38 @@ def maximum_path(value, mask, const=None):
110
121
path = path .to (dtype = dtype )
111
122
return path
112
123
113
- class ForwardSumLoss ():
114
- def __init__ (self , blank_logprob = - 1 ):
124
+ class ForwardSumLoss (Module ):
125
+ def __init__ (
126
+ self ,
127
+ blank_logprob = - 1
128
+ ):
115
129
super ().__init__ ()
116
- self .log_softmax = torch .nn .LogSoftmax (dim = - 1 )
117
- self .ctc_loss = torch .nn .CTCLoss (zero_infinity = True )
118
130
self .blank_logprob = blank_logprob
119
131
120
- def forward (self , attn_logprob , in_lens , out_lens ):
121
- key_lens = in_lens
122
- query_lens = out_lens
132
+ self .ctc_loss = torch .nn .CTCLoss (
133
+ blank = 0 , # check this value
134
+ zero_infinity = True
135
+ )
136
+
137
+ def forward (self , attn_logprob , key_lens , query_lens ):
138
+ device , blank_logprob = attn_logprob .device , self .blank_logprob
123
139
max_key_len = attn_logprob .size (- 1 )
124
140
125
141
# Reorder input to [query_len, batch_size, key_len]
126
- attn_logprob = rearrange (attn_logprob , 'b c t -> c b t' )
142
+ attn_logprob = rearrange (attn_logprob , 'b 1 c t -> c b t' )
127
143
128
144
# Add blank label
129
- attn_logprob = F .pad (attn_logprob , (1 , 0 , 0 , 0 , 0 , 0 ), self . blank_logprob )
145
+ attn_logprob = F .pad (attn_logprob , (1 , 0 , 0 , 0 , 0 , 0 ), value = blank_logprob )
130
146
131
147
# Convert to log probabilities
132
148
# Note: Mask out probs beyond key_len
133
- device = attn_logprob .device
134
149
attn_logprob .masked_fill_ (torch .arange (max_key_len + 1 , device = device , dtype = torch .long ).view (1 , 1 , - 1 ) > key_lens .view (1 , - 1 , 1 ), - 1e15 )
135
150
136
- attn_logprob = self .log_softmax (attn_logprob )
151
+ attn_logprob = attn_logprob .log_softmax (dim = - 1 )
137
152
138
153
# Target sequences
139
- target_seqs = torch .arange (1 , max_key_len + 1 , device = device , dtype = torch .long ). unsqueeze ( 0 )
140
- target_seqs = target_seqs . repeat (key_lens .numel (), 1 )
154
+ target_seqs = torch .arange (1 , max_key_len + 1 , device = device , dtype = torch .long )
155
+ target_seqs = repeat (target_seqs , 'n -> b n' , b = key_lens .numel ())
141
156
142
157
# Evaluate CTC loss
143
158
cost = self .ctc_loss (attn_logprob , target_seqs , query_lens , key_lens )
0 commit comments