Skip to content

Fix keyword spotting. #1689

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jan 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 1 addition & 32 deletions .github/scripts/test-python.sh
Original file line number Diff line number Diff line change
Expand Up @@ -574,29 +574,6 @@ echo "sherpa_onnx version: $sherpa_onnx_version"
pwd
ls -lh

repo=sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01
log "Start testing ${repo}"

pushd $dir
curl -LS -O https://github.com/pkufool/keyword-spotting-models/releases/download/v0.1/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01.tar.bz
tar xf sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01.tar.bz
rm sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01.tar.bz
popd

repo=$dir/$repo
ls -lh $repo

python3 ./python-api-examples/keyword-spotter.py \
--tokens=$repo/tokens.txt \
--encoder=$repo/encoder-epoch-12-avg-2-chunk-16-left-64.onnx \
--decoder=$repo/decoder-epoch-12-avg-2-chunk-16-left-64.onnx \
--joiner=$repo/joiner-epoch-12-avg-2-chunk-16-left-64.onnx \
--keywords-file=$repo/test_wavs/test_keywords.txt \
$repo/test_wavs/0.wav \
$repo/test_wavs/1.wav

rm -rf $repo

if [[ x$OS != x'windows-latest' ]]; then
echo "OS: $OS"

Expand All @@ -612,15 +589,7 @@ if [[ x$OS != x'windows-latest' ]]; then
repo=$dir/$repo
ls -lh $repo

python3 ./python-api-examples/keyword-spotter.py \
--tokens=$repo/tokens.txt \
--encoder=$repo/encoder-epoch-12-avg-2-chunk-16-left-64.onnx \
--decoder=$repo/decoder-epoch-12-avg-2-chunk-16-left-64.onnx \
--joiner=$repo/joiner-epoch-12-avg-2-chunk-16-left-64.onnx \
--keywords-file=$repo/test_wavs/test_keywords.txt \
$repo/test_wavs/3.wav \
$repo/test_wavs/4.wav \
$repo/test_wavs/5.wav
python3 ./python-api-examples/keyword-spotter.py

python3 sherpa-onnx/python/tests/test_keyword_spotter.py --verbose

Expand Down
21 changes: 21 additions & 0 deletions .github/workflows/c-api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,27 @@ jobs:
otool -L ./install/lib/libsherpa-onnx-c-api.dylib
fi

- name: Test kws (zh)
shell: bash
run: |
gcc -o kws-c-api ./c-api-examples/kws-c-api.c \
-I ./build/install/include \
-L ./build/install/lib/ \
-l sherpa-onnx-c-api \
-l onnxruntime

curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/kws-models/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2
tar xvf sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2
rm sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2

export LD_LIBRARY_PATH=$PWD/build/install/lib:$LD_LIBRARY_PATH
export DYLD_LIBRARY_PATH=$PWD/build/install/lib:$DYLD_LIBRARY_PATH

./kws-c-api

rm ./kws-c-api
rm -rf sherpa-onnx-kws-*

- name: Test Kokoro TTS (en)
shell: bash
run: |
Expand Down
22 changes: 22 additions & 0 deletions .github/workflows/cxx-api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,28 @@ jobs:
otool -L ./install/lib/libsherpa-onnx-cxx-api.dylib
fi

- name: Test KWS (zh)
shell: bash
run: |
g++ -std=c++17 -o kws-cxx-api ./cxx-api-examples/kws-cxx-api.cc \
-I ./build/install/include \
-L ./build/install/lib/ \
-l sherpa-onnx-cxx-api \
-l sherpa-onnx-c-api \
-l onnxruntime

curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/kws-models/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2
tar xvf sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2
rm sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2

export LD_LIBRARY_PATH=$PWD/build/install/lib:$LD_LIBRARY_PATH
export DYLD_LIBRARY_PATH=$PWD/build/install/lib:$DYLD_LIBRARY_PATH

./kws-cxx-api

rm kws-cxx-api
rm -rf sherpa-onnx-kws-*

