@@ -75,21 +75,23 @@ def prob_mask_like(shape, prob, device):
75
75
else :
76
76
return torch .zeros (shape , device = device ).float ().uniform_ (0 , 1 ) < prob
77
77
78
- def generate_mask_from_lengths (lengths ):
79
- src = lengths .int ()
80
- device = src .device
81
- tgt_length = src .sum (dim = - 1 ).amax ().item ()
78
+ def generate_mask_from_repeats (repeats ):
79
+ repeats = repeats .int ()
80
+ device = repeats .device
82
81
83
- cumsum = src .cumsum (dim = - 1 )
82
+ lengths = repeats .sum (dim = - 1 )
83
+ max_length = lengths .amax ().item ()
84
+ cumsum = repeats .cumsum (dim = - 1 )
84
85
cumsum_exclusive = F .pad (cumsum , (1 , - 1 ), value = 0. )
85
86
86
- tgt_arange = torch .arange (tgt_length , device = device )
87
- tgt_arange = repeat (tgt_arange , '... j -> ... i j' , i = src .shape [- 1 ])
87
+ seq = torch .arange (max_length , device = device )
88
+ seq = repeat (seq , '... j -> ... i j' , i = repeats .shape [- 1 ])
88
89
89
90
cumsum = rearrange (cumsum , '... i -> ... i 1' )
90
91
cumsum_exclusive = rearrange (cumsum_exclusive , '... i -> ... i 1' )
91
92
92
- mask = (tgt_arange < cumsum ) & (tgt_arange >= cumsum_exclusive )
93
+ lengths = rearrange (lengths , 'b -> b 1 1' )
94
+ mask = (seq < cumsum ) & (seq >= cumsum_exclusive ) & (seq < lengths )
93
95
return mask
94
96
95
97
# sinusoidal positional embeds
@@ -1424,7 +1426,7 @@ def sample(
1424
1426
duration , pitch = self .duration_pitch (phoneme_enc , prompt_enc )
1425
1427
pitch = rearrange (pitch , 'b n -> b 1 n' )
1426
1428
1427
- aln_mask = generate_mask_from_lengths (duration ).float ()
1429
+ aln_mask = generate_mask_from_repeats (duration ).float ()
1428
1430
1429
1431
cond = self .expand_encodings (rearrange (phoneme_enc , 'b n d -> b d n' ), rearrange (aln_mask , 'b n c -> b 1 n c' ), pitch )
1430
1432
0 commit comments