@@ -72,6 +72,23 @@ def prob_mask_like(shape, prob, device):
72
72
else :
73
73
return torch .zeros (shape , device = device ).float ().uniform_ (0 , 1 ) < prob
74
74
75
+ def generate_mask_from_lengths (lengths ):
76
+ src = lengths .int ()
77
+ device = src .device
78
+ tgt_length = src .sum (dim = - 1 ).amax ().item ()
79
+
80
+ cumsum = src .cumsum (dim = - 1 )
81
+ cumsum_exclusive = F .pad (cumsum , (1 , - 1 ), value = 0. )
82
+
83
+ tgt_arange = torch .arange (tgt_length , device = device )
84
+ tgt_arange = repeat (tgt_arange , '... j -> ... i j' , i = src .shape [- 1 ])
85
+
86
+ cumsum = rearrange (cumsum , '... i -> ... i 1' )
87
+ cumsum_exclusive = rearrange (cumsum_exclusive , '... i -> ... i 1' )
88
+
89
+ mask = (tgt_arange < cumsum ) & (tgt_arange >= cumsum_exclusive )
90
+ return mask
91
+
75
92
# sinusoidal positional embeds
76
93
77
94
class LearnedSinusoidalPosEmb (nn .Module ):
@@ -1344,76 +1361,6 @@ def process_prompt(self, prompt = None):
1344
1361
1345
1362
return prompt
1346
1363
1347
- def process_conditioning (
1348
- self ,
1349
- * ,
1350
- prompt ,
1351
- audio = None ,
1352
- pitch = None ,
1353
- text = None ,
1354
- text_lens = None ,
1355
- mel = None ,
1356
- mel_lens = None
1357
- ):
1358
- batch = prompt .shape [0 ]
1359
-
1360
- assert exists (text )
1361
- text_max_length = text .shape [- 1 ]
1362
-
1363
- if not exists (text_lens ):
1364
- text_lens = torch .full ((batch ,), text_max_length , device = self .device , dtype = torch .long )
1365
-
1366
- text_mask = rearrange (create_mask (text_lens , text_max_length ), 'b n -> b 1 n' )
1367
-
1368
- prompt = self .process_prompt (prompt )
1369
- prompt_enc = self .prompt_enc (prompt )
1370
- phoneme_enc = self .phoneme_enc (text )
1371
-
1372
- # process pitch
1373
-
1374
- if not exists (pitch ):
1375
- assert exists (audio ) and audio .ndim == 2
1376
- assert exists (self .target_sample_hz )
1377
-
1378
- pitch = compute_pitch_pytorch (audio , self .target_sample_hz )
1379
- pitch = rearrange (pitch , 'b n -> b 1 n' )
1380
-
1381
- # process mel
1382
-
1383
- if not exists (mel ):
1384
- assert exists (audio ) and audio .ndim == 2
1385
-
1386
- mel = self .audio_to_mel (audio )
1387
- mel = mel [..., :text_max_length ]
1388
-
1389
- mel_max_length = mel .shape [- 1 ]
1390
-
1391
- if not exists (mel_lens ):
1392
- mel_lens = torch .full ((batch ,), mel_max_length , device = self .device , dtype = torch .long )
1393
-
1394
- mel_mask = rearrange (create_mask (mel_lens , mel_max_length ), 'b n -> b 1 n' )
1395
-
1396
- # alignment
1397
-
1398
- aln_hard , aln_soft , aln_log , aln_mas = self .aligner (phoneme_enc , text_mask , mel , mel_mask )
1399
- duration_pred , pitch_pred = self .duration_pitch (phoneme_enc , prompt_enc )
1400
-
1401
- pitch = average_over_durations (pitch , aln_hard )
1402
- cond = self .expand_encodings (rearrange (phoneme_enc , 'b n d -> b d n' ), rearrange (aln_mas , 'b n c -> b 1 n c' ), pitch )
1403
-
1404
- # pitch and duration loss
1405
-
1406
- duration_loss = F .l1_loss (aln_hard , duration_pred )
1407
-
1408
- pitch = rearrange (pitch , 'b 1 d -> b d' )
1409
- pitch_loss = F .l1_loss (pitch , pitch_pred )
1410
-
1411
- # weigh the losses
1412
-
1413
- aux_loss = duration_loss * self .duration_loss_weight + pitch_loss + self .pitch_loss_weight
1414
-
1415
- return prompt_enc , cond , aux_loss
1416
-
1417
1364
def expand_encodings (self , phoneme_enc , attn , pitch ):
1418
1365
expanded_dur = einsum ('k l m n, k j m -> k j n' , attn , phoneme_enc )
1419
1366
pitch_emb = self .pitch_emb (rearrange (f0_to_coarse (pitch ), 'b 1 t -> b t' ))
@@ -1430,29 +1377,25 @@ def sample(
1430
1377
prompt = None ,
1431
1378
batch_size = 1 ,
1432
1379
cond_scale = 1. ,
1433
- pitch = None ,
1434
1380
text = None ,
1435
1381
text_lens = None ,
1436
- mel = None ,
1437
- mel_lens = None ,
1438
1382
):
1439
1383
sample_fn = self .ddpm_sample if not self .use_ddim else self .ddim_sample
1440
1384
1441
- prompt = self .process_prompt (prompt )
1442
-
1443
1385
prompt_enc = cond = None
1444
1386
1445
1387
if self .conditional :
1446
- assert exists (mel )
1447
-
1448
- prompt_enc , cond , _ = self .process_conditioning (
1449
- prompt = prompt ,
1450
- text = text ,
1451
- pitch = pitch ,
1452
- mel = mel ,
1453
- text_lens = text_lens ,
1454
- mel_lens = mel_lens
1455
- )
1388
+ assert exists (prompt ) and exists (text )
1389
+ prompt = self .process_prompt (prompt )
1390
+ prompt_enc = self .prompt_enc (prompt )
1391
+ phoneme_enc = self .phoneme_enc (text )
1392
+
1393
+ duration , pitch = self .duration_pitch (phoneme_enc , prompt_enc )
1394
+ pitch = rearrange (pitch , 'b n -> b 1 n' )
1395
+
1396
+ aln_mask = generate_mask_from_lengths (duration ).float ()
1397
+
1398
+ cond = self .expand_encodings (rearrange (phoneme_enc , 'b n d -> b d n' ), rearrange (aln_mask , 'b n c -> b 1 n c' ), pitch )
1456
1399
1457
1400
if exists (prompt ):
1458
1401
batch_size = prompt .shape [0 ]
@@ -1494,15 +1437,61 @@ def forward(
1494
1437
duration_pitch_loss = 0.
1495
1438
1496
1439
if self .conditional :
1497
- prompt_enc , cond , duration_pitch_loss = self .process_conditioning (
1498
- audio = audio ,
1499
- prompt = prompt ,
1500
- text = text ,
1501
- pitch = pitch ,
1502
- mel = mel ,
1503
- text_lens = text_lens ,
1504
- mel_lens = mel_lens
1505
- )
1440
+ batch = prompt .shape [0 ]
1441
+
1442
+ assert exists (text )
1443
+ text_max_length = text .shape [- 1 ]
1444
+
1445
+ if not exists (text_lens ):
1446
+ text_lens = torch .full ((batch ,), text_max_length , device = self .device , dtype = torch .long )
1447
+
1448
+ text_mask = rearrange (create_mask (text_lens , text_max_length ), 'b n -> b 1 n' )
1449
+
1450
+ prompt = self .process_prompt (prompt )
1451
+ prompt_enc = self .prompt_enc (prompt )
1452
+ phoneme_enc = self .phoneme_enc (text )
1453
+
1454
+ # process pitch
1455
+
1456
+ if not exists (pitch ):
1457
+ assert exists (audio ) and audio .ndim == 2
1458
+ assert exists (self .target_sample_hz )
1459
+
1460
+ pitch = compute_pitch_pytorch (audio , self .target_sample_hz )
1461
+ pitch = rearrange (pitch , 'b n -> b 1 n' )
1462
+
1463
+ # process mel
1464
+
1465
+ if not exists (mel ):
1466
+ assert exists (audio ) and audio .ndim == 2
1467
+ mel = self .audio_to_mel (audio )
1468
+ mel = mel [..., :pitch .shape [- 1 ]]
1469
+
1470
+ mel_max_length = mel .shape [- 1 ]
1471
+
1472
+ if not exists (mel_lens ):
1473
+ mel_lens = torch .full ((batch ,), mel_max_length , device = self .device , dtype = torch .long )
1474
+
1475
+ mel_mask = rearrange (create_mask (mel_lens , mel_max_length ), 'b n -> b 1 n' )
1476
+
1477
+ # alignment
1478
+
1479
+ aln_hard , aln_soft , aln_log , aln_mas = self .aligner (phoneme_enc , text_mask , mel , mel_mask )
1480
+ duration_pred , pitch_pred = self .duration_pitch (phoneme_enc , prompt_enc )
1481
+
1482
+ pitch = average_over_durations (pitch , aln_hard )
1483
+ cond = self .expand_encodings (rearrange (phoneme_enc , 'b n d -> b d n' ), rearrange (aln_mas , 'b n c -> b 1 n c' ), pitch )
1484
+
1485
+ # pitch and duration loss
1486
+
1487
+ duration_loss = F .l1_loss (aln_hard , duration_pred )
1488
+
1489
+ pitch = rearrange (pitch , 'b 1 d -> b d' )
1490
+ pitch_loss = F .l1_loss (pitch , pitch_pred )
1491
+
1492
+ # weigh the losses
1493
+
1494
+ aux_loss = duration_loss * self .duration_loss_weight + pitch_loss + self .pitch_loss_weight
1506
1495
1507
1496
# automatically encode raw audio to residual vq with codec
1508
1497
0 commit comments