- name: Test Kokoro TTS (en)
shell: bash
run: |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,24 +151,27 @@ class MainActivity : AppCompatActivity() {
stream.acceptWaveform(samples, sampleRate = sampleRateInHz)
while (kws.isReady(stream)) {
kws.decode(stream)
}

val text = kws.getResult(stream).keyword
val text = kws.getResult(stream).keyword

var textToDisplay = lastText

var textToDisplay = lastText
if (text.isNotBlank()) {
// Remember to reset the stream right after detecting a keyword

if (text.isNotBlank()) {
if (lastText.isBlank()) {
textToDisplay = "$idx: $text"
} else {
textToDisplay = "$idx: $text\n$lastText"
kws.reset(stream)
if (lastText.isBlank()) {
textToDisplay = "$idx: $text"
} else {
textToDisplay = "$idx: $text\n$lastText"
}
lastText = "$idx: $text\n$lastText"
idx += 1
}
lastText = "$idx: $text\n$lastText"
idx += 1
}

runOnUiThread {
textView.text = textToDisplay
runOnUiThread {
textView.text = textToDisplay
}
}
}
}
Expand Down
3 changes: 3 additions & 0 deletions c-api-examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ include_directories(${CMAKE_SOURCE_DIR})
add_executable(decode-file-c-api decode-file-c-api.c)
target_link_libraries(decode-file-c-api sherpa-onnx-c-api cargs)

add_executable(kws-c-api kws-c-api.c)
target_link_libraries(kws-c-api sherpa-onnx-c-api)

if(SHERPA_ONNX_ENABLE_TTS)
add_executable(offline-tts-c-api offline-tts-c-api.c)
target_link_libraries(offline-tts-c-api sherpa-onnx-c-api cargs)
Expand Down
150 changes: 150 additions & 0 deletions c-api-examples/kws-c-api.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
// c-api-examples/kws-c-api.c
//
// Copyright (c) 2025 Xiaomi Corporation
//
// This file demonstrates how to use keywords spotter with sherpa-onnx's C
// clang-format off
//
// Usage
//
// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/kws-models/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2
// tar xvf sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2
// rm sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2
//
// ./kws-c-api
//
// clang-format on
#include <stdio.h>
#include <stdlib.h> // exit
#include <string.h> // memset

#include "sherpa-onnx/c-api/c-api.h"

