Skip to content

Commit fb9e1d5

Browse files
committed
make sure forward sum loss actually runs, cleanup
1 parent 5cbf6e9 commit fb9e1d5

File tree

4 files changed

+41
-22
lines changed

4 files changed

+41
-22
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ raw_audio = torch.randn(4, 327680)
108108
prompt = torch.randn(4, 32768) # they randomly excised a range on the audio for the prompt during training, eventually will take care of this auto-magically
109109

110110
text = torch.randint(0, 100, (4, 100))
111-
text_lens = torch.tensor([100, 50 , 80, 120])
111+
text_lens = torch.tensor([100, 50 , 80, 100])
112112

113113
# forwards and backwards
114114

naturalspeech2_pytorch/aligner.py

Lines changed: 35 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,14 @@
22
import numpy as np
33

44
import torch
5-
from torch import nn
5+
from torch import nn, Tensor
66
from torch.nn import Module
77
import torch.nn.functional as F
88

9-
from einops import rearrange
9+
from einops import rearrange, repeat
10+
11+
from beartype import beartype
12+
from beartype.typing import Optional
1013

1114
def exists(val):
1215
return val is not None
@@ -22,7 +25,6 @@ def __init__(
2225
):
2326
super().__init__()
2427
self.temperature = temperature
25-
self.softmax = torch.nn.Softmax(dim=3)
2628

2729
self.key_layers = nn.ModuleList([
2830
nn.Conv1d(
@@ -50,7 +52,13 @@ def __init__(
5052
nn.Conv1d(dim_in, attn_channels, kernel_size=1, padding=0, bias=True)
5153
])
5254

53-
def forward(self, queries: torch.Tensor, keys: torch.Tensor, mask: torch.Tensor = None):
55+
@beartype
56+
def forward(
57+
self,
58+
queries: Tensor,
59+
keys: Tensor,
60+
mask: Optional[Tensor] = None
61+
):
5462
key_out = keys
5563
for layer in self.key_layers:
5664
key_out = layer(key_out)
@@ -61,12 +69,15 @@ def forward(self, queries: torch.Tensor, keys: torch.Tensor, mask: torch.Tensor
6169

6270
key_out = rearrange(key_out, 'b c t -> b t c')
6371
query_out = rearrange(query_out, 'b c t -> b t c')
64-
attn_logp = torch.cdist(query_out, key_out).unsqueeze(1)
72+
73+
attn_logp = torch.cdist(query_out, key_out)
74+
attn_logp = rearrange(attn_logp, 'b ... -> b 1 ...')
6575

6676
if exists(mask):
67-
attn_logp.data.masked_fill_(~mask.bool().unsqueeze(2), -float("inf"))
77+
mask = rearrange(mask.bool(), '... c -> ... 1 c')
78+
attn_logp.data.masked_fill_(~mask, -torch.finfo(attn_logp.dtype).max)
6879

69-
attn = self.softmax(attn_logp)
80+
attn = attn_logp.softmax(dim = -1)
7081
return attn, attn_logp
7182

7283
def pad_tensor(input, pad, value=0):
@@ -110,34 +121,38 @@ def maximum_path(value, mask, const=None):
110121
path = path.to(dtype=dtype)
111122
return path
112123

113-
class ForwardSumLoss():
114-
def __init__(self, blank_logprob=-1):
124+
class ForwardSumLoss(Module):
125+
def __init__(
126+
self,
127+
blank_logprob = -1
128+
):
115129
super().__init__()
116-
self.log_softmax = torch.nn.LogSoftmax(dim=-1)
117-
self.ctc_loss = torch.nn.CTCLoss(zero_infinity=True)
118130
self.blank_logprob = blank_logprob
119131

120-
def forward(self, attn_logprob, in_lens, out_lens):
121-
key_lens = in_lens
122-
query_lens = out_lens
132+
self.ctc_loss = torch.nn.CTCLoss(
133+
blank = 0, # check this value
134+
zero_infinity = True
135+
)
136+
137+
def forward(self, attn_logprob, key_lens, query_lens):
138+
device, blank_logprob = attn_logprob.device, self.blank_logprob
123139
max_key_len = attn_logprob.size(-1)
124140

125141
# Reorder input to [query_len, batch_size, key_len]
126-
attn_logprob = rearrange(attn_logprob, 'b c t -> c b t')
142+
attn_logprob = rearrange(attn_logprob, 'b 1 c t -> c b t')
127143

128144
# Add blank label
129-
attn_logprob = F.pad(attn_logprob, (1, 0, 0, 0, 0, 0), self.blank_logprob)
145+
attn_logprob = F.pad(attn_logprob, (1, 0, 0, 0, 0, 0), value = blank_logprob)
130146

131147
# Convert to log probabilities
132148
# Note: Mask out probs beyond key_len
133-
device = attn_logprob.device
134149
attn_logprob.masked_fill_(torch.arange(max_key_len + 1, device=device, dtype=torch.long).view(1, 1, -1) > key_lens.view(1, -1, 1), -1e15)
135150

136-
attn_logprob = self.log_softmax(attn_logprob)
151+
attn_logprob = attn_logprob.log_softmax(dim = -1)
137152

138153
# Target sequences
139-
target_seqs = torch.arange(1, max_key_len + 1, device=device, dtype=torch.long).unsqueeze(0)
140-
target_seqs = target_seqs.repeat(key_lens.numel(), 1)
154+
target_seqs = torch.arange(1, max_key_len + 1, device=device, dtype=torch.long)
155+
target_seqs = repeat(target_seqs, 'n -> b n', b = key_lens.numel())
141156

142157
# Evaluate CTC loss
143158
cost = self.ctc_loss(attn_logprob, target_seqs, query_lens, key_lens)

naturalspeech2_pytorch/naturalspeech2_pytorch.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1448,6 +1448,8 @@ def forward(
14481448
if not exists(text_lens):
14491449
text_lens = torch.full((batch,), text_max_length, device = self.device, dtype = torch.long)
14501450

1451+
text_lens.clamp_(max = text_max_length)
1452+
14511453
text_mask = rearrange(create_mask(text_lens, text_max_length), 'b n -> b 1 n')
14521454

14531455
prompt = self.process_prompt(prompt)
@@ -1475,6 +1477,8 @@ def forward(
14751477
if not exists(mel_lens):
14761478
mel_lens = torch.full((batch,), mel_max_length, device = self.device, dtype = torch.long)
14771479

1480+
mel_lens.clamp_(max = mel_max_length)
1481+
14781482
mel_mask = rearrange(create_mask(mel_lens, mel_max_length), 'b n -> b 1 n')
14791483

14801484
# alignment

naturalspeech2_pytorch/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.1.0'
1+
__version__ = '0.1.1'

0 commit comments

Comments
 (0)