Skip to content

Commit 091e603

Browse files
committed
condition latent features with aligned conditions prior to wavenet stack
1 parent e84d0a1 commit 091e603

File tree

2 files changed

+55
-5
lines changed

2 files changed

+55
-5
lines changed

naturalspeech2_pytorch/naturalspeech2_pytorch.py

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,15 @@ def has_int_squareroot(num):
6767

6868
# tensor helpers
6969

70+
def pad_or_curtail_to_length(t, length):
71+
if t.shape[-1] == length:
72+
return t
73+
74+
if t.shape[-1] > length:
75+
return t[..., :length]
76+
77+
return F.pad(t, (0, length - t.shape[-1]))
78+
7079
def prob_mask_like(shape, prob, device):
7180
if prob == 1:
7281
return torch.ones(shape, device = device, dtype = torch.bool)
@@ -834,6 +843,7 @@ def __init__(
834843
)
835844

836845
# prompt condition
846+
837847
self.cond_drop_prob = cond_drop_prob # for classifier free guidance
838848
self.condition_on_prompt = condition_on_prompt
839849
self.to_prompt_cond = None
@@ -861,6 +871,15 @@ def __init__(
861871
use_flash_attn = use_flash_attn
862872
)
863873

874+
# aligned conditioning from aligner + duration module
875+
876+
self.null_cond = None
877+
self.cond_to_model_dim = None
878+
879+
if self.condition_on_prompt:
880+
self.cond_to_model_dim = nn.Conv1d(dim_prompt, dim, 1)
881+
self.null_cond = nn.Parameter(torch.zeros(dim, 1))
882+
864883
# conditioning includes time and optionally prompt
865884

866885
dim_cond_mult = dim_cond_mult * (2 if condition_on_prompt else 1)
@@ -913,23 +932,27 @@ def forward(
913932
times,
914933
prompt = None,
915934
prompt_mask = None,
916-
cond= None,
935+
cond = None,
917936
cond_drop_prob = None
918937
):
919938
b = x.shape[0]
920939
cond_drop_prob = default(cond_drop_prob, self.cond_drop_prob)
921940

922-
drop_mask = prob_mask_like((b,), cond_drop_prob, self.device)
941+
# prepare prompt condition
942+
# prob should remove going forward
923943

924944
t = self.to_time_cond(times)
925945
c = None
926946

927947
if exists(self.to_prompt_cond):
928948
assert exists(prompt)
949+
950+
prompt_cond_drop_mask = prob_mask_like((b,), cond_drop_prob, self.device)
951+
929952
prompt_cond = self.to_prompt_cond(prompt)
930953

931954
prompt_cond = torch.where(
932-
rearrange(drop_mask, 'b -> b 1'),
955+
rearrange(prompt_cond_drop_mask, 'b -> b 1'),
933956
self.null_prompt_cond,
934957
prompt_cond,
935958
)
@@ -939,12 +962,37 @@ def forward(
939962
resampled_prompt_tokens = self.perceiver_resampler(prompt, mask = prompt_mask)
940963

941964
c = torch.where(
942-
rearrange(drop_mask, 'b -> b 1 1'),
965+
rearrange(prompt_cond_drop_mask, 'b -> b 1 1'),
943966
self.null_prompt_tokens,
944967
resampled_prompt_tokens
945968
)
946969

970+
# rearrange to channel first
971+
947972
x = rearrange(x, 'b n d -> b d n')
973+
974+
# sum aligned condition to input sequence
975+
976+
if exists(self.cond_to_model_dim):
977+
assert exists(cond)
978+
cond = self.cond_to_model_dim(cond)
979+
980+
cond_drop_mask = prob_mask_like((b,), cond_drop_prob, self.device)
981+
982+
cond = torch.where(
983+
rearrange(cond_drop_mask, 'b -> b 1 1'),
984+
self.null_cond,
985+
cond
986+
)
987+
988+
# for now, conform the condition to the length of the latent features
989+
990+
cond = pad_or_curtail_to_length(cond, x.shape[-1])
991+
992+
x = x + cond
993+
994+
# main wavenet body
995+
948996
x = self.wavenet(x, t)
949997
x = rearrange(x, 'b d n -> b n d')
950998

@@ -1527,6 +1575,7 @@ def forward(
15271575
duration_pred, pitch_pred = self.duration_pitch(phoneme_enc, prompt_enc)
15281576

15291577
pitch = average_over_durations(pitch, aln_hard)
1578+
15301579
cond = self.expand_encodings(rearrange(phoneme_enc, 'b n d -> b d n'), rearrange(aln_mask, 'b n c -> b 1 n c'), pitch)
15311580

15321581
# pitch and duration loss
@@ -1536,6 +1585,7 @@ def forward(
15361585
pitch = rearrange(pitch, 'b 1 d -> b d')
15371586
pitch_loss = F.l1_loss(pitch, pitch_pred)
15381587
align_loss = self.aligner_loss(aln_log , text_lens, mel_lens)
1588+
15391589
# weigh the losses
15401590

15411591
aux_loss = (duration_loss * self.duration_loss_weight) \

naturalspeech2_pytorch/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.1.6'
1+
__version__ = '0.1.7'

0 commit comments

Comments
 (0)