int32_t main() {
SherpaOnnxKeywordSpotterConfig config;

memset(&config, 0, sizeof(config));
config.model_config.transducer.encoder =
"./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/"
"encoder-epoch-12-avg-2-chunk-16-left-64.onnx";

config.model_config.transducer.decoder =
"./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/"
"decoder-epoch-12-avg-2-chunk-16-left-64.onnx";

config.model_config.transducer.joiner =
"./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/"
"joiner-epoch-12-avg-2-chunk-16-left-64.onnx";

config.model_config.tokens =
"./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/tokens.txt";

config.model_config.provider = "cpu";
config.model_config.num_threads = 1;
config.model_config.debug = 1;

config.keywords_file =
"./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/"
"test_keywords.txt";

const SherpaOnnxKeywordSpotter *kws = SherpaOnnxCreateKeywordSpotter(&config);
if (!kws) {
fprintf(stderr, "Please check your config");
exit(-1);
}

fprintf(stderr,
"--Test pre-defined keywords from test_wavs/test_keywords.txt--\n");

const char *wav_filename =
"./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/3.wav";

float tail_paddings[8000] = {0}; // 0.5 seconds

const SherpaOnnxWave *wave = SherpaOnnxReadWave(wav_filename);
if (wave == NULL) {
fprintf(stderr, "Failed to read %s\n", wav_filename);
exit(-1);
}

const SherpaOnnxOnlineStream *stream = SherpaOnnxCreateKeywordStream(kws);
if (!stream) {
fprintf(stderr, "Failed to create stream\n");
exit(-1);
}

SherpaOnnxOnlineStreamAcceptWaveform(stream, wave->sample_rate, wave->samples,
wave->num_samples);

SherpaOnnxOnlineStreamAcceptWaveform(stream, wave->sample_rate, tail_paddings,
sizeof(tail_paddings) / sizeof(float));
SherpaOnnxOnlineStreamInputFinished(stream);
while (SherpaOnnxIsKeywordStreamReady(kws, stream)) {
SherpaOnnxDecodeKeywordStream(kws, stream);
const SherpaOnnxKeywordResult *r = SherpaOnnxGetKeywordResult(kws, stream);
if (r && r->json && strlen(r->keyword)) {
fprintf(stderr, "Detected keyword: %s\n", r->json);

// Remember to reset the keyword stream right after a keyword is detected
SherpaOnnxResetKeywordStream(kws, stream);
}
SherpaOnnxDestroyKeywordResult(r);
}
SherpaOnnxDestroyOnlineStream(stream);

// --------------------------------------------------------------------------

fprintf(stderr, "--Use pre-defined keywords + add a new keyword--\n");

stream = SherpaOnnxCreateKeywordStreamWithKeywords(kws, "y ǎn y uán @演员");

SherpaOnnxOnlineStreamAcceptWaveform(stream, wave->sample_rate, wave->samples,
wave->num_samples);

SherpaOnnxOnlineStreamAcceptWaveform(stream, wave->sample_rate, tail_paddings,
sizeof(tail_paddings) / sizeof(float));
SherpaOnnxOnlineStreamInputFinished(stream);
while (SherpaOnnxIsKeywordStreamReady(kws, stream)) {
SherpaOnnxDecodeKeywordStream(kws, stream);
const SherpaOnnxKeywordResult *r = SherpaOnnxGetKeywordResult(kws, stream);
if (r && r->json && strlen(r->keyword)) {
fprintf(stderr, "Detected keyword: %s\n", r->json);

// Remember to reset the keyword stream
SherpaOnnxResetKeywordStream(kws, stream);
}
SherpaOnnxDestroyKeywordResult(r);
}
SherpaOnnxDestroyOnlineStream(stream);

// --------------------------------------------------------------------------

fprintf(stderr, "--Use pre-defined keywords + add two new keywords--\n");

stream = SherpaOnnxCreateKeywordStreamWithKeywords(
kws, "y ǎn y uán @演员/zh ī m íng @知名");

SherpaOnnxOnlineStreamAcceptWaveform(stream, wave->sample_rate, wave->samples,
wave->num_samples);

SherpaOnnxOnlineStreamAcceptWaveform(stream, wave->sample_rate, tail_paddings,
sizeof(tail_paddings) / sizeof(float));
SherpaOnnxOnlineStreamInputFinished(stream);
while (SherpaOnnxIsKeywordStreamReady(kws, stream)) {
SherpaOnnxDecodeKeywordStream(kws, stream);
const SherpaOnnxKeywordResult *r = SherpaOnnxGetKeywordResult(kws, stream);
if (r && r->json && strlen(r->keyword)) {
fprintf(stderr, "Detected keyword: %s\n", r->json);

// Remember to reset the keyword stream
SherpaOnnxResetKeywordStream(kws, stream);
}
SherpaOnnxDestroyKeywordResult(r);
}
SherpaOnnxDestroyOnlineStream(stream);

SherpaOnnxFreeWave(wave);
SherpaOnnxDestroyKeywordSpotter(kws);

return 0;
}
3 changes: 3 additions & 0 deletions cxx-api-examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ include_directories(${CMAKE_SOURCE_DIR})
add_executable(streaming-zipformer-cxx-api ./streaming-zipformer-cxx-api.cc)
target_link_libraries(streaming-zipformer-cxx-api sherpa-onnx-cxx-api)

add_executable(kws-cxx-api ./kws-cxx-api.cc)
target_link_libraries(kws-cxx-api sherpa-onnx-cxx-api)

add_executable(streaming-zipformer-rtf-cxx-api ./streaming-zipformer-rtf-cxx-api.cc)
target_link_libraries(streaming-zipformer-rtf-cxx-api sherpa-onnx-cxx-api)

Expand Down
Loading
Loading