Skip to content

Commit 9b734c9

Browse files
authored
rewrite_conformer_evalnet (#176)
* conformer_change * update readme * update eval_net * fix ds2 bugs
1 parent adc2b94 commit 9b734c9

File tree

9 files changed

+34
-30
lines changed

9 files changed

+34
-30
lines changed

examples/conformer/README_CN.md

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,21 @@ python train.py --config_path ./conformer.yaml
8080
此样例使用 8张NPU.
8181
```shell
8282
# Distribute_training
83-
mpirun -n 8 python train.py ----config_path ./conformer.yaml
83+
mpirun -n 8 python train.py --config_path ./conformer.yaml
8484
```
85-
注意:如果脚本是由root用户执行的,必须在mpirun中添加——allow-run-as-root参数,如下所示:
85+
注意:
86+
87+
1.采用多卡训练时需确保yaml文件中的is_distributed为True,可通过更改yaml或在命令行中添加参数进行配置。
88+
89+
```shell
90+
# Distribute_training
91+
mpirun -n 8 python train.py --config_path ./conformer.yaml --is_distributed True
92+
```
93+
94+
2.如果脚本是由root用户执行的,必须在mpirun中添加——allow-run-as-root参数,如下所示:
95+
8696
```shell
87-
mpirun --allow-run-as-root -n 8 python train.py ----config_path ./conformer.yaml
97+
mpirun --allow-run-as-root -n 8 python train.py --config_path ./conformer.yaml
8898
```
8999

90100
如在GPU中进行训练,可更改yaml文件中的配置。

examples/conformer/asr_model.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -355,11 +355,14 @@ def creadte_asr_model(config, input_dim, vocab_size):
355355
class create_asr_eval_net(nn.Cell):
356356
"""Create ASR eval network."""
357357

358-
def __init__(self, network):
358+
def __init__(self, network, device_num):
359359
super(create_asr_eval_net, self).__init__()
360360
self.network = network
361-
self.device_num = 1
362-
self.all_reduce = None
361+
self.device_num = device_num
362+
if device_num > 1:
363+
self.all_reduce = ops.AllReduce()
364+
else:
365+
self.all_reduce = None
363366

364367
def construct(self, *inputs, **kwargs):
365368
loss = self.network(*inputs, **kwargs)

examples/conformer/conformer.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ is_distributed: False
102102
mixed_precision: True
103103
resume_ckpt: ""
104104
save_graphs: False
105-
training_with_eval: True
105+
training_with_eval: False
106106

107107
# decode option
108108
test_data: "/data/test.csv"

examples/conformer/dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -694,7 +694,7 @@ def create_dataset(
694694
group_size=group_size,
695695
)
696696

697-
sampler = DistributedSampler(dataset, rank, group_size, shuffle=True)
697+
sampler = DistributedSampler(dataset, rank, group_size, shuffle=True, group=False)
698698

699699
ds = de.GeneratorDataset(
700700
dataset,

examples/conformer/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def train():
145145
]
146146

147147
if config.training_with_eval:
148-
eval_net = create_asr_eval_net(net_with_loss)
148+
eval_net = create_asr_eval_net(net_with_loss, device_num)
149149
callback_list.append(
150150
EvalCallback(
151151
eval_net,

examples/deepspeech2/README_CN.md

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,16 +34,15 @@ DeepSpeech2是一种采用CTC损失训练的语音识别模型。它用神经网
3434
如为未下载数据集,可使用提供的脚本进行一键下载以及数据准备,如下所示:
3535

3636
```shell
37-
cd mindaudio/data
3837
# Download and creat json
39-
python librispeech_prepare.py --root_path "your_data_path"
38+
python mindaudio/data/librispeech.py --root_path "your_data_path"
4039
```
4140

4241
如已下载好压缩文件,请按如下命令操作:
4342

4443
```shell
4544
# creat json
46-
python librispeech_prepare.py --root_path "your_data_path" --data_ready True
45+
python mindaudio/data/librispeech.py --root_path "your_data_path" --data_ready True
4746
```
4847

4948
LibriSpeech存储flac音频格式的文件。要在MindAudio中使用它们,须将所有flac文件转换为wav文件,用户可以使用[ffmpeg](https://gist.github.com/seungwonpark/4f273739beef2691cd53b5c39629d830)[sox](https://sourceforge.net/projects/sox/)进行转换。
@@ -94,19 +93,13 @@ mpirun -n 8 python train.py -c "./deepspeech2.yaml"
9493
mpirun --allow-run-as-root -n 8 python train.py -c "./deepspeech2.yaml"
9594
```
9695

97-
#### 在GPU上进行多卡训练
98-
If you want to use the GPU for distributed training, see the following command:
99-
```shell
100-
# Distribute_training
101-
# assume you have 8 GPUs
102-
mpirun -n 8 python train.py -c "./deepspeech2.yaml" --device_target "GPU"
103-
```
10496

10597
### 3.评估模型
10698

99+
将训好的权重地址更新在deepspeech2.yaml配置文件Pretrained_model中,执行以下命令
107100
```shell
108101
# Validate a trained model
109-
python eval.py -c "./deepspeech2.yaml" --pre_trained_model_path "xx.ckpt"
102+
python eval.py -c "./deepspeech2.yaml"
110103
```
111104

112105

examples/deepspeech2/deepspeech2.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ EvalConfig:
4545
save_output: 'librispeech_val_output'
4646

4747
# use to finetune or eval model
48-
Pretrained_model: './ckpt'
48+
Pretrained_model: ''
4949

5050
labels: ["'", "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M",
5151
"N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z", " ", "_"]

examples/deepspeech2/eval.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66
import mindspore.ops as ops
77
import numpy as np
88
from dataset import create_dataset
9-
from hparams import parse_args
109
from mindspore import context, nn
1110
from mindspore.train.serialization import load_checkpoint, load_param_into_net
1211

1312
from mindaudio.models.decoders.greedydecoder import MSGreedyDecoder
1413
from mindaudio.models.deepspeech2 import DeepSpeechModel
14+
from mindaudio.utils.hparams import parse_args
1515

1616

1717
class PredictWithSoftmax(nn.Cell):
@@ -73,10 +73,7 @@ def construct(self, inputs, input_length):
7373
load_param_into_net(model, param_dict)
7474
print("Successfully loading the pre-trained model")
7575

76-
if args.Decoder_type == "greedy":
77-
decoder = MSGreedyDecoder(labels=labels, blank_index=labels.index("_"))
78-
else:
79-
raise NotImplementedError("Only greedy decoder is supported now")
76+
decoder = MSGreedyDecoder(labels=labels, blank_index=labels.index("_"))
8077
target_decoder = MSGreedyDecoder(labels, blank_index=labels.index("_"))
8178

8279
model.set_train(False)
@@ -106,8 +103,7 @@ def construct(self, inputs, input_length):
106103
decoded_output, _ = decoder.decode(out, output_sizes)
107104
target_strings = target_decoder.convert_to_strings(split_targets)
108105

109-
if args.save_output is not None:
110-
output_data.append((out.asnumpy(), output_sizes.asnumpy(), target_strings))
106+
output_data.append((out.asnumpy(), output_sizes.asnumpy(), target_strings))
111107
for doutput, toutput in zip(decoded_output, target_strings):
112108
transcript, reference = doutput[0], toutput[0]
113109
wer_inst = decoder.wer(transcript, reference)

mindaudio/utils/distributed.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ class DistributedSampler:
66
For mindspore.dataset.GeneratorDataset
77
"""
88

9-
def __init__(self, dataset, rank, group_size, shuffle=True, seed=0):
9+
def __init__(self, dataset, rank, group_size, shuffle=True, seed=0, group=True):
10+
self.group = group
1011
self.rank = rank
1112
self.group_size = group_size
1213
self.dataset_len = len(dataset)
@@ -20,7 +21,8 @@ def __iter__(self):
2021
indices = np.random.permutation(self.dataset_len)
2122
else:
2223
indices = np.arange(self.dataset_len)
23-
indices = indices[self.rank :: self.group_size]
24+
if self.group:
25+
indices = indices[self.rank :: self.group_size]
2426
return iter(indices)
2527

2628
def __len__(self):

0 commit comments

Comments
 (0)