Skip to content

Commit bff0e9b

Browse files
authored
Merge pull request #26 from manmay-nakhashi/main
added aligner loss
2 parents 7a856c3 + f5ce1b6 commit bff0e9b

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

naturalspeech2_pytorch/naturalspeech2_pytorch.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from beartype.door import is_bearable
2929

3030
from naturalspeech2_pytorch.attend import Attend
31-
from naturalspeech2_pytorch.aligner import Aligner
31+
from naturalspeech2_pytorch.aligner import Aligner, ForwardSumLoss
3232
from naturalspeech2_pytorch.utils.tokenizer import Tokenizer, ESpeak
3333
from naturalspeech2_pytorch.utils.utils import average_over_durations, create_mask
3434
from naturalspeech2_pytorch.version import __version__
@@ -1118,7 +1118,8 @@ def __init__(
11181118
audio_to_mel_kwargs: dict = dict(),
11191119
scale = 1., # this will be set to < 1. for better convergence when training on higher resolution images
11201120
duration_loss_weight = 1.,
1121-
pitch_loss_weight = 1.
1121+
pitch_loss_weight = 1.,
1122+
aligner_loss_weight = 1.
11221123
):
11231124
super().__init__()
11241125

@@ -1154,6 +1155,7 @@ def __init__(
11541155
self.duration_pitch = DurationPitchPredictor(dim=duration_pitch_dim)
11551156
self.aligner = Aligner(dim_in=aligner_dim_in, dim_hidden=aligner_dim_hidden, attn_channels=aligner_attn_channels)
11561157
self.pitch_emb = nn.Embedding(pitch_emb_dim, pitch_emb_pp_hidden_dim)
1158+
self.aligner_loss = ForwardSumLoss()
11571159

11581160
# rest of ddpm
11591161

@@ -1207,6 +1209,7 @@ def __init__(
12071209

12081210
self.duration_loss_weight = duration_loss_weight
12091211
self.pitch_loss_weight = pitch_loss_weight
1212+
self.aligner_loss_weight = aligner_loss_weight
12101213

12111214
@property
12121215
def device(self):
@@ -1488,10 +1491,12 @@ def forward(
14881491

14891492
pitch = rearrange(pitch, 'b 1 d -> b d')
14901493
pitch_loss = F.l1_loss(pitch, pitch_pred)
1491-
1494+
align_loss = self.aligner_loss(aln_log , text_lens, mel_lens)
14921495
# weigh the losses
14931496

1494-
aux_loss = duration_loss * self.duration_loss_weight + pitch_loss + self.pitch_loss_weight
1497+
aux_loss = (duration_loss * self.duration_loss_weight) \
1498+
+ (pitch_loss * self.pitch_loss_weight) \
1499+
+ (align_loss * self.aligner_loss_weight)
14951500

14961501
# automatically encode raw audio to residual vq with codec
14971502

0 commit comments

Comments
 (0)