Skip to content

Commit 4be2fd3

Browse files
committed
just make a guess about pyworld and make sure it runs
1 parent fb9e1d5 commit 4be2fd3

File tree

3 files changed

+58
-19
lines changed

3 files changed

+58
-19
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,9 @@ trainer.train()
150150
- [x] complete perceiver then cross attention conditioning on ddpm side
151151
- [x] add classifier free guidance, even if not in paper
152152
- [x] complete duration / pitch prediction during training - thanks to Manmay
153+
- [x] make sure pyworld way of computing pitch can also work
153154

154-
- [ ] make sure pyworld way of computing pitch can also work
155+
- [ ] consult phd student in TTS field about pyworld usage
155156
- [ ] also offer direct summation conditioning using spear-tts text-to-semantic module, if available
156157
- [ ] add self-conditioning on ddpm side
157158
- [ ] take care of automatic slicing of audio for prompt, being aware of minimal audio segment as allowed by the codec model

naturalspeech2_pytorch/naturalspeech2_pytorch.py

Lines changed: 55 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ def default(val, d):
5656
return val
5757
return d() if callable(d) else d
5858

59+
def divisible_by(num, den):
60+
return (num % den) == 0
61+
5962
def identity(t, *args, **kwargs):
6063
return t
6164

@@ -94,7 +97,7 @@ def generate_mask_from_lengths(lengths):
9497
class LearnedSinusoidalPosEmb(nn.Module):
9598
def __init__(self, dim):
9699
super().__init__()
97-
assert (dim % 2) == 0
100+
assert divisible_by(dim, 2)
98101
half_dim = dim // 2
99102
self.weights = nn.Parameter(torch.randn(half_dim))
100103

@@ -115,19 +118,37 @@ def compute_pitch_pytorch(wav, sample_rate):
115118

116119
#as mentioned in paper using pyworld
117120

118-
def compute_pitch(spec, sample_rate, hop_length, pitch_fmax=640.0):
119-
# align F0 length to the spectrogram length
120-
if len(spec) % hop_length == 0:
121-
spec = np.pad(spec, (0, hop_length // 2), mode="reflect")
121+
def compute_pitch_pyworld(wav, sample_rate, hop_length, pitch_fmax=640.0):
122+
is_tensor_input = torch.is_tensor(wav)
122123

123-
f0, t = pw.dio(
124-
spec.astype(np.double),
125-
fs=sample_rate,
126-
f0_ceil=pitch_fmax,
127-
frame_period=1000 * hop_length / sample_rate,
128-
)
129-
f0 = pw.stonemask(spec.astype(np.double), f0, t, sample_rate)
130-
return f0
124+
if is_tensor_input:
125+
device = wav.device
126+
wav = wav.contiguous().cpu().numpy()
127+
128+
if divisible_by(len(wav), hop_length):
129+
wav = np.pad(wav, (0, hop_length // 2), mode="reflect")
130+
131+
wav = wav.astype(np.double)
132+
133+
outs = []
134+
135+
for sample in wav:
136+
f0, t = pw.dio(
137+
sample,
138+
fs = sample_rate,
139+
f0_ceil = pitch_fmax,
140+
frame_period = 1000 * hop_length / sample_rate,
141+
)
142+
143+
f0 = pw.stonemask(sample, f0, t, sample_rate)
144+
outs.append(f0)
145+
146+
outs = np.stack(outs)
147+
148+
if is_tensor_input:
149+
outs = torch.from_numpy(outs).to(device)
150+
151+
return outs
131152

132153
def f0_to_coarse(f0, f0_bin = 256, f0_max = 1100.0, f0_min = 50.0):
133154
f0_mel_max = 1127 * torch.log(1 + torch.tensor(f0_max) / 700)
@@ -1115,6 +1136,8 @@ def __init__(
11151136
num_phoneme_tokens: int = 150,
11161137
pitch_emb_dim: int = 256,
11171138
pitch_emb_pp_hidden_dim: int= 512,
1139+
calc_pitch_with_pyworld = True, # pyworld or kaldi from torchaudio
1140+
mel_hop_length = 160,
11181141
audio_to_mel_kwargs: dict = dict(),
11191142
scale = 1., # this will be set to < 1. for better convergence when training on higher resolution images
11201143
duration_loss_weight = 1.,
@@ -1145,11 +1168,16 @@ def __init__(
11451168
if exists(self.target_sample_hz):
11461169
audio_to_mel_kwargs.update(sampling_rate = self.target_sample_hz)
11471170

1171+
self.mel_hop_length = mel_hop_length
1172+
11481173
self.audio_to_mel = AudioToMel(
11491174
n_mels = aligner_dim_in,
1175+
hop_length = mel_hop_length,
11501176
**audio_to_mel_kwargs
11511177
)
11521178

1179+
self.calc_pitch_with_pyworld = calc_pitch_with_pyworld
1180+
11531181
self.phoneme_enc = PhonemeEncoder(tokenizer=tokenizer, num_tokens=num_phoneme_tokens)
11541182
self.prompt_enc = SpeechPromptEncoder(dim_codebook=dim_codebook)
11551183
self.duration_pitch = DurationPitchPredictor(dim=duration_pitch_dim)
@@ -1456,21 +1484,31 @@ def forward(
14561484
prompt_enc = self.prompt_enc(prompt)
14571485
phoneme_enc = self.phoneme_enc(text)
14581486

1459-
# process pitch
1487+
# process pitch with kaldi
14601488

14611489
if not exists(pitch):
14621490
assert exists(audio) and audio.ndim == 2
14631491
assert exists(self.target_sample_hz)
14641492

1465-
pitch = compute_pitch_pytorch(audio, self.target_sample_hz)
1493+
if self.calc_pitch_with_pyworld:
1494+
pitch = compute_pitch_pyworld(
1495+
audio,
1496+
sample_rate = self.target_sample_hz,
1497+
hop_length = self.mel_hop_length
1498+
)
1499+
else:
1500+
pitch = compute_pitch_pytorch(audio, self.target_sample_hz)
1501+
14661502
pitch = rearrange(pitch, 'b n -> b 1 n')
14671503

14681504
# process mel
14691505

14701506
if not exists(mel):
14711507
assert exists(audio) and audio.ndim == 2
14721508
mel = self.audio_to_mel(audio)
1473-
mel = mel[..., :pitch.shape[-1]]
1509+
1510+
if exists(pitch):
1511+
mel = mel[..., :pitch.shape[-1]]
14741512

14751513
mel_max_length = mel.shape[-1]
14761514

@@ -1803,7 +1841,7 @@ def train(self):
18031841
if accelerator.is_main_process:
18041842
self.ema.update()
18051843

1806-
if self.step % self.save_and_sample_every == 0:
1844+
if divisible_by(self.step, self.save_and_sample_every):
18071845
milestone = self.step // self.save_and_sample_every
18081846

18091847
models = [(self.unwrapped_model, str(self.step))]

naturalspeech2_pytorch/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.1.1'
1+
__version__ = '0.1.2'

0 commit comments

Comments
 (0)