Skip to content

Support KWS + RKNN. #2190

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sherpa-onnx/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ if(SHERPA_ONNX_ENABLE_RKNN)
./rknn/online-zipformer-ctc-model-rknn.cc
./rknn/online-zipformer-transducer-model-rknn.cc
./rknn/silero-vad-model-rknn.cc
./rknn/transducer-keyword-decoder-rknn.cc
./rknn/utils.cc
)

Expand Down
34 changes: 33 additions & 1 deletion sherpa-onnx/csrc/keyword-spotter-impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@

#include "sherpa-onnx/csrc/keyword-spotter-transducer-impl.h"

#if SHERPA_ONNX_ENABLE_RKNN
#include "sherpa-onnx/csrc/rknn/keyword-spotter-transducer-rknn-impl.h"
#endif

#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
Expand All @@ -19,17 +23,45 @@ namespace sherpa_onnx {

std::unique_ptr<KeywordSpotterImpl> KeywordSpotterImpl::Create(
const KeywordSpotterConfig &config) {
if (config.model_config.provider_config.provider == "rknn") {
#if SHERPA_ONNX_ENABLE_RKNN
if (!config.model_config.transducer.encoder.empty()) {
return std::make_unique<KeywordSpotterTransducerRknnImpl>(config);
}
#else
SHERPA_ONNX_LOGE(
"Please rebuild sherpa-onnx with -DSHERPA_ONNX_ENABLE_RKNN=ON if you "
"want to use rknn.");
SHERPA_ONNX_EXIT(-1);
return nullptr;
#endif
}

if (!config.model_config.transducer.encoder.empty()) {
return std::make_unique<KeywordSpotterTransducerImpl>(config);
}

SHERPA_ONNX_LOGE("Please specify a model");
exit(-1);
SHERPA_ONNX_EXIT(-1);
}

template <typename Manager>
std::unique_ptr<KeywordSpotterImpl> KeywordSpotterImpl::Create(
Manager *mgr, const KeywordSpotterConfig &config) {
if (config.model_config.provider_config.provider == "rknn") {
#if SHERPA_ONNX_ENABLE_RKNN
if (!config.model_config.transducer.encoder.empty()) {
return std::make_unique<KeywordSpotterTransducerRknnImpl>(mgr, config);
}
#else
SHERPA_ONNX_LOGE(
"Please rebuild sherpa-onnx with -DSHERPA_ONNX_ENABLE_RKNN=ON if you "
"want to use rknn.");
SHERPA_ONNX_EXIT(-1);
return nullptr;
#endif
}

if (!config.model_config.transducer.encoder.empty()) {
return std::make_unique<KeywordSpotterTransducerImpl>(mgr, config);
}
Expand Down
7 changes: 3 additions & 4 deletions sherpa-onnx/csrc/keyword-spotter-transducer-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,9 @@

namespace sherpa_onnx {

static KeywordResult Convert(const TransducerKeywordResult &src,
const SymbolTable &sym_table, float frame_shift_ms,
int32_t subsampling_factor,
int32_t frames_since_start) {
KeywordResult Convert(const TransducerKeywordResult &src,
const SymbolTable &sym_table, float frame_shift_ms,
int32_t subsampling_factor, int32_t frames_since_start) {
KeywordResult r;
r.tokens.reserve(src.tokens.size());
r.timestamps.reserve(src.tokens.size());
Expand Down
2 changes: 1 addition & 1 deletion sherpa-onnx/csrc/online-recognizer-impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
config.model_config.zipformer2_ctc.model.empty()) {
SHERPA_ONNX_LOGE(
"Only Zipformer transducers and CTC models are currently supported "
"by rknn. Fallback to CPU");
"by rknn. Fallback to CPU. Make sure you pass an onnx model");
} else if (!config.model_config.transducer.encoder.empty()) {
return std::make_unique<OnlineRecognizerTransducerRknnImpl>(config);
} else if (!config.model_config.zipformer2_ctc.model.empty()) {
Expand Down
Loading
Loading