Skip to content

Commit db67e00

Browse files
authored
Add HLG decoding for streaming CTC models (#731)
1 parent f8832cb commit db67e00

28 files changed

+668
-82
lines changed

.github/scripts/test-online-ctc.sh

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/usr/bin/env bash
22

3-
set -e
3+
set -ex
44

55
log() {
66
# This function is from espnet
@@ -13,6 +13,26 @@ echo "PATH: $PATH"
1313

1414
which $EXE
1515

16+
log "------------------------------------------------------------"
17+
log "Run streaming Zipformer2 CTC HLG decoding "
18+
log "------------------------------------------------------------"
19+
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
20+
tar xvf sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
21+
rm sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
22+
repo=$PWD/sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18
23+
ls -lh $repo
24+
echo "pwd: $PWD"
25+
26+
$EXE \
27+
--zipformer2-ctc-model=$repo/ctc-epoch-30-avg-3-chunk-16-left-128.int8.onnx \
28+
--ctc-graph=$repo/HLG.fst \
29+
--tokens=$repo/tokens.txt \
30+
$repo/test_wavs/0.wav \
31+
$repo/test_wavs/1.wav \
32+
$repo/test_wavs/8k.wav
33+
34+
rm -rf sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18
35+
1636
log "------------------------------------------------------------"
1737
log "Run streaming Zipformer2 CTC "
1838
log "------------------------------------------------------------"

.github/scripts/test-python.sh

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,30 @@
11
#!/usr/bin/env bash
22

3-
set -e
3+
set -ex
44

55
log() {
66
# This function is from espnet
77
local fname=${BASH_SOURCE[1]##*/}
88
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
99
}
1010

11+
log "test streaming zipformer2 ctc HLG decoding"
12+
13+
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
14+
tar xvf sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
15+
rm sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
16+
repo=sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18
17+
18+
python3 ./python-api-examples/online-zipformer-ctc-hlg-decode-file.py \
19+
--debug 1 \
20+
--tokens ./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/tokens.txt \
21+
--graph ./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/HLG.fst \
22+
--model ./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/ctc-epoch-30-avg-3-chunk-16-left-128.int8.onnx \
23+
./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/test_wavs/0.wav
24+
25+
rm -rf sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18
26+
27+
1128
mkdir -p /tmp/icefall-models
1229
dir=/tmp/icefall-models
1330

.github/workflows/linux.yaml

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,14 @@ jobs:
124124
name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }}
125125
path: build/bin/*
126126

127+
- name: Test online CTC
128+
shell: bash
129+
run: |
130+
export PATH=$PWD/build/bin:$PATH
131+
export EXE=sherpa-onnx
132+
133+
.github/scripts/test-online-ctc.sh
134+
127135
- name: Test C API
128136
shell: bash
129137
run: |
@@ -149,13 +157,6 @@ jobs:
149157
150158
.github/scripts/test-kws.sh
151159
152-
- name: Test online CTC
153-
shell: bash
154-
run: |
155-
export PATH=$PWD/build/bin:$PATH
156-
export EXE=sherpa-onnx
157-
158-
.github/scripts/test-online-ctc.sh
159160
160161
- name: Test offline Whisper
161162
if: matrix.build_type != 'Debug'

cmake/kaldi-decoder.cmake

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
function(download_kaldi_decoder)
22
include(FetchContent)
33

4-
set(kaldi_decoder_URL "https://github.com/k2-fsa/kaldi-decoder/archive/refs/tags/v0.2.4.tar.gz")
5-
set(kaldi_decoder_URL2 "https://hub.nuaa.cf/k2-fsa/kaldi-decoder/archive/refs/tags/v0.2.4.tar.gz")
6-
set(kaldi_decoder_HASH "SHA256=136d96c2f1f8ec44de095205f81a6ce98981cd867fe4ba840f9415a0b58fe601")
4+
set(kaldi_decoder_URL "https://github.com/k2-fsa/kaldi-decoder/archive/refs/tags/v0.2.5.tar.gz")
5+
set(kaldi_decoder_URL2 "https://hub.nuaa.cf/k2-fsa/kaldi-decoder/archive/refs/tags/v0.2.5.tar.gz")
6+
set(kaldi_decoder_HASH "SHA256=f663e58aef31b33cd8086eaa09ff1383628039845f31300b5abef817d8cc2fff")
77

88
set(KALDI_DECODER_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
99
set(KALDI_DECODER_ENABLE_TESTS OFF CACHE BOOL "" FORCE)
@@ -12,11 +12,11 @@ function(download_kaldi_decoder)
1212
# If you don't have access to the Internet,
1313
# please pre-download kaldi-decoder
1414
set(possible_file_locations
15-
$ENV{HOME}/Downloads/kaldi-decoder-0.2.4.tar.gz
16-
${CMAKE_SOURCE_DIR}/kaldi-decoder-0.2.4.tar.gz
17-
${CMAKE_BINARY_DIR}/kaldi-decoder-0.2.4.tar.gz
18-
/tmp/kaldi-decoder-0.2.4.tar.gz
19-
/star-fj/fangjun/download/github/kaldi-decoder-0.2.4.tar.gz
15+
$ENV{HOME}/Downloads/kaldi-decoder-0.2.5.tar.gz
16+
${CMAKE_SOURCE_DIR}/kaldi-decoder-0.2.5.tar.gz
17+
${CMAKE_BINARY_DIR}/kaldi-decoder-0.2.5.tar.gz
18+
/tmp/kaldi-decoder-0.2.5.tar.gz
19+
/star-fj/fangjun/download/github/kaldi-decoder-0.2.5.tar.gz
2020
)
2121

2222
foreach(f IN LISTS possible_file_locations)
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
#!/usr/bin/env python3
2+
3+
# This file shows how to use a streaming zipformer CTC model and an HLG
4+
# graph for decoding.
5+
#
6+
# We use the following model as an example
7+
#
8+
"""
9+
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
10+
tar xvf sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
11+
rm sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
12+
13+
python3 ./python-api-examples/online-zipformer-ctc-hlg-decode-file.py \
14+
--tokens ./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/tokens.txt \
15+
--graph ./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/HLG.fst \
16+
--model ./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/ctc-epoch-30-avg-3-chunk-16-left-128.int8.onnx \
17+
./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/test_wavs/0.wav
18+
19+
"""
20+
# (The above model is from https://github.com/k2-fsa/icefall/pull/1557)
21+
22+
import argparse
23+
import time
24+
import wave
25+
from pathlib import Path
26+
from typing import List, Tuple
27+
28+
import numpy as np
29+
import sherpa_onnx
30+
31+
32+
def get_args():
33+
parser = argparse.ArgumentParser(
34+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
35+
)
36+
37+
parser.add_argument(
38+
"--tokens",
39+
type=str,
40+
required=True,
41+
help="Path to tokens.txt",
42+
)
43+
44+
parser.add_argument(
45+
"--model",
46+
type=str,
47+
required=True,
48+
help="Path to the ONNX model",
49+
)
50+
51+
parser.add_argument(
52+
"--graph",
53+
type=str,
54+
required=True,
55+
help="Path to H.fst, HL.fst, or HLG.fst",
56+
)
57+
58+
parser.add_argument(
59+
"--num-threads",
60+
type=int,
61+
default=1,
62+
help="Number of threads for neural network computation",
63+
)
64+
65+
parser.add_argument(
66+
"--provider",
67+
type=str,
68+
default="cpu",
69+
help="Valid values: cpu, cuda, coreml",
70+
)
71+
72+
parser.add_argument(
73+
"--debug",
74+
type=int,
75+
default=0,
76+
help="Valid values: 1, 0",
77+
)
78+
79+
parser.add_argument(
80+
"sound_file",
81+
type=str,
82+
help="The input sound file to decode. It must be of WAVE"
83+
"format with a single channel, and each sample has 16-bit, "
84+
"i.e., int16_t. "
85+
"The sample rate of the file can be arbitrary and does not need to "
86+
"be 16 kHz",
87+
)
88+
89+
return parser.parse_args()
90+
91+
92+
def assert_file_exists(filename: str):
93+
assert Path(filename).is_file(), (
94+
f"{filename} does not exist!\n"
95+
"Please refer to "
96+
"https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it"
97+
)
98+
99+
100+
def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
101+
"""
102+
Args:
103+
wave_filename:
104+
Path to a wave file. It should be single channel and each sample should
105+
be 16-bit. Its sample rate does not need to be 16kHz.
106+
Returns:
107+
Return a tuple containing:
108+
- A 1-D array of dtype np.float32 containing the samples, which are
109+
normalized to the range [-1, 1].
110+
- sample rate of the wave file
111+
"""
112+
113+
with wave.open(wave_filename) as f:
114+
assert f.getnchannels() == 1, f.getnchannels()
115+
assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes
116+
num_samples = f.getnframes()
117+
samples = f.readframes(num_samples)
118+
samples_int16 = np.frombuffer(samples, dtype=np.int16)
119+
samples_float32 = samples_int16.astype(np.float32)
120+
121+
samples_float32 = samples_float32 / 32768
122+
return samples_float32, f.getframerate()
123+
124+
125+
def main():
126+
args = get_args()
127+
print(vars(args))
128+
129+
assert_file_exists(args.tokens)
130+
assert_file_exists(args.graph)
131+
assert_file_exists(args.model)
132+
133+
recognizer = sherpa_onnx.OnlineRecognizer.from_zipformer2_ctc(
134+
tokens=args.tokens,
135+
model=args.model,
136+
num_threads=args.num_threads,
137+
provider=args.provider,
138+
sample_rate=16000,
139+
feature_dim=80,
140+
ctc_graph=args.graph,
141+
)
142+
143+
wave_filename = args.sound_file
144+
assert_file_exists(wave_filename)
145+
samples, sample_rate = read_wave(wave_filename)
146+
duration = len(samples) / sample_rate
147+
148+
print("Started")
149+
150+
start_time = time.time()
151+
s = recognizer.create_stream()
152+
s.accept_waveform(sample_rate, samples)
153+
tail_paddings = np.zeros(int(0.66 * sample_rate), dtype=np.float32)
154+
s.accept_waveform(sample_rate, tail_paddings)
155+
s.input_finished()
156+
while recognizer.is_ready(s):
157+
recognizer.decode_stream(s)
158+
159+
result = recognizer.get_result(s).lower()
160+
end_time = time.time()
161+
162+
elapsed_seconds = end_time - start_time
163+
rtf = elapsed_seconds / duration
164+
print(f"num_threads: {args.num_threads}")
165+
print(f"Wave duration: {duration:.3f} s")
166+
print(f"Elapsed time: {elapsed_seconds:.3f} s")
167+
print(f"Real time factor (RTF): {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}")
168+
print(result)
169+
170+
171+
if __name__ == "__main__":
172+
main()

sherpa-onnx/csrc/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ set(sources
5151
offline-zipformer-ctc-model-config.cc
5252
offline-zipformer-ctc-model.cc
5353
online-conformer-transducer-model.cc
54+
online-ctc-fst-decoder-config.cc
55+
online-ctc-fst-decoder.cc
5456
online-ctc-greedy-search-decoder.cc
5557
online-ctc-model.cc
5658
online-lm-config.cc

sherpa-onnx/csrc/offline-ctc-fst-decoder-config.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
#include <sstream>
88
#include <string>
99

10+
#include "sherpa-onnx/csrc/file-utils.h"
11+
#include "sherpa-onnx/csrc/macros.h"
12+
1013
namespace sherpa_onnx {
1114

1215
std::string OfflineCtcFstDecoderConfig::ToString() const {
@@ -29,4 +32,12 @@ void OfflineCtcFstDecoderConfig::Register(ParseOptions *po) {
2932
"Decoder max active states. Larger->slower; more accurate");
3033
}
3134

35+
bool OfflineCtcFstDecoderConfig::Validate() const {
36+
if (!graph.empty() && !FileExists(graph)) {
37+
SHERPA_ONNX_LOGE("graph: %s does not exist", graph.c_str());
38+
return false;
39+
}
40+
return true;
41+
}
42+
3243
} // namespace sherpa_onnx

sherpa-onnx/csrc/offline-ctc-fst-decoder-config.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ struct OfflineCtcFstDecoderConfig {
2424
std::string ToString() const;
2525

2626
void Register(ParseOptions *po);
27+
bool Validate() const;
2728
};
2829

2930
} // namespace sherpa_onnx

sherpa-onnx/csrc/offline-ctc-fst-decoder.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ namespace sherpa_onnx {
2020
// @param filename Path to a StdVectorFst or StdConstFst graph
2121
// @return The caller should free the returned pointer using `delete` to
2222
// avoid memory leak.
23-
static fst::Fst<fst::StdArc> *ReadGraph(const std::string &filename) {
23+
fst::Fst<fst::StdArc> *ReadGraph(const std::string &filename) {
2424
// read decoding network FST
2525
std::ifstream is(filename, std::ios::binary);
2626
if (!is.good()) {

sherpa-onnx/csrc/offline-recognizer.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,12 @@ bool OfflineRecognizerConfig::Validate() const {
6767
return false;
6868
}
6969

70+
if (!ctc_fst_decoder_config.graph.empty() &&
71+
!ctc_fst_decoder_config.Validate()) {
72+
SHERPA_ONNX_LOGE("Errors in fst_decoder");
73+
return false;
74+
}
75+
7076
return model_config.Validate();
7177
}
7278

sherpa-onnx/csrc/online-ctc-decoder.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,16 @@
55
#ifndef SHERPA_ONNX_CSRC_ONLINE_CTC_DECODER_H_
66
#define SHERPA_ONNX_CSRC_ONLINE_CTC_DECODER_H_
77

8+
#include <memory>
89
#include <vector>
910

11+
#include "kaldi-decoder/csrc/faster-decoder.h"
1012
#include "onnxruntime_cxx_api.h" // NOLINT
1113

1214
namespace sherpa_onnx {
1315

16+
class OnlineStream;
17+
1418
struct OnlineCtcDecoderResult {
1519
/// Number of frames after subsampling we have decoded so far
1620
int32_t frame_offset = 0;
@@ -37,7 +41,13 @@ class OnlineCtcDecoder {
3741
* @param results Input & Output parameters..
3842
*/
3943
virtual void Decode(Ort::Value log_probs,
40-
std::vector<OnlineCtcDecoderResult> *results) = 0;
44+
std::vector<OnlineCtcDecoderResult> *results,
45+
OnlineStream **ss = nullptr, int32_t n = 0) = 0;
46+
47+
virtual std::unique_ptr<kaldi_decoder::FasterDecoder> CreateFasterDecoder()
48+
const {
49+
return nullptr;
50+
}
4151
};
4252

4353
} // namespace sherpa_onnx

0 commit comments

Comments
 (0)