Skip to content

Commit e637ea2

Browse files
authored
Merge pull request #82 from r9y9/pytorch0.4
Support PyTorch >= v0.4
2 parents a0f65c6 + 2b58330 commit e637ea2

File tree

14 files changed

+103
-120
lines changed

14 files changed

+103
-120
lines changed

README.md

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ A notebook supposed to be executed on https://colab.research.google.com is avail
2424
- Multi-speaker and single speaker versions of DeepVoice3
2525
- Audio samples and pre-trained models
2626
- Preprocessor for [LJSpeech (en)](https://keithito.com/LJ-Speech-Dataset/), [JSUT (jp)](https://sites.google.com/site/shinnosuketakamichi/publication/jsut) and [VCTK](http://homepages.inf.ed.ac.uk/jyamagis/page3/page58/page58.html) datasets, as well as [carpedm20/multi-speaker-tacotron-tensorflow](https://github.com/carpedm20/multi-Speaker-tacotron-tensorflow) compatible custom dataset (in JSON format)
27-
- Language-dependent frontend text processor for English and Japanese
27+
- Language-dependent frontend text processor for English and Japanese
2828

2929
### Samples
3030

@@ -61,6 +61,7 @@ See "Synthesize from a checkpoint" section in the README for how to generate spe
6161

6262
- Python 3
6363
- CUDA >= 8.0
64+
- PyTorch >= v0.4.0
6465
- TensorFlow >= v1.3
6566
- [nnmnkwii](https://github.com/r9y9/nnmnkwii) >= v0.0.11
6667
- [MeCab](http://taku910.github.io/mecab/) (Japanese only)
@@ -104,7 +105,7 @@ python train.py --preset=presets/deepvoice3_ljspeech.json --data-root=./data/ljs
104105
- LJSpeech (en): https://keithito.com/LJ-Speech-Dataset/
105106
- VCTK (en): http://homepages.inf.ed.ac.uk/jyamagis/page3/page58/page58.html
106107
- JSUT (jp): https://sites.google.com/site/shinnosuketakamichi/publication/jsut
107-
- NIKL (ko) (**Need korean cellphone number to access it**): http://www.korean.go.kr/front/board/boardStandardView.do?board_id=4&mn_id=17&b_seq=464
108+
- NIKL (ko) (**Need korean cellphone number to access it**): http://www.korean.go.kr/front/board/boardStandardView.do?board_id=4&mn_id=17&b_seq=464
108109

109110
### 1. Preprocessing
110111

@@ -147,14 +148,14 @@ python preprocess.py json_meta "./datasets/datasetA/alignment.json,./datasets/da
147148

148149
#### 1-2. Preprocessing custom english datasets with long silence. (Based on [vctk_preprocess](vctk_preprocess/))
149150

150-
Some dataset, especially automatically generated dataset may include long silence and undesirable leading/trailing noises, undermining the char-level seq2seq model.
151+
Some dataset, especially automatically generated dataset may include long silence and undesirable leading/trailing noises, undermining the char-level seq2seq model.
151152
(e.g. VCTK, although this is covered in vctk_preprocess)
152153

153154
To deal with the problem, `gentle_web_align.py` will
154-
- **Prepare phoneme alignments for all utterances**
155-
- Cut silences during preprocessing
155+
- **Prepare phoneme alignments for all utterances**
156+
- Cut silences during preprocessing
156157

157-
`gentle_web_align.py` uses [Gentle](https://github.com/lowerquality/gentle), a kaldi based speech-text alignment tool. This accesses web-served Gentle application, aligns given sound segments with transcripts and converts the result to HTK-style label files, to be processed in `preprocess.py`. Gentle can be run in Linux/Mac/Windows(via Docker).
158+
`gentle_web_align.py` uses [Gentle](https://github.com/lowerquality/gentle), a kaldi based speech-text alignment tool. This accesses web-served Gentle application, aligns given sound segments with transcripts and converts the result to HTK-style label files, to be processed in `preprocess.py`. Gentle can be run in Linux/Mac/Windows(via Docker).
158159

159160
Preliminary results show that while HTK/festival/merlin-based method in `vctk_preprocess/prepare_vctk_labels.py` works better on VCTK, Gentle is more stable with audio clips with ambient noise. (e.g. movie excerpts)
160161

@@ -182,7 +183,7 @@ python train.py --data-root=${data-root} --preset=<json> --hparams="parameters y
182183
Suppose you build a DeepVoice3-style model using LJSpeech dataset, then you can train your model by:
183184

184185
```
185-
python train.py --preset=presets/deepvoice3_ljspeech.json --data-root=./data/ljspeech/
186+
python train.py --preset=presets/deepvoice3_ljspeech.json --data-root=./data/ljspeech/
186187
```
187188

188189
Model checkpoints (.pth) and alignments (.png) are saved in `./checkpoints` directory per 10000 steps by default.
@@ -290,9 +291,9 @@ From my experience, it can get reasonable speech quality very quickly rather tha
290291
There are two important options used above:
291292

292293
- `--restore-parts=<N>`: It specifies where to load model parameters. The differences from the option `--checkpoint=<N>` are 1) `--restore-parts=<N>` ignores all invalid parameters, while `--checkpoint=<N>` doesn't. 2) `--restore-parts=<N>` tell trainer to start from 0-step, while `--checkpoint=<N>` tell trainer to continue from last step. `--checkpoint=<N>` should be ok if you are using exactly same model and continue to train, but it would be useful if you want to customize your model architecture and take advantages of pre-trained model.
293-
- `--speaker-id=<N>`: It specifies what speaker of data is used for training. This should only be specified if you are using multi-speaker dataset. As for VCTK, speaker id is automatically assigned incrementally (0, 1, ..., 107) according to the `speaker_info.txt` in the dataset.
294+
- `--speaker-id=<N>`: It specifies what speaker of data is used for training. This should only be specified if you are using multi-speaker dataset. As for VCTK, speaker id is automatically assigned incrementally (0, 1, ..., 107) according to the `speaker_info.txt` in the dataset.
294295

295-
If you are training multi-speaker model, speaker adaptation will only work **when `n_speakers` is identical**.
296+
If you are training multi-speaker model, speaker adaptation will only work **when `n_speakers` is identical**.
296297

297298
## Acknowledgements
298299

audio.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def load_wav(path):
1414

1515

1616
def save_wav(wav, path):
17-
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
17+
wav = wav * 32767 / max(0.01, np.max(np.abs(wav)))
1818
wavfile.write(path, hparams.sample_rate, wav.astype(np.int16))
1919

2020

deepvoice3_pytorch/conv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def incremental_forward(self, input):
4040
self.input_buffer[:, :-1, :] = self.input_buffer[:, 1:, :].clone()
4141
# append next input
4242
self.input_buffer[:, -1, :] = input[:, -1, :]
43-
input = self.input_buffer.clone()
43+
input = self.input_buffer
4444
if dilation > 1:
4545
input = input[:, 0::dilation, :].contiguous()
4646
output = F.linear(input.view(bsz, -1), weight, self.bias)

deepvoice3_pytorch/deepvoice3.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import torch
44
from torch import nn
55
from torch.nn import functional as F
6-
from torch.autograd import Variable
76
import math
87
import numpy as np
98

@@ -207,9 +206,9 @@ def __init__(self, embed_dim, n_speakers, speaker_embed_dim,
207206

208207
# Position encodings for query (decoder states) and keys (encoder states)
209208
self.embed_query_positions = SinusoidalEncoding(
210-
max_positions, convolutions[0][0], padding_idx)
209+
max_positions, convolutions[0][0])
211210
self.embed_keys_positions = SinusoidalEncoding(
212-
max_positions, embed_dim, padding_idx)
211+
max_positions, embed_dim)
213212
# Used for compute multiplier for positional encodings
214213
if n_speakers > 1:
215214
self.speaker_proj1 = Linear(speaker_embed_dim, 1, dropout=dropout)
@@ -393,12 +392,11 @@ def incremental_forward(self, encoder_out, text_positions, speaker_embed=None,
393392
num_attention_layers = sum([layer is not None for layer in self.attention])
394393
t = 0
395394
if initial_input is None:
396-
initial_input = Variable(
397-
keys.data.new(B, 1, self.in_dim * self.r).zero_())
395+
initial_input = keys.data.new(B, 1, self.in_dim * self.r).zero_()
398396
current_input = initial_input
399397
while True:
400398
# frame pos start with 1.
401-
frame_pos = Variable(keys.data.new(B, 1).fill_(t + 1)).long()
399+
frame_pos = keys.data.new(B, 1).fill_(t + 1).long()
402400
w = self.query_position_rate
403401
if self.speaker_proj2 is not None:
404402
w = w * F.sigmoid(self.speaker_proj2(speaker_embed)).view(-1)

deepvoice3_pytorch/modules.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,24 +32,24 @@ def sinusoidal_encode(x, w):
3232

3333

3434
class SinusoidalEncoding(nn.Embedding):
35-
def __init__(self, num_embeddings, embedding_dim, padding_idx=0,
35+
36+
def __init__(self, num_embeddings, embedding_dim,
3637
*args, **kwargs):
3738
super(SinusoidalEncoding, self).__init__(num_embeddings, embedding_dim,
38-
padding_idx, *args, **kwargs)
39+
padding_idx=0,
40+
*args, **kwargs)
3941
self.weight.data = position_encoding_init(num_embeddings, embedding_dim,
4042
position_rate=1.0,
4143
sinusoidal=False)
4244

4345
def forward(self, x, w=1.0):
4446
isscaler = np.isscalar(w)
45-
padding_idx = self.padding_idx
46-
if padding_idx is None:
47-
padding_idx = -1
47+
assert self.padding_idx is not None
4848

4949
if isscaler or w.size(0) == 1:
5050
weight = sinusoidal_encode(self.weight, w)
5151
return F.embedding(
52-
x, weight, padding_idx, self.max_norm,
52+
x, weight, self.padding_idx, self.max_norm,
5353
self.norm_type, self.scale_grad_by_freq, self.sparse)
5454
else:
5555
# TODO: cannot simply apply for batch
@@ -58,7 +58,7 @@ def forward(self, x, w=1.0):
5858
for batch_idx, we in enumerate(w):
5959
weight = sinusoidal_encode(self.weight, we)
6060
pe.append(F.embedding(
61-
x[batch_idx], weight, padding_idx, self.max_norm,
61+
x[batch_idx], weight, self.padding_idx, self.max_norm,
6262
self.norm_type, self.scale_grad_by_freq, self.sparse))
6363
pe = torch.stack(pe)
6464
return pe

deepvoice3_pytorch/nyanko.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import torch
44
from torch import nn
55
from torch.nn import functional as F
6-
from torch.autograd import Variable
76
import math
87
import numpy as np
98

@@ -270,12 +269,11 @@ def incremental_forward(self, encoder_out, text_positions,
270269

271270
t = 0
272271
if initial_input is None:
273-
initial_input = Variable(
274-
keys.data.new(B, 1, self.in_dim * self.r).zero_())
272+
initial_input = keys.data.new(B, 1, self.in_dim * self.r).zero_()
275273
current_input = initial_input
276274
while True:
277275
# frame pos start with 1.
278-
frame_pos = Variable(keys.data.new(B, 1).fill_(t + 1)).long()
276+
frame_pos = keys.data.new(B, 1).fill_(t + 1).long()
279277
frame_pos_embed = self.embed_query_positions(frame_pos)
280278

281279
if test_inputs is not None:

hparams.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@
9999
adam_beta1=0.5,
100100
adam_beta2=0.9,
101101
adam_eps=1e-6,
102+
amsgrad=False,
102103
initial_learning_rate=5e-4, # 0.001,
103104
lr_schedule="noam_learning_rate_decay",
104105
lr_schedule_kwargs={},
@@ -125,14 +126,16 @@
125126
# Forced garbage collection probability
126127
# Use only when MemoryError continues in Windows (Disabled by default)
127128
#gc_probability = 0.001,
128-
129+
129130
# json_meta mode only
130131
# 0: "use all",
131132
# 1: "ignore only unmatched_alignment",
132133
# 2: "fully ignore recognition",
133-
ignore_recognition_level = 2,
134-
min_text=20, # when dealing with non-dedicated speech dataset(e.g. movie excerpts), setting min_text above 15 is desirable. Can be adjusted by dataset.
135-
process_only_htk_aligned = False, # if true, data without phoneme alignment file(.lab) will be ignored
134+
ignore_recognition_level=2,
135+
# when dealing with non-dedicated speech dataset(e.g. movie excerpts), setting min_text above 15 is desirable. Can be adjusted by dataset.
136+
min_text=20,
137+
# if true, data without phoneme alignment file(.lab) will be ignored
138+
process_only_htk_aligned=False,
136139
)
137140

138141

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def create_readme_rst():
7979
install_requires=[
8080
"numpy",
8181
"scipy",
82-
"torch >= 0.3.0",
82+
"torch >= 0.4.0",
8383
"unidecode",
8484
"inflect",
8585
"librosa",

synthesis.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import audio
2626

2727
import torch
28-
from torch.autograd import Variable
2928
import numpy as np
3029
import nltk
3130

@@ -36,6 +35,7 @@
3635
from tqdm import tqdm
3736

3837
use_cuda = torch.cuda.is_available()
38+
device = torch.device("cuda" if use_cuda else "cpu")
3939
_frontend = None # to be set later
4040

4141

@@ -46,25 +46,20 @@ def tts(model, text, p=0, speaker_id=None, fast=False):
4646
text (str) : Input text to be synthesized
4747
p (float) : Replace word to pronounciation if p > 0. Default is 0.
4848
"""
49-
if use_cuda:
50-
model = model.cuda()
49+
model = model.to(device)
5150
model.eval()
5251
if fast:
5352
model.make_generation_fast_()
5453

5554
sequence = np.array(_frontend.text_to_sequence(text, p=p))
56-
sequence = Variable(torch.from_numpy(sequence)).unsqueeze(0).long()
57-
text_positions = torch.arange(1, sequence.size(-1) + 1).unsqueeze(0).long()
58-
text_positions = Variable(text_positions)
59-
speaker_ids = None if speaker_id is None else Variable(torch.LongTensor([speaker_id]))
60-
if use_cuda:
61-
sequence = sequence.cuda()
62-
text_positions = text_positions.cuda()
63-
speaker_ids = None if speaker_ids is None else speaker_ids.cuda()
55+
sequence = torch.from_numpy(sequence).unsqueeze(0).long().to(device)
56+
text_positions = torch.arange(1, sequence.size(-1) + 1).unsqueeze(0).long().to(device)
57+
speaker_ids = None if speaker_id is None else torch.LongTensor([speaker_id]).to(device)
6458

6559
# Greedy decoding
66-
mel_outputs, linear_outputs, alignments, done = model(
67-
sequence, text_positions=text_positions, speaker_ids=speaker_ids)
60+
with torch.no_grad():
61+
mel_outputs, linear_outputs, alignments, done = model(
62+
sequence, text_positions=text_positions, speaker_ids=speaker_ids)
6863

6964
linear_output = linear_outputs[0].cpu().data.numpy()
7065
spectrogram = audio._denormalize(linear_output)

tests/test_conv.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
import torch
55
from torch import nn
6-
from torch.autograd import Variable
76
from torch.nn import functional as F
87
from deepvoice3_pytorch.conv import Conv1d
98

@@ -36,7 +35,7 @@ def __test(kernel_size, dilation, T, B, C, causual=True):
3635
conv_online.bias.data.zero_()
3736

3837
# (B, C, T)
39-
bct = Variable(torch.zeros(B, C, T) + torch.arange(0, T))
38+
bct = torch.zeros(B, C, T) + torch.arange(0, T).float()
4039
output_conv = conv(bct)
4140

4241
# Remove future time stamps

0 commit comments

Comments
 (0)