Skip to content

Commit 68b8b88

Browse files
authored
Add Python API for punctuation models. (#762)
1 parent 329fe1a commit 68b8b88

File tree

14 files changed

+136
-6
lines changed

14 files changed

+136
-6
lines changed

.github/scripts/test-offline-punctuation.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ echo "PATH: $PATH"
1414
which $EXE
1515

1616
log "------------------------------------------------------------"
17-
log "Download model "
17+
log "Download the punctuation model "
1818
log "------------------------------------------------------------"
1919

2020
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2

.github/scripts/test-python.sh

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,18 @@ log() {
88
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
99
}
1010

11+
log "test offline punctuation"
12+
13+
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2
14+
tar xvf sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2
15+
rm sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2
16+
repo=sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12
17+
ls -lh $repo
18+
19+
python3 ./python-api-examples/add-punctuation.py
20+
21+
rm -rf $repo
22+
1123
log "test audio tagging"
1224

1325
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/audio-tagging-models/sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,3 +91,4 @@ sr-data
9191
*xcworkspace/xcuserdata/*
9292

9393
vits-icefall-*
94+
sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12

go-api-examples/vad-asr-paraformer/run.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33

44
if [ ! -f ./silero_vad.onnx ]; then
5-
curl -SL -O https://github.com/snakers4/silero-vad/blob/master/files/silero_vad.onnx
5+
curl -SL -O https://github.com/snakers4/silero-vad/raw/master/files/silero_vad.onnx
66
fi
77

88
if [ ! -f ./sherpa-onnx-paraformer-trilingual-zh-cantonese-en/model.int8.onnx ]; then

go-api-examples/vad-asr-whisper/run.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33

44
if [ ! -f ./silero_vad.onnx ]; then
5-
curl -SL -O https://github.com/snakers4/silero-vad/blob/master/files/silero_vad.onnx
5+
curl -SL -O https://github.com/snakers4/silero-vad/raw/master/files/silero_vad.onnx
66
fi
77

88
if [ ! -f ./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.int8.onnx ]; then

go-api-examples/vad-speaker-identification/run.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ if [ ! -f ./sr-data/enroll/fangjun-sr-1.wav ]; then
99
fi
1010

1111
if [ ! -f ./silero_vad.onnx ]; then
12-
curl -SL -O https://github.com/snakers4/silero-vad/blob/master/files/silero_vad.onnx
12+
curl -SL -O https://github.com/snakers4/silero-vad/raw/master/files/silero_vad.onnx
1313
fi
1414

1515
go mod tidy

go-api-examples/vad-spoken-language-identification/run.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33

44
if [ ! -f ./silero_vad.onnx ]; then
5-
curl -SL -O https://github.com/snakers4/silero-vad/blob/master/files/silero_vad.onnx
5+
curl -SL -O https://github.com/snakers4/silero-vad/raw/master/files/silero_vad.onnx
66
fi
77

88
if [ ! -f ./sherpa-onnx-whisper-tiny/tiny-encoder.int8.onnx ]; then

go-api-examples/vad/run.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33

44
if [ ! -f ./silero_vad.onnx ]; then
5-
curl -SL -O https://github.com/snakers4/silero-vad/blob/master/files/silero_vad.onnx
5+
curl -SL -O https://github.com/snakers4/silero-vad/raw/master/files/silero_vad.onnx
66
fi
77

88
go mod tidy
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
#!/usr/bin/env python3
2+
3+
"""
4+
This script shows how to add punctuations to text using sherpa-onnx Python API.
5+
6+
Please download the model from
7+
https://github.com/k2-fsa/sherpa-onnx/releases/tag/punctuation-models
8+
9+
The following is an example
10+
11+
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2
12+
tar xvf sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2
13+
rm sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2
14+
"""
15+
16+
from pathlib import Path
17+
18+
import sherpa_onnx
19+
20+
21+
def main():
22+
model = "./sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12/model.onnx"
23+
if not Path(model).is_file():
24+
raise ValueError(f"{model} does not exist")
25+
config = sherpa_onnx.OfflinePunctuationConfig(
26+
model=sherpa_onnx.OfflinePunctuationModelConfig(ct_transformer=model),
27+
)
28+
29+
punct = sherpa_onnx.OfflinePunctuation(config)
30+
31+
text_list = [
32+
"这是一个测试你好吗How are you我很好thank you are you ok谢谢你",
33+
"我们都是木头人不会说话不会动",
34+
"The African blogosphere is rapidly expanding bringing more voices online in the form of commentaries opinions analyses rants and poetry",
35+
]
36+
for text in text_list:
37+
text_with_punct = punct.add_punctuation(text)
38+
print("----------")
39+
print(f"input: {text}")
40+
print(f"output: {text_with_punct}")
41+
42+
print("----------")
43+
44+
45+
if __name__ == "__main__":
46+
main()

sherpa-onnx/python/csrc/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ set(srcs
1212
offline-model-config.cc
1313
offline-nemo-enc-dec-ctc-model-config.cc
1414
offline-paraformer-model-config.cc
15+
offline-punctuation.cc
1516
offline-recognizer.cc
1617
offline-stream.cc
1718
offline-tdnn-model-config.cc
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
// sherpa-onnx/python/csrc/offline-punctuation.cc
2+
//
3+
// Copyright (c) 2024 Xiaomi Corporation
4+
5+
#include "sherpa-onnx/python/csrc/offline-punctuation.h"
6+
7+
#include "sherpa-onnx/csrc/offline-punctuation.h"
8+
9+
namespace sherpa_onnx {
10+
11+
static void PybindOfflinePunctuationModelConfig(py::module *m) {
12+
using PyClass = OfflinePunctuationModelConfig;
13+
py::class_<PyClass>(*m, "OfflinePunctuationModelConfig")
14+
.def(py::init<>())
15+
.def(py::init<const std::string &, int32_t, bool, const std::string &>(),
16+
py::arg("ct_transformer"), py::arg("num_threads") = 1,
17+
py::arg("debug") = false, py::arg("provider") = "cpu")
18+
.def_readwrite("ct_transformer", &PyClass::ct_transformer)
19+
.def_readwrite("num_threads", &PyClass::num_threads)
20+
.def_readwrite("debug", &PyClass::debug)
21+
.def_readwrite("provider", &PyClass::provider)
22+
.def("validate", &PyClass::Validate)
23+
.def("__str__", &PyClass::ToString);
24+
}
25+
26+
static void PybindOfflinePunctuationConfig(py::module *m) {
27+
PybindOfflinePunctuationModelConfig(m);
28+
using PyClass = OfflinePunctuationConfig;
29+
30+
py::class_<PyClass>(*m, "OfflinePunctuationConfig")
31+
.def(py::init<>())
32+
.def(py::init<const OfflinePunctuationModelConfig &>(), py::arg("model"))
33+
.def_readwrite("model", &PyClass::model)
34+
.def("validate", &PyClass::Validate)
35+
.def("__str__", &PyClass::ToString);
36+
}
37+
38+
void PybindOfflinePunctuation(py::module *m) {
39+
PybindOfflinePunctuationConfig(m);
40+
using PyClass = OfflinePunctuation;
41+
42+
py::class_<PyClass>(*m, "OfflinePunctuation")
43+
.def(py::init<const OfflinePunctuationConfig &>(), py::arg("config"),
44+
py::call_guard<py::gil_scoped_release>())
45+
.def("add_punctuation", &PyClass::AddPunctuation, py::arg("text"),
46+
py::call_guard<py::gil_scoped_release>());
47+
}
48+
49+
} // namespace sherpa_onnx
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
// sherpa-onnx/python/csrc/offline-punctuation.h
2+
//
3+
// Copyright (c) 2024 Xiaomi Corporation
4+
5+
#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_PUNCTUATION_H_
6+
#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_PUNCTUATION_H_
7+
8+
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
9+
10+
namespace sherpa_onnx {
11+
12+
void PybindOfflinePunctuation(py::module *m);
13+
14+
}
15+
16+
#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_PUNCTUATION_H_

sherpa-onnx/python/csrc/sherpa-onnx.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "sherpa-onnx/python/csrc/offline-ctc-fst-decoder-config.h"
1515
#include "sherpa-onnx/python/csrc/offline-lm-config.h"
1616
#include "sherpa-onnx/python/csrc/offline-model-config.h"
17+
#include "sherpa-onnx/python/csrc/offline-punctuation.h"
1718
#include "sherpa-onnx/python/csrc/offline-recognizer.h"
1819
#include "sherpa-onnx/python/csrc/offline-stream.h"
1920
#include "sherpa-onnx/python/csrc/online-ctc-fst-decoder-config.h"
@@ -40,6 +41,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
4041

4142
PybindWaveWriter(&m);
4243
PybindAudioTagging(&m);
44+
PybindOfflinePunctuation(&m);
4345

4446
PybindFeatures(&m);
4547
PybindOnlineCtcFstDecoderConfig(&m);

sherpa-onnx/python/sherpa_onnx/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
AudioTaggingModelConfig,
77
CircularBuffer,
88
Display,
9+
OfflinePunctuation,
10+
OfflinePunctuationConfig,
11+
OfflinePunctuationModelConfig,
912
OfflineStream,
1013
OfflineTts,
1114
OfflineTtsConfig,

0 commit comments

Comments
 (0)