Skip to content

Commit 659bec7

Browse files
committed
address #29
1 parent 091e603 commit 659bec7

File tree

3 files changed

+31
-5
lines changed

3 files changed

+31
-5
lines changed

naturalspeech2_pytorch/aligner.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,8 @@ def forward(self, attn_logprob, key_lens, query_lens):
146146

147147
# Convert to log probabilities
148148
# Note: Mask out probs beyond key_len
149-
attn_logprob.masked_fill_(torch.arange(max_key_len + 1, device=device, dtype=torch.long).view(1, 1, -1) > key_lens.view(1, -1, 1), -1e15)
149+
mask_value = -torch.finfo(attn_logprob.dtype).max
150+
attn_logprob.masked_fill_(torch.arange(max_key_len + 1, device=device, dtype=torch.long).view(1, 1, -1) > key_lens.view(1, -1, 1), mask_value)
150151

151152
attn_logprob = attn_logprob.log_softmax(dim = -1)
152153

@@ -159,6 +160,22 @@ def forward(self, attn_logprob, key_lens, query_lens):
159160

160161
return cost
161162

163+
class BinLoss(Module):
164+
def forward(self, attn_hard, attn_logprob, key_lens):
165+
batch, device = attn_logprob.shape[0], attn_logprob.device
166+
max_key_len = attn_logprob.size(-1)
167+
168+
# Reorder input to [query_len, batch_size, key_len]
169+
attn_logprob = rearrange(attn_logprob, 'b 1 c t -> c b t')
170+
attn_hard = rearrange(attn_hard, 'b t c -> c b t')
171+
172+
mask_value = -torch.finfo(attn_logprob.dtype).max
173+
174+
attn_logprob.masked_fill_(torch.arange(max_key_len, device=device, dtype=torch.long).view(1, 1, -1) > key_lens.view(1, -1, 1), mask_value)
175+
attn_logprob = attn_logprob.log_softmax(dim = -1)
176+
177+
return (attn_hard * attn_logprob).sum() / batch
178+
162179
class Aligner(Module):
163180
def __init__(
164181
self,

naturalspeech2_pytorch/naturalspeech2_pytorch.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from beartype.door import is_bearable
2929

3030
from naturalspeech2_pytorch.attend import Attend
31-
from naturalspeech2_pytorch.aligner import Aligner, ForwardSumLoss
31+
from naturalspeech2_pytorch.aligner import Aligner, ForwardSumLoss, BinLoss
3232
from naturalspeech2_pytorch.utils.tokenizer import Tokenizer, ESpeak
3333
from naturalspeech2_pytorch.utils.utils import average_over_durations, create_mask
3434
from naturalspeech2_pytorch.version import __version__
@@ -1192,7 +1192,8 @@ def __init__(
11921192
scale = 1., # this will be set to < 1. for better convergence when training on higher resolution images
11931193
duration_loss_weight = 1.,
11941194
pitch_loss_weight = 1.,
1195-
aligner_loss_weight = 1.
1195+
aligner_loss_weight = 1.,
1196+
aligner_bin_loss_weight = 0.
11961197
):
11971198
super().__init__()
11981199

@@ -1233,7 +1234,10 @@ def __init__(
12331234
self.duration_pitch = DurationPitchPredictor(dim=duration_pitch_dim)
12341235
self.aligner = Aligner(dim_in=aligner_dim_in, dim_hidden=aligner_dim_hidden, attn_channels=aligner_attn_channels)
12351236
self.pitch_emb = nn.Embedding(pitch_emb_dim, pitch_emb_pp_hidden_dim)
1237+
12361238
self.aligner_loss = ForwardSumLoss()
1239+
self.bin_loss = BinLoss()
1240+
self.aligner_bin_loss_weight = aligner_bin_loss_weight
12371241

12381242
# rest of ddpm
12391243

@@ -1584,7 +1588,12 @@ def forward(
15841588

15851589
pitch = rearrange(pitch, 'b 1 d -> b d')
15861590
pitch_loss = F.l1_loss(pitch, pitch_pred)
1587-
align_loss = self.aligner_loss(aln_log , text_lens, mel_lens)
1591+
1592+
align_loss = self.aligner_loss(aln_log, text_lens, mel_lens)
1593+
1594+
if self.aligner_bin_loss_weight > 0.:
1595+
align_bin_loss = self.bin_loss(aln_mask, aln_log, text_lens) * self.aligner_bin_loss_weight
1596+
align_loss = align_loss + align_bin_loss
15881597

15891598
# weigh the losses
15901599

naturalspeech2_pytorch/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.1.7'
1+
__version__ = '0.1.8'

0 commit comments

Comments
 (0)