Skip to content

Commit 44d1a1f

Browse files
committed
at inference time, the alignment mask is derived from the duration. improvise a get_mask_from_lengths function, consult with someone in the field later
1 parent 4ff58db commit 44d1a1f

File tree

4 files changed

+93
-114
lines changed

4 files changed

+93
-114
lines changed

README.md

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -107,23 +107,15 @@ diffusion = NaturalSpeech2(
107107
raw_audio = torch.randn(4, 327680)
108108
prompt = torch.randn(4, 32768) # they randomly excised a range on the audio for the prompt during training, eventually will take care of this auto-magically
109109

110-
mel_lens = torch.tensor([120, 60 , 80, 70])
111-
mel = torch.randn((4, 80, 120))
112-
113110
text = torch.randint(0, 100, (4, 100))
114111
text_lens = torch.tensor([100, 50 , 80, 120])
115112

116-
pitch = torch.randn(4, 1, 120)
117-
118113
# forwards and backwards
119114

120115
loss = diffusion(
121116
audio = raw_audio,
122117
text = text,
123118
text_lens = text_lens,
124-
mel = mel,
125-
mel_lens = mel_lens,
126-
pitch = pitch,
127119
prompt = prompt
128120
)
129121

@@ -134,9 +126,6 @@ loss.backward()
134126
generated_audio = diffusion.sample(
135127
length = 1024,
136128
text = text,
137-
mel = mel,
138-
mel_lens = mel_lens,
139-
pitch = pitch,
140129
prompt = prompt
141130
) # (1, 327680)
142131
```

naturalspeech2_pytorch/aligner.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -171,17 +171,18 @@ def forward(
171171
y,
172172
y_mask
173173
):
174-
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
175174
alignment_soft, alignment_logprob = self.aligner(y, rearrange(x, 'b d t -> b t d'), x_mask)
176-
alignment_soft = rearrange(alignment_soft, 'b 1 c t -> b t c')
177175

178-
alignment_mas = maximum_path(
179-
alignment_soft.contiguous(),
180-
rearrange(attn_mask, 'b 1 c t -> b c t').contiguous()
181-
)
176+
x_mask = rearrange(x_mask, '... i -> ... i 1')
177+
y_mask = rearrange(y_mask, '... j -> ... 1 j')
178+
attn_mask = x_mask * y_mask
179+
attn_mask = rearrange(attn_mask, 'b 1 i j -> b i j')
180+
181+
alignment_soft = rearrange(alignment_soft, 'b 1 c t -> b t c')
182+
alignment_mask = maximum_path(alignment_soft, attn_mask)
182183

183-
alignment_hard = torch.sum(alignment_mas, -1).int()
184-
return alignment_hard, alignment_soft, alignment_logprob, alignment_mas
184+
alignment_hard = torch.sum(alignment_mask, -1).int()
185+
return alignment_hard, alignment_soft, alignment_logprob, alignment_mask
185186

186187
if __name__ == '__main__':
187188
batch_size = 10

naturalspeech2_pytorch/naturalspeech2_pytorch.py

Lines changed: 83 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,23 @@ def prob_mask_like(shape, prob, device):
7272
else:
7373
return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob
7474

75+
def generate_mask_from_lengths(lengths):
76+
src = lengths.int()
77+
device = src.device
78+
tgt_length = src.sum(dim = -1).amax().item()
79+
80+
cumsum = src.cumsum(dim = -1)
81+
cumsum_exclusive = F.pad(cumsum, (1, -1), value = 0.)
82+
83+
tgt_arange = torch.arange(tgt_length, device = device)
84+
tgt_arange = repeat(tgt_arange, '... j -> ... i j', i = src.shape[-1])
85+
86+
cumsum = rearrange(cumsum, '... i -> ... i 1')
87+
cumsum_exclusive = rearrange(cumsum_exclusive, '... i -> ... i 1')
88+
89+
mask = (tgt_arange < cumsum) & (tgt_arange >= cumsum_exclusive)
90+
return mask
91+
7592
# sinusoidal positional embeds
7693

7794
class LearnedSinusoidalPosEmb(nn.Module):
@@ -1344,76 +1361,6 @@ def process_prompt(self, prompt = None):
13441361

13451362
return prompt
13461363

1347-
def process_conditioning(
1348-
self,
1349-
*,
1350-
prompt,
1351-
audio = None,
1352-
pitch = None,
1353-
text = None,
1354-
text_lens = None,
1355-
mel = None,
1356-
mel_lens = None
1357-
):
1358-
batch = prompt.shape[0]
1359-
1360-
assert exists(text)
1361-
text_max_length = text.shape[-1]
1362-
1363-
if not exists(text_lens):
1364-
text_lens = torch.full((batch,), text_max_length, device = self.device, dtype = torch.long)
1365-
1366-
text_mask = rearrange(create_mask(text_lens, text_max_length), 'b n -> b 1 n')
1367-
1368-
prompt = self.process_prompt(prompt)
1369-
prompt_enc = self.prompt_enc(prompt)
1370-
phoneme_enc = self.phoneme_enc(text)
1371-
1372-
# process pitch
1373-
1374-
if not exists(pitch):
1375-
assert exists(audio) and audio.ndim == 2
1376-
assert exists(self.target_sample_hz)
1377-
1378-
pitch = compute_pitch_pytorch(audio, self.target_sample_hz)
1379-
pitch = rearrange(pitch, 'b n -> b 1 n')
1380-
1381-
# process mel
1382-
1383-
if not exists(mel):
1384-
assert exists(audio) and audio.ndim == 2
1385-
1386-
mel = self.audio_to_mel(audio)
1387-
mel = mel[..., :text_max_length]
1388-
1389-
mel_max_length = mel.shape[-1]
1390-
1391-
if not exists(mel_lens):
1392-
mel_lens = torch.full((batch,), mel_max_length, device = self.device, dtype = torch.long)
1393-
1394-
mel_mask = rearrange(create_mask(mel_lens, mel_max_length), 'b n -> b 1 n')
1395-
1396-
# alignment
1397-
1398-
aln_hard, aln_soft, aln_log, aln_mas = self.aligner(phoneme_enc, text_mask, mel, mel_mask)
1399-
duration_pred, pitch_pred = self.duration_pitch(phoneme_enc, prompt_enc)
1400-
1401-
pitch = average_over_durations(pitch, aln_hard)
1402-
cond = self.expand_encodings(rearrange(phoneme_enc, 'b n d -> b d n'), rearrange(aln_mas, 'b n c -> b 1 n c'), pitch)
1403-
1404-
# pitch and duration loss
1405-
1406-
duration_loss = F.l1_loss(aln_hard, duration_pred)
1407-
1408-
pitch = rearrange(pitch, 'b 1 d -> b d')
1409-
pitch_loss = F.l1_loss(pitch, pitch_pred)
1410-
1411-
# weigh the losses
1412-
1413-
aux_loss = duration_loss * self.duration_loss_weight + pitch_loss + self.pitch_loss_weight
1414-
1415-
return prompt_enc, cond, aux_loss
1416-
14171364
def expand_encodings(self, phoneme_enc, attn, pitch):
14181365
expanded_dur = einsum('k l m n, k j m -> k j n', attn, phoneme_enc)
14191366
pitch_emb = self.pitch_emb(rearrange(f0_to_coarse(pitch), 'b 1 t -> b t'))
@@ -1430,29 +1377,25 @@ def sample(
14301377
prompt = None,
14311378
batch_size = 1,
14321379
cond_scale = 1.,
1433-
pitch = None,
14341380
text = None,
14351381
text_lens = None,
1436-
mel = None,
1437-
mel_lens = None,
14381382
):
14391383
sample_fn = self.ddpm_sample if not self.use_ddim else self.ddim_sample
14401384

1441-
prompt = self.process_prompt(prompt)
1442-
14431385
prompt_enc = cond = None
14441386

14451387
if self.conditional:
1446-
assert exists(mel)
1447-
1448-
prompt_enc, cond, _ = self.process_conditioning(
1449-
prompt = prompt,
1450-
text = text,
1451-
pitch = pitch,
1452-
mel = mel,
1453-
text_lens = text_lens,
1454-
mel_lens = mel_lens
1455-
)
1388+
assert exists(prompt) and exists(text)
1389+
prompt = self.process_prompt(prompt)
1390+
prompt_enc = self.prompt_enc(prompt)
1391+
phoneme_enc = self.phoneme_enc(text)
1392+
1393+
duration, pitch = self.duration_pitch(phoneme_enc, prompt_enc)
1394+
pitch = rearrange(pitch, 'b n -> b 1 n')
1395+
1396+
aln_mask = generate_mask_from_lengths(duration).float()
1397+
1398+
cond = self.expand_encodings(rearrange(phoneme_enc, 'b n d -> b d n'), rearrange(aln_mask, 'b n c -> b 1 n c'), pitch)
14561399

14571400
if exists(prompt):
14581401
batch_size = prompt.shape[0]
@@ -1494,15 +1437,61 @@ def forward(
14941437
duration_pitch_loss = 0.
14951438

14961439
if self.conditional:
1497-
prompt_enc, cond, duration_pitch_loss = self.process_conditioning(
1498-
audio = audio,
1499-
prompt = prompt,
1500-
text = text,
1501-
pitch = pitch,
1502-
mel = mel,
1503-
text_lens = text_lens,
1504-
mel_lens = mel_lens
1505-
)
1440+
batch = prompt.shape[0]
1441+
1442+
assert exists(text)
1443+
text_max_length = text.shape[-1]
1444+
1445+
if not exists(text_lens):
1446+
text_lens = torch.full((batch,), text_max_length, device = self.device, dtype = torch.long)
1447+
1448+
text_mask = rearrange(create_mask(text_lens, text_max_length), 'b n -> b 1 n')
1449+
1450+
prompt = self.process_prompt(prompt)
1451+
prompt_enc = self.prompt_enc(prompt)
1452+
phoneme_enc = self.phoneme_enc(text)
1453+
1454+
# process pitch
1455+
1456+
if not exists(pitch):
1457+
assert exists(audio) and audio.ndim == 2
1458+
assert exists(self.target_sample_hz)
1459+
1460+
pitch = compute_pitch_pytorch(audio, self.target_sample_hz)
1461+
pitch = rearrange(pitch, 'b n -> b 1 n')
1462+
1463+
# process mel
1464+
1465+
if not exists(mel):
1466+
assert exists(audio) and audio.ndim == 2
1467+
mel = self.audio_to_mel(audio)
1468+
mel = mel[..., :pitch.shape[-1]]
1469+
1470+
mel_max_length = mel.shape[-1]
1471+
1472+
if not exists(mel_lens):
1473+
mel_lens = torch.full((batch,), mel_max_length, device = self.device, dtype = torch.long)
1474+
1475+
mel_mask = rearrange(create_mask(mel_lens, mel_max_length), 'b n -> b 1 n')
1476+
1477+
# alignment
1478+
1479+
aln_hard, aln_soft, aln_log, aln_mas = self.aligner(phoneme_enc, text_mask, mel, mel_mask)
1480+
duration_pred, pitch_pred = self.duration_pitch(phoneme_enc, prompt_enc)
1481+
1482+
pitch = average_over_durations(pitch, aln_hard)
1483+
cond = self.expand_encodings(rearrange(phoneme_enc, 'b n d -> b d n'), rearrange(aln_mas, 'b n c -> b 1 n c'), pitch)
1484+
1485+
# pitch and duration loss
1486+
1487+
duration_loss = F.l1_loss(aln_hard, duration_pred)
1488+
1489+
pitch = rearrange(pitch, 'b 1 d -> b d')
1490+
pitch_loss = F.l1_loss(pitch, pitch_pred)
1491+
1492+
# weigh the losses
1493+
1494+
aux_loss = duration_loss * self.duration_loss_weight + pitch_loss + self.pitch_loss_weight
15061495

15071496
# automatically encode raw audio to residual vq with codec
15081497

naturalspeech2_pytorch/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.0.48'
1+
__version__ = '0.0.49'

0 commit comments

Comments
 (0)