From 4cae5fb7ae418ed352ace903080ea3c621589387 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 24 Mar 2025 21:47:47 +0800 Subject: [PATCH 1/2] begin to add rknn for kws --- sherpa-onnx/csrc/keyword-spotter-impl.cc | 30 +- .../keyword-spotter-transducer-rknn-impl.h | 372 ++++++++++++++++++ sherpa-onnx/csrc/online-recognizer-impl.cc | 4 +- sherpa-onnx/csrc/sherpa-onnx-microphone.cc | 2 +- 4 files changed, 404 insertions(+), 4 deletions(-) create mode 100644 sherpa-onnx/csrc/keyword-spotter-transducer-rknn-impl.h diff --git a/sherpa-onnx/csrc/keyword-spotter-impl.cc b/sherpa-onnx/csrc/keyword-spotter-impl.cc index affb212c05..d6006446aa 100644 --- a/sherpa-onnx/csrc/keyword-spotter-impl.cc +++ b/sherpa-onnx/csrc/keyword-spotter-impl.cc @@ -6,6 +6,10 @@ #include "sherpa-onnx/csrc/keyword-spotter-transducer-impl.h" +#if SHERPA_ONNX_ENABLE_RKNN +#include "sherpa-onnx/csrc/keyword-spotter-transducer-rknn-impl.h" +#endif + #if __ANDROID_API__ >= 9 #include "android/asset_manager.h" #include "android/asset_manager_jni.h" @@ -19,17 +23,41 @@ namespace sherpa_onnx { std::unique_ptr KeywordSpotterImpl::Create( const KeywordSpotterConfig &config) { + if (config.model_config.provider_config.provider == "rknn") { +#if SHERPA_ONNX_ENABLE_RKNN + if (!config.model_config.transducer.encoder.empty()) { + return std::make_unique(config); + } +#else + SHERPA_ONNX_LOGE( + "Please rebuild sherpa-onnx with -DSHERPA_ONNX_ENABLE_RKNN=ON if you " + "want to use rknn. Fallback to CPU. Make sure you pass an onnx model."); +#endif + } + if (!config.model_config.transducer.encoder.empty()) { return std::make_unique(config); } SHERPA_ONNX_LOGE("Please specify a model"); - exit(-1); + SHERPA_ONNX_EXIT(-1); } template std::unique_ptr KeywordSpotterImpl::Create( Manager *mgr, const KeywordSpotterConfig &config) { + if (config.model_config.provider_config.provider == "rknn") { +#if SHERPA_ONNX_ENABLE_RKNN + if (!config.model_config.transducer.encoder.empty()) { + return std::make_unique(mgr, config); + } +#else + SHERPA_ONNX_LOGE( + "Please rebuild sherpa-onnx with -DSHERPA_ONNX_ENABLE_RKNN=ON if you " + "want to use rknn. Fallback to CPU. Make sure you pass an onnx model."); +#endif + } + if (!config.model_config.transducer.encoder.empty()) { return std::make_unique(mgr, config); } diff --git a/sherpa-onnx/csrc/keyword-spotter-transducer-rknn-impl.h b/sherpa-onnx/csrc/keyword-spotter-transducer-rknn-impl.h new file mode 100644 index 0000000000..b82705401e --- /dev/null +++ b/sherpa-onnx/csrc/keyword-spotter-transducer-rknn-impl.h @@ -0,0 +1,372 @@ +// sherpa-onnx/csrc/keyword-spotter-transducer-rknn-impl.h +// +// Copyright (c) 2025 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_TRANSDUCER_RKNN_IMPL_H_ +#define SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_TRANSDUCER_RKNN_IMPL_H_ + +#include +#include +#include // NOLINT +#include +#include +#include +#include + +#include "sherpa-onnx/csrc/file-utils.h" +#include "sherpa-onnx/csrc/keyword-spotter-impl.h" +#include "sherpa-onnx/csrc/keyword-spotter.h" +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/online-transducer-model.h" +#include "sherpa-onnx/csrc/symbol-table.h" +#include "sherpa-onnx/csrc/transducer-keyword-decoder.h" +#include "sherpa-onnx/csrc/utils.h" + +namespace sherpa_onnx { + +static KeywordResult Convert(const TransducerKeywordResult &src, + const SymbolTable &sym_table, float frame_shift_ms, + int32_t subsampling_factor, + int32_t frames_since_start) { + KeywordResult r; + r.tokens.reserve(src.tokens.size()); + r.timestamps.reserve(src.tokens.size()); + r.keyword = src.keyword; + bool from_tokens = src.keyword.empty(); + + for (auto i : src.tokens) { + auto sym = sym_table[i]; + if (from_tokens) { + r.keyword.append(sym); + } + r.tokens.push_back(std::move(sym)); + } + if (from_tokens && r.keyword.size()) { + r.keyword = r.keyword.substr(1); + } + + float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor; + for (auto t : src.timestamps) { + float time = frame_shift_s * t; + r.timestamps.push_back(time); + } + + r.start_time = frames_since_start * frame_shift_ms / 1000.; + + return r; +} + +class KeywordSpotterTransducerRknnImpl : public KeywordSpotterImpl { + public: + explicit KeywordSpotterTransducerRknnImpl(const KeywordSpotterConfig &config) + : config_(config), + model_(OnlineTransducerModel::Create(config.model_config)) { + if (!config.model_config.tokens_buf.empty()) { + sym_ = SymbolTable(config.model_config.tokens_buf, false); + } else { + /// assuming tokens_buf and tokens are guaranteed not being both empty + sym_ = SymbolTable(config.model_config.tokens, true); + } + + if (sym_.Contains("")) { + unk_id_ = sym_[""]; + } + + model_->SetFeatureDim(config.feat_config.feature_dim); + + if (config.keywords_buf.empty()) { + InitKeywords(); + } else { + InitKeywordsFromBufStr(); + } + + decoder_ = std::make_unique( + model_.get(), config_.max_active_paths, config_.num_trailing_blanks, + unk_id_); + } + + template + KeywordSpotterTransducerRknnImpl(Manager *mgr, + const KeywordSpotterConfig &config) + : config_(config), + model_(OnlineTransducerModel::Create(mgr, config.model_config)), + sym_(mgr, config.model_config.tokens) { + if (sym_.Contains("")) { + unk_id_ = sym_[""]; + } + + model_->SetFeatureDim(config.feat_config.feature_dim); + + InitKeywords(mgr); + + decoder_ = std::make_unique( + model_.get(), config_.max_active_paths, config_.num_trailing_blanks, + unk_id_); + } + + std::unique_ptr CreateStream() const override { + auto stream = + std::make_unique(config_.feat_config, keywords_graph_); + InitOnlineStream(stream.get()); + return stream; + } + + std::unique_ptr CreateStream( + const std::string &keywords) const override { + auto kws = std::regex_replace(keywords, std::regex("/"), "\n"); + std::istringstream is(kws); + + std::vector> current_ids; + std::vector current_kws; + std::vector current_scores; + std::vector current_thresholds; + + if (!EncodeKeywords(is, sym_, ¤t_ids, ¤t_kws, ¤t_scores, + ¤t_thresholds)) { +#if __OHOS__ + SHERPA_ONNX_LOGE("Encode keywords %{public}s failed.", keywords.c_str()); +#else + SHERPA_ONNX_LOGE("Encode keywords %s failed.", keywords.c_str()); +#endif + return nullptr; + } + + int32_t num_kws = current_ids.size(); + int32_t num_default_kws = keywords_id_.size(); + + current_ids.insert(current_ids.end(), keywords_id_.begin(), + keywords_id_.end()); + + if (!current_kws.empty() && !keywords_.empty()) { + current_kws.insert(current_kws.end(), keywords_.begin(), keywords_.end()); + } else if (!current_kws.empty() && keywords_.empty()) { + current_kws.insert(current_kws.end(), num_default_kws, std::string()); + } else if (current_kws.empty() && !keywords_.empty()) { + current_kws.insert(current_kws.end(), num_kws, std::string()); + current_kws.insert(current_kws.end(), keywords_.begin(), keywords_.end()); + } else { + // Do nothing. + } + + if (!current_scores.empty() && !boost_scores_.empty()) { + current_scores.insert(current_scores.end(), boost_scores_.begin(), + boost_scores_.end()); + } else if (!current_scores.empty() && boost_scores_.empty()) { + current_scores.insert(current_scores.end(), num_default_kws, + config_.keywords_score); + } else if (current_scores.empty() && !boost_scores_.empty()) { + current_scores.insert(current_scores.end(), num_kws, + config_.keywords_score); + current_scores.insert(current_scores.end(), boost_scores_.begin(), + boost_scores_.end()); + } else { + // Do nothing. + } + + if (!current_thresholds.empty() && !thresholds_.empty()) { + current_thresholds.insert(current_thresholds.end(), thresholds_.begin(), + thresholds_.end()); + } else if (!current_thresholds.empty() && thresholds_.empty()) { + current_thresholds.insert(current_thresholds.end(), num_default_kws, + config_.keywords_threshold); + } else if (current_thresholds.empty() && !thresholds_.empty()) { + current_thresholds.insert(current_thresholds.end(), num_kws, + config_.keywords_threshold); + current_thresholds.insert(current_thresholds.end(), thresholds_.begin(), + thresholds_.end()); + } else { + // Do nothing. + } + + auto keywords_graph = std::make_shared( + current_ids, config_.keywords_score, config_.keywords_threshold, + current_scores, current_kws, current_thresholds); + + auto stream = + std::make_unique(config_.feat_config, keywords_graph); + InitOnlineStream(stream.get()); + return stream; + } + + bool IsReady(OnlineStream *s) const override { + return s->GetNumProcessedFrames() + model_->ChunkSize() < + s->NumFramesReady(); + } + void Reset(OnlineStream *s) const override { InitOnlineStream(s); } + + void DecodeStreams(OnlineStream **ss, int32_t n) const override { + for (int32_t i = 0; i < n; ++i) { + auto s = ss[i]; + auto r = s->GetKeywordResult(true); + int32_t num_trailing_blanks = r.num_trailing_blanks; + // assume subsampling_factor is 4 + // assume frameshift is 0.01 second + float trailing_slience = num_trailing_blanks * 4 * 0.01; + + // it resets automatically after detecting 1.5 seconds of silence + float threshold = 1.5; + if (trailing_slience > threshold) { + Reset(s); + } + } + + int32_t chunk_size = model_->ChunkSize(); + int32_t chunk_shift = model_->ChunkShift(); + + int32_t feature_dim = ss[0]->FeatureDim(); + + std::vector results(n); + std::vector features_vec(n * chunk_size * feature_dim); + std::vector> states_vec(n); + std::vector all_processed_frames(n); + + for (int32_t i = 0; i != n; ++i) { + SHERPA_ONNX_CHECK(ss[i]->GetContextGraph() != nullptr); + + const auto num_processed_frames = ss[i]->GetNumProcessedFrames(); + std::vector features = + ss[i]->GetFrames(num_processed_frames, chunk_size); + + // Question: should num_processed_frames include chunk_shift? + ss[i]->GetNumProcessedFrames() += chunk_shift; + + std::copy(features.begin(), features.end(), + features_vec.data() + i * chunk_size * feature_dim); + + results[i] = std::move(ss[i]->GetKeywordResult()); + states_vec[i] = std::move(ss[i]->GetStates()); + all_processed_frames[i] = num_processed_frames; + } + + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + std::array x_shape{n, chunk_size, feature_dim}; + + Ort::Value x = Ort::Value::CreateTensor(memory_info, features_vec.data(), + features_vec.size(), x_shape.data(), + x_shape.size()); + + std::array processed_frames_shape{ + static_cast(all_processed_frames.size())}; + + Ort::Value processed_frames = Ort::Value::CreateTensor( + memory_info, all_processed_frames.data(), all_processed_frames.size(), + processed_frames_shape.data(), processed_frames_shape.size()); + + auto states = model_->StackStates(states_vec); + + auto pair = model_->RunEncoder(std::move(x), std::move(states), + std::move(processed_frames)); + + decoder_->Decode(std::move(pair.first), ss, &results); + + std::vector> next_states = + model_->UnStackStates(pair.second); + + for (int32_t i = 0; i != n; ++i) { + ss[i]->SetKeywordResult(results[i]); + ss[i]->SetStates(std::move(next_states[i])); + } + } + + KeywordResult GetResult(OnlineStream *s) const override { + TransducerKeywordResult decoder_result = s->GetKeywordResult(true); + + // TODO(fangjun): Remember to change these constants if needed + int32_t frame_shift_ms = 10; + int32_t subsampling_factor = 4; + return Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor, + s->GetNumFramesSinceStart()); + } + + private: + void InitKeywords(std::istream &is) { + if (!EncodeKeywords(is, sym_, &keywords_id_, &keywords_, &boost_scores_, + &thresholds_)) { + SHERPA_ONNX_LOGE("Encode keywords failed."); + exit(-1); + } + keywords_graph_ = std::make_shared( + keywords_id_, config_.keywords_score, config_.keywords_threshold, + boost_scores_, keywords_, thresholds_); + } + + void InitKeywords() { +#ifdef SHERPA_ONNX_ENABLE_WASM_KWS + // Due to the limitations of the wasm file system, + // the keyword_file variable is directly parsed as a string of keywords + // if WASM KWS on + std::istringstream is(config_.keywords_file); + InitKeywords(is); +#else + // each line in keywords_file contains space-separated words + std::ifstream is(config_.keywords_file); + if (!is) { +#if __OHOS__ + SHERPA_ONNX_LOGE("Open keywords file failed: %{public}s", + config_.keywords_file.c_str()); +#else + SHERPA_ONNX_LOGE("Open keywords file failed: %s", + config_.keywords_file.c_str()); +#endif + exit(-1); + } + InitKeywords(is); +#endif + } + + template + void InitKeywords(Manager *mgr) { + // each line in keywords_file contains space-separated words + + auto buf = ReadFile(mgr, config_.keywords_file); + + std::istrstream is(buf.data(), buf.size()); + + if (!is) { +#if __OHOS__ + SHERPA_ONNX_LOGE("Open keywords file failed: %{public}s", + config_.keywords_file.c_str()); +#else + SHERPA_ONNX_LOGE("Open keywords file failed: %s", + config_.keywords_file.c_str()); +#endif + exit(-1); + } + InitKeywords(is); + } + + void InitKeywordsFromBufStr() { + // keywords_buf's content is supposed to be same as the keywords_file's + std::istringstream is(config_.keywords_buf); + InitKeywords(is); + } + + void InitOnlineStream(OnlineStream *stream) const { + auto r = decoder_->GetEmptyResult(); + SHERPA_ONNX_CHECK_EQ(r.hyps.Size(), 1); + + SHERPA_ONNX_CHECK(stream->GetContextGraph() != nullptr); + r.hyps.begin()->second.context_state = stream->GetContextGraph()->Root(); + + stream->SetKeywordResult(r); + stream->SetStates(model_->GetEncoderInitStates()); + } + + private: + KeywordSpotterConfig config_; + std::vector> keywords_id_; + std::vector boost_scores_; + std::vector thresholds_; + std::vector keywords_; + ContextGraphPtr keywords_graph_; + std::unique_ptr model_; + std::unique_ptr decoder_; + SymbolTable sym_; + int32_t unk_id_ = -1; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_TRANSDUCER_RKNN_IMPL_H_ diff --git a/sherpa-onnx/csrc/online-recognizer-impl.cc b/sherpa-onnx/csrc/online-recognizer-impl.cc index c8f4c26959..676bba8378 100644 --- a/sherpa-onnx/csrc/online-recognizer-impl.cc +++ b/sherpa-onnx/csrc/online-recognizer-impl.cc @@ -42,7 +42,7 @@ std::unique_ptr OnlineRecognizerImpl::Create( config.model_config.zipformer2_ctc.model.empty()) { SHERPA_ONNX_LOGE( "Only Zipformer transducers and CTC models are currently supported " - "by rknn. Fallback to CPU"); + "by rknn. Fallback to CPU. Make sure you pass an onnx model"); } else if (!config.model_config.transducer.encoder.empty()) { return std::make_unique(config); } else if (!config.model_config.zipformer2_ctc.model.empty()) { @@ -51,7 +51,7 @@ std::unique_ptr OnlineRecognizerImpl::Create( #else SHERPA_ONNX_LOGE( "Please rebuild sherpa-onnx with -DSHERPA_ONNX_ENABLE_RKNN=ON if you " - "want to use rknn. Fallback to CPU"); + "want to use rknn. Fallback to CPU. Make sure you pass an onnx model."); #endif } diff --git a/sherpa-onnx/csrc/sherpa-onnx-microphone.cc b/sherpa-onnx/csrc/sherpa-onnx-microphone.cc index f9f8ab326f..bf1ac92e34 100644 --- a/sherpa-onnx/csrc/sherpa-onnx-microphone.cc +++ b/sherpa-onnx/csrc/sherpa-onnx-microphone.cc @@ -146,8 +146,8 @@ for a list of pre-trained models to download. param.hostApiSpecificStreamInfo = nullptr; const char *pSampleRateStr = std::getenv("SHERPA_ONNX_MIC_SAMPLE_RATE"); if (pSampleRateStr) { - fprintf(stderr, "Use sample rate %f for mic\n", mic_sample_rate); mic_sample_rate = atof(pSampleRateStr); + fprintf(stderr, "Use sample rate %f for mic\n", mic_sample_rate); } float sample_rate = 16000; From 37653d537d9e1654135587f1e0e1f0f9da6c7c8b Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 25 Mar 2025 15:18:58 +0800 Subject: [PATCH 2/2] Add RKNN support for KWS --- sherpa-onnx/csrc/CMakeLists.txt | 1 + sherpa-onnx/csrc/keyword-spotter-impl.cc | 2 +- .../csrc/keyword-spotter-transducer-impl.h | 7 +- .../keyword-spotter-transducer-rknn-impl.h | 165 ++++++------------ ...ducer-modified-beam-search-decoder-rknn.cc | 4 +- .../rknn/transducer-keyword-decoder-rknn.cc | 149 ++++++++++++++++ .../rknn/transducer-keyword-decoder-rknn.h | 42 +++++ .../csrc/sherpa-onnx-keyword-spotter.cc | 99 ++++++++--- 8 files changed, 327 insertions(+), 142 deletions(-) rename sherpa-onnx/csrc/{ => rknn}/keyword-spotter-transducer-rknn-impl.h (64%) create mode 100644 sherpa-onnx/csrc/rknn/transducer-keyword-decoder-rknn.cc create mode 100644 sherpa-onnx/csrc/rknn/transducer-keyword-decoder-rknn.h diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index e59a7a17d7..2813df381f 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -159,6 +159,7 @@ if(SHERPA_ONNX_ENABLE_RKNN) ./rknn/online-transducer-modified-beam-search-decoder-rknn.cc ./rknn/online-zipformer-ctc-model-rknn.cc ./rknn/online-zipformer-transducer-model-rknn.cc + ./rknn/transducer-keyword-decoder-rknn.cc ./rknn/utils.cc ) diff --git a/sherpa-onnx/csrc/keyword-spotter-impl.cc b/sherpa-onnx/csrc/keyword-spotter-impl.cc index d6006446aa..a1c577a595 100644 --- a/sherpa-onnx/csrc/keyword-spotter-impl.cc +++ b/sherpa-onnx/csrc/keyword-spotter-impl.cc @@ -7,7 +7,7 @@ #include "sherpa-onnx/csrc/keyword-spotter-transducer-impl.h" #if SHERPA_ONNX_ENABLE_RKNN -#include "sherpa-onnx/csrc/keyword-spotter-transducer-rknn-impl.h" +#include "sherpa-onnx/csrc/rknn/keyword-spotter-transducer-rknn-impl.h" #endif #if __ANDROID_API__ >= 9 diff --git a/sherpa-onnx/csrc/keyword-spotter-transducer-impl.h b/sherpa-onnx/csrc/keyword-spotter-transducer-impl.h index e62b3b2c07..526d0d62f5 100644 --- a/sherpa-onnx/csrc/keyword-spotter-transducer-impl.h +++ b/sherpa-onnx/csrc/keyword-spotter-transducer-impl.h @@ -24,10 +24,9 @@ namespace sherpa_onnx { -static KeywordResult Convert(const TransducerKeywordResult &src, - const SymbolTable &sym_table, float frame_shift_ms, - int32_t subsampling_factor, - int32_t frames_since_start) { +KeywordResult Convert(const TransducerKeywordResult &src, + const SymbolTable &sym_table, float frame_shift_ms, + int32_t subsampling_factor, int32_t frames_since_start) { KeywordResult r; r.tokens.reserve(src.tokens.size()); r.timestamps.reserve(src.tokens.size()); diff --git a/sherpa-onnx/csrc/keyword-spotter-transducer-rknn-impl.h b/sherpa-onnx/csrc/rknn/keyword-spotter-transducer-rknn-impl.h similarity index 64% rename from sherpa-onnx/csrc/keyword-spotter-transducer-rknn-impl.h rename to sherpa-onnx/csrc/rknn/keyword-spotter-transducer-rknn-impl.h index b82705401e..35bcbfa8c6 100644 --- a/sherpa-onnx/csrc/keyword-spotter-transducer-rknn-impl.h +++ b/sherpa-onnx/csrc/rknn/keyword-spotter-transducer-rknn-impl.h @@ -2,8 +2,8 @@ // // Copyright (c) 2025 Xiaomi Corporation -#ifndef SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_TRANSDUCER_RKNN_IMPL_H_ -#define SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_TRANSDUCER_RKNN_IMPL_H_ +#ifndef SHERPA_ONNX_CSRC_RKNN_KEYWORD_SPOTTER_TRANSDUCER_RKNN_IMPL_H_ +#define SHERPA_ONNX_CSRC_RKNN_KEYWORD_SPOTTER_TRANSDUCER_RKNN_IMPL_H_ #include #include @@ -17,50 +17,24 @@ #include "sherpa-onnx/csrc/keyword-spotter-impl.h" #include "sherpa-onnx/csrc/keyword-spotter.h" #include "sherpa-onnx/csrc/macros.h" -#include "sherpa-onnx/csrc/online-transducer-model.h" +#include "sherpa-onnx/csrc/rknn/online-stream-rknn.h" +#include "sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.h" +#include "sherpa-onnx/csrc/rknn/transducer-keyword-decoder-rknn.h" #include "sherpa-onnx/csrc/symbol-table.h" -#include "sherpa-onnx/csrc/transducer-keyword-decoder.h" #include "sherpa-onnx/csrc/utils.h" namespace sherpa_onnx { -static KeywordResult Convert(const TransducerKeywordResult &src, - const SymbolTable &sym_table, float frame_shift_ms, - int32_t subsampling_factor, - int32_t frames_since_start) { - KeywordResult r; - r.tokens.reserve(src.tokens.size()); - r.timestamps.reserve(src.tokens.size()); - r.keyword = src.keyword; - bool from_tokens = src.keyword.empty(); - - for (auto i : src.tokens) { - auto sym = sym_table[i]; - if (from_tokens) { - r.keyword.append(sym); - } - r.tokens.push_back(std::move(sym)); - } - if (from_tokens && r.keyword.size()) { - r.keyword = r.keyword.substr(1); - } - - float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor; - for (auto t : src.timestamps) { - float time = frame_shift_s * t; - r.timestamps.push_back(time); - } - - r.start_time = frames_since_start * frame_shift_ms / 1000.; - - return r; -} +KeywordResult Convert(const TransducerKeywordResult &src, + const SymbolTable &sym_table, float frame_shift_ms, + int32_t subsampling_factor, int32_t frames_since_start); class KeywordSpotterTransducerRknnImpl : public KeywordSpotterImpl { public: explicit KeywordSpotterTransducerRknnImpl(const KeywordSpotterConfig &config) : config_(config), - model_(OnlineTransducerModel::Create(config.model_config)) { + model_(std::make_unique( + config.model_config)) { if (!config.model_config.tokens_buf.empty()) { sym_ = SymbolTable(config.model_config.tokens_buf, false); } else { @@ -72,15 +46,13 @@ class KeywordSpotterTransducerRknnImpl : public KeywordSpotterImpl { unk_id_ = sym_[""]; } - model_->SetFeatureDim(config.feat_config.feature_dim); - if (config.keywords_buf.empty()) { InitKeywords(); } else { InitKeywordsFromBufStr(); } - decoder_ = std::make_unique( + decoder_ = std::make_unique( model_.get(), config_.max_active_paths, config_.num_trailing_blanks, unk_id_); } @@ -89,24 +61,24 @@ class KeywordSpotterTransducerRknnImpl : public KeywordSpotterImpl { KeywordSpotterTransducerRknnImpl(Manager *mgr, const KeywordSpotterConfig &config) : config_(config), - model_(OnlineTransducerModel::Create(mgr, config.model_config)), + model_(std::make_unique( + mgr, config.model_config)), sym_(mgr, config.model_config.tokens) { if (sym_.Contains("")) { unk_id_ = sym_[""]; } - model_->SetFeatureDim(config.feat_config.feature_dim); - InitKeywords(mgr); - decoder_ = std::make_unique( + decoder_ = std::make_unique( model_.get(), config_.max_active_paths, config_.num_trailing_blanks, unk_id_); } std::unique_ptr CreateStream() const override { - auto stream = - std::make_unique(config_.feat_config, keywords_graph_); + auto stream = std::make_unique(config_.feat_config, + keywords_graph_); + InitOnlineStream(stream.get()); return stream; } @@ -183,7 +155,7 @@ class KeywordSpotterTransducerRknnImpl : public KeywordSpotterImpl { current_scores, current_kws, current_thresholds); auto stream = - std::make_unique(config_.feat_config, keywords_graph); + std::make_unique(config_.feat_config, keywords_graph); InitOnlineStream(stream.get()); return stream; } @@ -192,81 +164,47 @@ class KeywordSpotterTransducerRknnImpl : public KeywordSpotterImpl { return s->GetNumProcessedFrames() + model_->ChunkSize() < s->NumFramesReady(); } - void Reset(OnlineStream *s) const override { InitOnlineStream(s); } - void DecodeStreams(OnlineStream **ss, int32_t n) const override { - for (int32_t i = 0; i < n; ++i) { - auto s = ss[i]; - auto r = s->GetKeywordResult(true); - int32_t num_trailing_blanks = r.num_trailing_blanks; - // assume subsampling_factor is 4 - // assume frameshift is 0.01 second - float trailing_slience = num_trailing_blanks * 4 * 0.01; - - // it resets automatically after detecting 1.5 seconds of silence - float threshold = 1.5; - if (trailing_slience > threshold) { - Reset(s); - } + void Reset(OnlineStream *s) const override { + InitOnlineStream(reinterpret_cast(s)); + } + + void DecodeStream(OnlineStreamRknn *s) const { + auto r = s->GetKeywordResult(true); + int32_t num_trailing_blanks = r.num_trailing_blanks; + // assume subsampling_factor is 4 + // assume frameshift is 0.01 second + float trailing_slience = num_trailing_blanks * 4 * 0.01; + + // it resets automatically after detecting 1.5 seconds of silence + float threshold = 1.5; + if (trailing_slience > threshold) { + Reset(s); } int32_t chunk_size = model_->ChunkSize(); int32_t chunk_shift = model_->ChunkShift(); - int32_t feature_dim = ss[0]->FeatureDim(); + int32_t feature_dim = s->FeatureDim(); - std::vector results(n); - std::vector features_vec(n * chunk_size * feature_dim); - std::vector> states_vec(n); - std::vector all_processed_frames(n); + const auto num_processed_frames = s->GetNumProcessedFrames(); - for (int32_t i = 0; i != n; ++i) { - SHERPA_ONNX_CHECK(ss[i]->GetContextGraph() != nullptr); + std::vector features = + s->GetFrames(num_processed_frames, chunk_size); + s->GetNumProcessedFrames() += chunk_shift; - const auto num_processed_frames = ss[i]->GetNumProcessedFrames(); - std::vector features = - ss[i]->GetFrames(num_processed_frames, chunk_size); + auto &states = s->GetZipformerEncoderStates(); - // Question: should num_processed_frames include chunk_shift? - ss[i]->GetNumProcessedFrames() += chunk_shift; + auto p = model_->RunEncoder(features, std::move(states)); - std::copy(features.begin(), features.end(), - features_vec.data() + i * chunk_size * feature_dim); + states = std::move(p.second); - results[i] = std::move(ss[i]->GetKeywordResult()); - states_vec[i] = std::move(ss[i]->GetStates()); - all_processed_frames[i] = num_processed_frames; - } - - auto memory_info = - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); - - std::array x_shape{n, chunk_size, feature_dim}; - - Ort::Value x = Ort::Value::CreateTensor(memory_info, features_vec.data(), - features_vec.size(), x_shape.data(), - x_shape.size()); - - std::array processed_frames_shape{ - static_cast(all_processed_frames.size())}; - - Ort::Value processed_frames = Ort::Value::CreateTensor( - memory_info, all_processed_frames.data(), all_processed_frames.size(), - processed_frames_shape.data(), processed_frames_shape.size()); - - auto states = model_->StackStates(states_vec); - - auto pair = model_->RunEncoder(std::move(x), std::move(states), - std::move(processed_frames)); - - decoder_->Decode(std::move(pair.first), ss, &results); - - std::vector> next_states = - model_->UnStackStates(pair.second); + decoder_->Decode(std::move(p.first), s); + } - for (int32_t i = 0; i != n; ++i) { - ss[i]->SetKeywordResult(results[i]); - ss[i]->SetStates(std::move(next_states[i])); + void DecodeStreams(OnlineStream **ss, int32_t n) const override { + for (int32_t i = 0; i < n; ++i) { + DecodeStream(reinterpret_cast(ss[i])); } } @@ -343,7 +281,7 @@ class KeywordSpotterTransducerRknnImpl : public KeywordSpotterImpl { InitKeywords(is); } - void InitOnlineStream(OnlineStream *stream) const { + void InitOnlineStream(OnlineStreamRknn *stream) const { auto r = decoder_->GetEmptyResult(); SHERPA_ONNX_CHECK_EQ(r.hyps.Size(), 1); @@ -351,7 +289,7 @@ class KeywordSpotterTransducerRknnImpl : public KeywordSpotterImpl { r.hyps.begin()->second.context_state = stream->GetContextGraph()->Root(); stream->SetKeywordResult(r); - stream->SetStates(model_->GetEncoderInitStates()); + stream->SetZipformerEncoderStates(model_->GetEncoderInitStates()); } private: @@ -361,12 +299,13 @@ class KeywordSpotterTransducerRknnImpl : public KeywordSpotterImpl { std::vector thresholds_; std::vector keywords_; ContextGraphPtr keywords_graph_; - std::unique_ptr model_; - std::unique_ptr decoder_; + std::unique_ptr model_; + + std::unique_ptr decoder_; SymbolTable sym_; int32_t unk_id_ = -1; }; } // namespace sherpa_onnx -#endif // SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_TRANSDUCER_RKNN_IMPL_H_ +#endif // SHERPA_ONNX_CSRC_RKNN_KEYWORD_SPOTTER_TRANSDUCER_RKNN_IMPL_H_ diff --git a/sherpa-onnx/csrc/rknn/online-transducer-modified-beam-search-decoder-rknn.cc b/sherpa-onnx/csrc/rknn/online-transducer-modified-beam-search-decoder-rknn.cc index cb4456b27e..3d3f2cb5bc 100644 --- a/sherpa-onnx/csrc/rknn/online-transducer-modified-beam-search-decoder-rknn.cc +++ b/sherpa-onnx/csrc/rknn/online-transducer-modified-beam-search-decoder-rknn.cc @@ -42,7 +42,7 @@ void OnlineTransducerModifiedBeamSearchDecoderRknn::StripLeadingBlanks( r->num_trailing_blanks = hyp.num_trailing_blanks; } -static std::vector> GetDecoderOut( +std::vector> GetDecoderOut( OnlineZipformerTransducerModelRknn *model, const Hypotheses &hyp_vec) { std::vector> ans; ans.reserve(hyp_vec.Size()); @@ -61,7 +61,7 @@ static std::vector> GetDecoderOut( return ans; } -static std::vector> GetJoinerOutLogSoftmax( +std::vector> GetJoinerOutLogSoftmax( OnlineZipformerTransducerModelRknn *model, const float *p_encoder_out, const std::vector> &decoder_out) { std::vector> ans; diff --git a/sherpa-onnx/csrc/rknn/transducer-keyword-decoder-rknn.cc b/sherpa-onnx/csrc/rknn/transducer-keyword-decoder-rknn.cc new file mode 100644 index 0000000000..98cac8c22a --- /dev/null +++ b/sherpa-onnx/csrc/rknn/transducer-keyword-decoder-rknn.cc @@ -0,0 +1,149 @@ +// sherpa-onnx/csrc/rknn/transducer-keywords-decoder-rknn.cc +// +// Copyright (c) 2025 Xiaomi Corporation + +#include "sherpa-onnx/csrc/rknn/transducer-keyword-decoder-rknn.h" + +#include +#include +#include +#include +#include + +#include "sherpa-onnx/csrc/log.h" + +namespace sherpa_onnx { + +TransducerKeywordResult TransducerKeywordDecoderRknn::GetEmptyResult() const { + int32_t context_size = model_->ContextSize(); + int32_t blank_id = 0; // always 0 + TransducerKeywordResult r; + std::vector blanks(context_size, -1); + blanks.back() = blank_id; + + Hypotheses blank_hyp({{blanks, 0}}); + r.hyps = std::move(blank_hyp); + return r; +} + +std::vector> GetDecoderOut( + OnlineZipformerTransducerModelRknn *model, const Hypotheses &hyp_vec); + +std::vector> GetJoinerOutLogSoftmax( + OnlineZipformerTransducerModelRknn *model, const float *p_encoder_out, + const std::vector> &decoder_out); + +void TransducerKeywordDecoderRknn::Decode(std::vector encoder_out, + OnlineStreamRknn *s) { + auto attr = model_->GetEncoderOutAttr(); + int32_t num_frames = attr.dims[1]; + int32_t encoder_out_dim = attr.dims[2]; + + int32_t vocab_size = model_->VocabSize(); + int32_t context_size = model_->ContextSize(); + + std::vector blanks(context_size, -1); + blanks.back() = 0; // blank_id is hardcoded to 0 + + auto r = s->GetKeywordResult(); + + Hypotheses cur = std::move(r.hyps); + std::vector prev; + + auto decoder_out = GetDecoderOut(model_, cur); + + const float *p_encoder_out = encoder_out.data(); + + int32_t frame_offset = r.frame_offset; + + for (int32_t t = 0; t != num_frames; ++t) { + prev = cur.Vec(); + cur.Clear(); + + auto log_probs = GetJoinerOutLogSoftmax(model_, p_encoder_out, decoder_out); + + auto log_probs_old = log_probs; + + p_encoder_out += encoder_out_dim; + + for (int32_t i = 0; i != prev.size(); ++i) { + auto log_prob = prev[i].log_prob; + for (auto &p : log_probs[i]) { + p += log_prob; + } + } + + auto topk = TopkIndex(log_probs, max_active_paths_); + + Hypotheses hyps; + + for (auto k : topk) { + int32_t hyp_index = k / vocab_size; + int32_t new_token = k % vocab_size; + + Hypothesis new_hyp = prev[hyp_index]; + float context_score = 0; + auto context_state = new_hyp.context_state; + + // blank is hardcoded to 0 + // also, it treats unk as blank + if (new_token != 0 && new_token != unk_id_) { + new_hyp.ys.push_back(new_token); + new_hyp.timestamps.push_back(t + frame_offset); + new_hyp.ys_probs.push_back(exp(log_probs_old[hyp_index][new_token])); + + new_hyp.num_trailing_blanks = 0; + auto context_res = + s->GetContextGraph()->ForwardOneStep(context_state, new_token); + context_score = std::get<0>(context_res); + new_hyp.context_state = std::get<1>(context_res); + // Start matching from the start state, forget the decoder history. + if (new_hyp.context_state->token == -1) { + new_hyp.ys = blanks; + new_hyp.timestamps.clear(); + new_hyp.ys_probs.clear(); + } + } else { + ++new_hyp.num_trailing_blanks; + } + new_hyp.log_prob = log_probs[hyp_index][new_token] + context_score; + hyps.Add(std::move(new_hyp)); + } // for (auto k : topk) + + auto best_hyp = hyps.GetMostProbable(false); + + auto status = s->GetContextGraph()->IsMatched(best_hyp.context_state); + bool matched = std::get<0>(status); + const ContextState *matched_state = std::get<1>(status); + + if (matched) { + float ys_prob = 0.0; + for (int32_t i = 0; i < matched_state->level; ++i) { + ys_prob += best_hyp.ys_probs[i]; + } + ys_prob /= matched_state->level; + if (best_hyp.num_trailing_blanks > num_trailing_blanks_ && + ys_prob >= matched_state->ac_threshold) { + r.tokens = {best_hyp.ys.end() - matched_state->level, + best_hyp.ys.end()}; + r.timestamps = {best_hyp.timestamps.end() - matched_state->level, + best_hyp.timestamps.end()}; + r.keyword = matched_state->phrase; + + hyps = Hypotheses({{blanks, 0, s->GetContextGraph()->Root()}}); + } + } + + cur = std::move(hyps); + decoder_out = GetDecoderOut(model_, cur); + } + + auto best_hyp = cur.GetMostProbable(false); + r.hyps = std::move(cur); + r.frame_offset += num_frames; + r.num_trailing_blanks = best_hyp.num_trailing_blanks; + + s->SetKeywordResult(r); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/rknn/transducer-keyword-decoder-rknn.h b/sherpa-onnx/csrc/rknn/transducer-keyword-decoder-rknn.h new file mode 100644 index 0000000000..bac1ea6475 --- /dev/null +++ b/sherpa-onnx/csrc/rknn/transducer-keyword-decoder-rknn.h @@ -0,0 +1,42 @@ +// sherpa-onnx/csrc/rknn/transducer-keywords-decoder-rknn.h +// +// Copyright (c) 2025 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_RKNN_TRANSDUCER_KEYWORD_DECODER_RKNN_H_ +#define SHERPA_ONNX_CSRC_RKNN_TRANSDUCER_KEYWORD_DECODER_RKNN_H_ + +#include +#include +#include + +#include "sherpa-onnx/csrc/rknn/online-stream-rknn.h" +#include "sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.h" +#include "sherpa-onnx/csrc/transducer-keyword-decoder.h" + +namespace sherpa_onnx { + +class TransducerKeywordDecoderRknn { + public: + TransducerKeywordDecoderRknn(OnlineZipformerTransducerModelRknn *model, + int32_t max_active_paths, + int32_t num_trailing_blanks, int32_t unk_id) + : model_(model), + max_active_paths_(max_active_paths), + num_trailing_blanks_(num_trailing_blanks), + unk_id_(unk_id) {} + + TransducerKeywordResult GetEmptyResult() const; + + void Decode(std::vector encoder_out, OnlineStreamRknn *s); + + private: + OnlineZipformerTransducerModelRknn *model_; // Not owned + + int32_t max_active_paths_; + int32_t num_trailing_blanks_; + int32_t unk_id_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_RKNN_TRANSDUCER_KEYWORD_DECODER_RKNN_H_ diff --git a/sherpa-onnx/csrc/sherpa-onnx-keyword-spotter.cc b/sherpa-onnx/csrc/sherpa-onnx-keyword-spotter.cc index 77d5cd4ef1..e52ef8ac9f 100644 --- a/sherpa-onnx/csrc/sherpa-onnx-keyword-spotter.cc +++ b/sherpa-onnx/csrc/sherpa-onnx-keyword-spotter.cc @@ -4,6 +4,7 @@ #include +#include // NOLINT #include #include #include @@ -67,10 +68,9 @@ for a list of pre-trained models to download. sherpa_onnx::KeywordSpotter keyword_spotter(config); - std::vector ss; + if (po.NumArgs() == 1) { + const std::string wav_filename = po.GetArg(1); - for (int32_t i = 1; i <= po.NumArgs(); ++i) { - const std::string wav_filename = po.GetArg(i); int32_t sampling_rate = -1; bool is_ok = false; @@ -82,6 +82,8 @@ for a list of pre-trained models to download. return -1; } + auto begin = std::chrono::steady_clock::now(); + auto s = keyword_spotter.CreateStream(); s->AcceptWaveform(sampling_rate, samples.data(), samples.size()); @@ -90,32 +92,85 @@ for a list of pre-trained models to download. s->AcceptWaveform(sampling_rate, tail_paddings.data(), tail_paddings.size()); - // Call InputFinished() to indicate that no audio samples are available s->InputFinished(); - ss.push_back({std::move(s), wav_filename}); - } - std::vector ready_streams; - for (;;) { - ready_streams.clear(); - for (auto &s : ss) { - const auto p_ss = s.online_stream.get(); - if (keyword_spotter.IsReady(p_ss)) { - ready_streams.push_back(p_ss); - } - std::ostringstream os; - const auto r = keyword_spotter.GetResult(p_ss); + while (keyword_spotter.IsReady(s.get())) { + keyword_spotter.DecodeStream(s.get()); + + auto r = keyword_spotter.GetResult(s.get()); if (!r.keyword.empty()) { - os << s.filename << "\n"; - os << r.AsJsonString() << "\n\n"; - fprintf(stderr, "%s", os.str().c_str()); + keyword_spotter.Reset(s.get()); + + fprintf(stderr, "%s\n%s\n\n", wav_filename.c_str(), + r.AsJsonString().c_str()); + } + } + + auto end = std::chrono::steady_clock::now(); + + float duration = samples.size() / static_cast(sampling_rate); + + float elapsed_seconds = + std::chrono::duration_cast(end - begin) + .count() / + 1000.; + float rtf = elapsed_seconds / duration; + fprintf(stderr, "Number of threads: %d\n", config.model_config.num_threads); + fprintf(stderr, "Audio duration: %.3f s\n", duration); + fprintf(stderr, "Elapsed seconds: %.3f\n", elapsed_seconds); + fprintf(stderr, "RTF = %.3f/%.3f = %.3f\n", elapsed_seconds, duration, rtf); + + } else { + std::vector ss; + + for (int32_t i = 1; i <= po.NumArgs(); ++i) { + const std::string wav_filename = po.GetArg(i); + int32_t sampling_rate = -1; + + bool is_ok = false; + const std::vector samples = + sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok); + + if (!is_ok) { + fprintf(stderr, "Failed to read '%s'\n", wav_filename.c_str()); + return -1; } + + auto s = keyword_spotter.CreateStream(); + s->AcceptWaveform(sampling_rate, samples.data(), samples.size()); + + std::vector tail_paddings(static_cast(0.8 * sampling_rate)); + // Note: We can call AcceptWaveform() multiple times. + s->AcceptWaveform(sampling_rate, tail_paddings.data(), + tail_paddings.size()); + + // Call InputFinished() to indicate that no audio samples are available + s->InputFinished(); + ss.push_back({std::move(s), wav_filename}); } - if (ready_streams.empty()) { - break; + std::vector ready_streams; + for (;;) { + ready_streams.clear(); + for (auto &s : ss) { + const auto p_ss = s.online_stream.get(); + if (keyword_spotter.IsReady(p_ss)) { + ready_streams.push_back(p_ss); + } + std::ostringstream os; + const auto r = keyword_spotter.GetResult(p_ss); + if (!r.keyword.empty()) { + os << s.filename << "\n"; + os << r.AsJsonString() << "\n\n"; + fprintf(stderr, "%s", os.str().c_str()); + } + } + + if (ready_streams.empty()) { + break; + } + keyword_spotter.DecodeStreams(ready_streams.data(), ready_streams.size()); } - keyword_spotter.DecodeStreams(ready_streams.data(), ready_streams.size()); } return 0; }