@@ -56,6 +56,9 @@ def default(val, d):
56
56
return val
57
57
return d () if callable (d ) else d
58
58
59
+ def divisible_by (num , den ):
60
+ return (num % den ) == 0
61
+
59
62
def identity (t , * args , ** kwargs ):
60
63
return t
61
64
@@ -94,7 +97,7 @@ def generate_mask_from_lengths(lengths):
94
97
class LearnedSinusoidalPosEmb (nn .Module ):
95
98
def __init__ (self , dim ):
96
99
super ().__init__ ()
97
- assert (dim % 2 ) == 0
100
+ assert divisible_by (dim , 2 )
98
101
half_dim = dim // 2
99
102
self .weights = nn .Parameter (torch .randn (half_dim ))
100
103
@@ -115,19 +118,37 @@ def compute_pitch_pytorch(wav, sample_rate):
115
118
116
119
#as mentioned in paper using pyworld
117
120
118
- def compute_pitch (spec , sample_rate , hop_length , pitch_fmax = 640.0 ):
119
- # align F0 length to the spectrogram length
120
- if len (spec ) % hop_length == 0 :
121
- spec = np .pad (spec , (0 , hop_length // 2 ), mode = "reflect" )
121
+ def compute_pitch_pyworld (wav , sample_rate , hop_length , pitch_fmax = 640.0 ):
122
+ is_tensor_input = torch .is_tensor (wav )
122
123
123
- f0 , t = pw .dio (
124
- spec .astype (np .double ),
125
- fs = sample_rate ,
126
- f0_ceil = pitch_fmax ,
127
- frame_period = 1000 * hop_length / sample_rate ,
128
- )
129
- f0 = pw .stonemask (spec .astype (np .double ), f0 , t , sample_rate )
130
- return f0
124
+ if is_tensor_input :
125
+ device = wav .device
126
+ wav = wav .contiguous ().cpu ().numpy ()
127
+
128
+ if divisible_by (len (wav ), hop_length ):
129
+ wav = np .pad (wav , (0 , hop_length // 2 ), mode = "reflect" )
130
+
131
+ wav = wav .astype (np .double )
132
+
133
+ outs = []
134
+
135
+ for sample in wav :
136
+ f0 , t = pw .dio (
137
+ sample ,
138
+ fs = sample_rate ,
139
+ f0_ceil = pitch_fmax ,
140
+ frame_period = 1000 * hop_length / sample_rate ,
141
+ )
142
+
143
+ f0 = pw .stonemask (sample , f0 , t , sample_rate )
144
+ outs .append (f0 )
145
+
146
+ outs = np .stack (outs )
147
+
148
+ if is_tensor_input :
149
+ outs = torch .from_numpy (outs ).to (device )
150
+
151
+ return outs
131
152
132
153
def f0_to_coarse (f0 , f0_bin = 256 , f0_max = 1100.0 , f0_min = 50.0 ):
133
154
f0_mel_max = 1127 * torch .log (1 + torch .tensor (f0_max ) / 700 )
@@ -1115,6 +1136,8 @@ def __init__(
1115
1136
num_phoneme_tokens : int = 150 ,
1116
1137
pitch_emb_dim : int = 256 ,
1117
1138
pitch_emb_pp_hidden_dim : int = 512 ,
1139
+ calc_pitch_with_pyworld = True , # pyworld or kaldi from torchaudio
1140
+ mel_hop_length = 160 ,
1118
1141
audio_to_mel_kwargs : dict = dict (),
1119
1142
scale = 1. , # this will be set to < 1. for better convergence when training on higher resolution images
1120
1143
duration_loss_weight = 1. ,
@@ -1145,11 +1168,16 @@ def __init__(
1145
1168
if exists (self .target_sample_hz ):
1146
1169
audio_to_mel_kwargs .update (sampling_rate = self .target_sample_hz )
1147
1170
1171
+ self .mel_hop_length = mel_hop_length
1172
+
1148
1173
self .audio_to_mel = AudioToMel (
1149
1174
n_mels = aligner_dim_in ,
1175
+ hop_length = mel_hop_length ,
1150
1176
** audio_to_mel_kwargs
1151
1177
)
1152
1178
1179
+ self .calc_pitch_with_pyworld = calc_pitch_with_pyworld
1180
+
1153
1181
self .phoneme_enc = PhonemeEncoder (tokenizer = tokenizer , num_tokens = num_phoneme_tokens )
1154
1182
self .prompt_enc = SpeechPromptEncoder (dim_codebook = dim_codebook )
1155
1183
self .duration_pitch = DurationPitchPredictor (dim = duration_pitch_dim )
@@ -1456,21 +1484,31 @@ def forward(
1456
1484
prompt_enc = self .prompt_enc (prompt )
1457
1485
phoneme_enc = self .phoneme_enc (text )
1458
1486
1459
- # process pitch
1487
+ # process pitch with kaldi
1460
1488
1461
1489
if not exists (pitch ):
1462
1490
assert exists (audio ) and audio .ndim == 2
1463
1491
assert exists (self .target_sample_hz )
1464
1492
1465
- pitch = compute_pitch_pytorch (audio , self .target_sample_hz )
1493
+ if self .calc_pitch_with_pyworld :
1494
+ pitch = compute_pitch_pyworld (
1495
+ audio ,
1496
+ sample_rate = self .target_sample_hz ,
1497
+ hop_length = self .mel_hop_length
1498
+ )
1499
+ else :
1500
+ pitch = compute_pitch_pytorch (audio , self .target_sample_hz )
1501
+
1466
1502
pitch = rearrange (pitch , 'b n -> b 1 n' )
1467
1503
1468
1504
# process mel
1469
1505
1470
1506
if not exists (mel ):
1471
1507
assert exists (audio ) and audio .ndim == 2
1472
1508
mel = self .audio_to_mel (audio )
1473
- mel = mel [..., :pitch .shape [- 1 ]]
1509
+
1510
+ if exists (pitch ):
1511
+ mel = mel [..., :pitch .shape [- 1 ]]
1474
1512
1475
1513
mel_max_length = mel .shape [- 1 ]
1476
1514
@@ -1803,7 +1841,7 @@ def train(self):
1803
1841
if accelerator .is_main_process :
1804
1842
self .ema .update ()
1805
1843
1806
- if self .step % self .save_and_sample_every == 0 :
1844
+ if divisible_by ( self .step , self .save_and_sample_every ) :
1807
1845
milestone = self .step // self .save_and_sample_every
1808
1846
1809
1847
models = [(self .unwrapped_model , str (self .step ))]
0 commit comments