@@ -67,6 +67,15 @@ def has_int_squareroot(num):
67
67
68
68
# tensor helpers
69
69
70
+ def pad_or_curtail_to_length (t , length ):
71
+ if t .shape [- 1 ] == length :
72
+ return t
73
+
74
+ if t .shape [- 1 ] > length :
75
+ return t [..., :length ]
76
+
77
+ return F .pad (t , (0 , length - t .shape [- 1 ]))
78
+
70
79
def prob_mask_like (shape , prob , device ):
71
80
if prob == 1 :
72
81
return torch .ones (shape , device = device , dtype = torch .bool )
@@ -834,6 +843,7 @@ def __init__(
834
843
)
835
844
836
845
# prompt condition
846
+
837
847
self .cond_drop_prob = cond_drop_prob # for classifier free guidance
838
848
self .condition_on_prompt = condition_on_prompt
839
849
self .to_prompt_cond = None
@@ -861,6 +871,15 @@ def __init__(
861
871
use_flash_attn = use_flash_attn
862
872
)
863
873
874
+ # aligned conditioning from aligner + duration module
875
+
876
+ self .null_cond = None
877
+ self .cond_to_model_dim = None
878
+
879
+ if self .condition_on_prompt :
880
+ self .cond_to_model_dim = nn .Conv1d (dim_prompt , dim , 1 )
881
+ self .null_cond = nn .Parameter (torch .zeros (dim , 1 ))
882
+
864
883
# conditioning includes time and optionally prompt
865
884
866
885
dim_cond_mult = dim_cond_mult * (2 if condition_on_prompt else 1 )
@@ -913,23 +932,27 @@ def forward(
913
932
times ,
914
933
prompt = None ,
915
934
prompt_mask = None ,
916
- cond = None ,
935
+ cond = None ,
917
936
cond_drop_prob = None
918
937
):
919
938
b = x .shape [0 ]
920
939
cond_drop_prob = default (cond_drop_prob , self .cond_drop_prob )
921
940
922
- drop_mask = prob_mask_like ((b ,), cond_drop_prob , self .device )
941
+ # prepare prompt condition
942
+ # prob should remove going forward
923
943
924
944
t = self .to_time_cond (times )
925
945
c = None
926
946
927
947
if exists (self .to_prompt_cond ):
928
948
assert exists (prompt )
949
+
950
+ prompt_cond_drop_mask = prob_mask_like ((b ,), cond_drop_prob , self .device )
951
+
929
952
prompt_cond = self .to_prompt_cond (prompt )
930
953
931
954
prompt_cond = torch .where (
932
- rearrange (drop_mask , 'b -> b 1' ),
955
+ rearrange (prompt_cond_drop_mask , 'b -> b 1' ),
933
956
self .null_prompt_cond ,
934
957
prompt_cond ,
935
958
)
@@ -939,12 +962,37 @@ def forward(
939
962
resampled_prompt_tokens = self .perceiver_resampler (prompt , mask = prompt_mask )
940
963
941
964
c = torch .where (
942
- rearrange (drop_mask , 'b -> b 1 1' ),
965
+ rearrange (prompt_cond_drop_mask , 'b -> b 1 1' ),
943
966
self .null_prompt_tokens ,
944
967
resampled_prompt_tokens
945
968
)
946
969
970
+ # rearrange to channel first
971
+
947
972
x = rearrange (x , 'b n d -> b d n' )
973
+
974
+ # sum aligned condition to input sequence
975
+
976
+ if exists (self .cond_to_model_dim ):
977
+ assert exists (cond )
978
+ cond = self .cond_to_model_dim (cond )
979
+
980
+ cond_drop_mask = prob_mask_like ((b ,), cond_drop_prob , self .device )
981
+
982
+ cond = torch .where (
983
+ rearrange (cond_drop_mask , 'b -> b 1 1' ),
984
+ self .null_cond ,
985
+ cond
986
+ )
987
+
988
+ # for now, conform the condition to the length of the latent features
989
+
990
+ cond = pad_or_curtail_to_length (cond , x .shape [- 1 ])
991
+
992
+ x = x + cond
993
+
994
+ # main wavenet body
995
+
948
996
x = self .wavenet (x , t )
949
997
x = rearrange (x , 'b d n -> b n d' )
950
998
@@ -1527,6 +1575,7 @@ def forward(
1527
1575
duration_pred , pitch_pred = self .duration_pitch (phoneme_enc , prompt_enc )
1528
1576
1529
1577
pitch = average_over_durations (pitch , aln_hard )
1578
+
1530
1579
cond = self .expand_encodings (rearrange (phoneme_enc , 'b n d -> b d n' ), rearrange (aln_mask , 'b n c -> b 1 n c' ), pitch )
1531
1580
1532
1581
# pitch and duration loss
@@ -1536,6 +1585,7 @@ def forward(
1536
1585
pitch = rearrange (pitch , 'b 1 d -> b d' )
1537
1586
pitch_loss = F .l1_loss (pitch , pitch_pred )
1538
1587
align_loss = self .aligner_loss (aln_log , text_lens , mel_lens )
1588
+
1539
1589
# weigh the losses
1540
1590
1541
1591
aux_loss = (duration_loss * self .duration_loss_weight ) \
0 commit comments