|
28 | 28 | from beartype.door import is_bearable
|
29 | 29 |
|
30 | 30 | from naturalspeech2_pytorch.attend import Attend
|
31 |
| -from naturalspeech2_pytorch.aligner import Aligner |
| 31 | +from naturalspeech2_pytorch.aligner import Aligner, ForwardSumLoss |
32 | 32 | from naturalspeech2_pytorch.utils.tokenizer import Tokenizer, ESpeak
|
33 | 33 | from naturalspeech2_pytorch.utils.utils import average_over_durations, create_mask
|
34 | 34 | from naturalspeech2_pytorch.version import __version__
|
@@ -1118,7 +1118,8 @@ def __init__(
|
1118 | 1118 | audio_to_mel_kwargs: dict = dict(),
|
1119 | 1119 | scale = 1., # this will be set to < 1. for better convergence when training on higher resolution images
|
1120 | 1120 | duration_loss_weight = 1.,
|
1121 |
| - pitch_loss_weight = 1. |
| 1121 | + pitch_loss_weight = 1., |
| 1122 | + aligner_loss_weight = 1. |
1122 | 1123 | ):
|
1123 | 1124 | super().__init__()
|
1124 | 1125 |
|
@@ -1154,6 +1155,7 @@ def __init__(
|
1154 | 1155 | self.duration_pitch = DurationPitchPredictor(dim=duration_pitch_dim)
|
1155 | 1156 | self.aligner = Aligner(dim_in=aligner_dim_in, dim_hidden=aligner_dim_hidden, attn_channels=aligner_attn_channels)
|
1156 | 1157 | self.pitch_emb = nn.Embedding(pitch_emb_dim, pitch_emb_pp_hidden_dim)
|
| 1158 | + self.aligner_loss = ForwardSumLoss() |
1157 | 1159 |
|
1158 | 1160 | # rest of ddpm
|
1159 | 1161 |
|
@@ -1207,6 +1209,7 @@ def __init__(
|
1207 | 1209 |
|
1208 | 1210 | self.duration_loss_weight = duration_loss_weight
|
1209 | 1211 | self.pitch_loss_weight = pitch_loss_weight
|
| 1212 | + self.aligner_loss_weight = aligner_loss_weight |
1210 | 1213 |
|
1211 | 1214 | @property
|
1212 | 1215 | def device(self):
|
@@ -1488,10 +1491,12 @@ def forward(
|
1488 | 1491 |
|
1489 | 1492 | pitch = rearrange(pitch, 'b 1 d -> b d')
|
1490 | 1493 | pitch_loss = F.l1_loss(pitch, pitch_pred)
|
1491 |
| - |
| 1494 | + align_loss = self.aligner_loss(aln_log , text_lens, mel_lens) |
1492 | 1495 | # weigh the losses
|
1493 | 1496 |
|
1494 |
| - aux_loss = duration_loss * self.duration_loss_weight + pitch_loss + self.pitch_loss_weight |
| 1497 | + aux_loss = (duration_loss * self.duration_loss_weight) \ |
| 1498 | + + (pitch_loss * self.pitch_loss_weight) \ |
| 1499 | + + (align_loss * self.aligner_loss_weight) |
1495 | 1500 |
|
1496 | 1501 | # automatically encode raw audio to residual vq with codec
|
1497 | 1502 |
|
|
0 commit comments