Skip to content

Commit deb97db

Browse files
committed
account for variable sequence lengths when generating alignment mask from durations
1 parent 4be2fd3 commit deb97db

File tree

2 files changed

+12
-10
lines changed

2 files changed

+12
-10
lines changed

naturalspeech2_pytorch/naturalspeech2_pytorch.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -75,21 +75,23 @@ def prob_mask_like(shape, prob, device):
7575
else:
7676
return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob
7777

78-
def generate_mask_from_lengths(lengths):
79-
src = lengths.int()
80-
device = src.device
81-
tgt_length = src.sum(dim = -1).amax().item()
78+
def generate_mask_from_repeats(repeats):
79+
repeats = repeats.int()
80+
device = repeats.device
8281

83-
cumsum = src.cumsum(dim = -1)
82+
lengths = repeats.sum(dim = -1)
83+
max_length = lengths.amax().item()
84+
cumsum = repeats.cumsum(dim = -1)
8485
cumsum_exclusive = F.pad(cumsum, (1, -1), value = 0.)
8586

86-
tgt_arange = torch.arange(tgt_length, device = device)
87-
tgt_arange = repeat(tgt_arange, '... j -> ... i j', i = src.shape[-1])
87+
seq = torch.arange(max_length, device = device)
88+
seq = repeat(seq, '... j -> ... i j', i = repeats.shape[-1])
8889

8990
cumsum = rearrange(cumsum, '... i -> ... i 1')
9091
cumsum_exclusive = rearrange(cumsum_exclusive, '... i -> ... i 1')
9192

92-
mask = (tgt_arange < cumsum) & (tgt_arange >= cumsum_exclusive)
93+
lengths = rearrange(lengths, 'b -> b 1 1')
94+
mask = (seq < cumsum) & (seq >= cumsum_exclusive) & (seq < lengths)
9395
return mask
9496

9597
# sinusoidal positional embeds
@@ -1424,7 +1426,7 @@ def sample(
14241426
duration, pitch = self.duration_pitch(phoneme_enc, prompt_enc)
14251427
pitch = rearrange(pitch, 'b n -> b 1 n')
14261428

1427-
aln_mask = generate_mask_from_lengths(duration).float()
1429+
aln_mask = generate_mask_from_repeats(duration).float()
14281430

14291431
cond = self.expand_encodings(rearrange(phoneme_enc, 'b n d -> b d n'), rearrange(aln_mask, 'b n c -> b 1 n c'), pitch)
14301432

naturalspeech2_pytorch/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.1.2'
1+
__version__ = '0.1.4'

0 commit comments

Comments
 (0)