From a57f4dc9a768fbd83667c0148013b9cc71724711 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 1 Apr 2025 16:36:31 +0800 Subject: [PATCH] Refactor rknn code --- sherpa-onnx/csrc/online-recognizer-impl.cc | 20 ++ .../rknn/online-zipformer-ctc-model-rknn.cc | 123 +------ .../online-zipformer-transducer-model-rknn.cc | 301 ++---------------- .../csrc/rknn/silero-vad-model-rknn.cc | 80 +---- sherpa-onnx/csrc/rknn/utils.cc | 129 +++++++- sherpa-onnx/csrc/rknn/utils.h | 16 +- 6 files changed, 218 insertions(+), 451 deletions(-) diff --git a/sherpa-onnx/csrc/online-recognizer-impl.cc b/sherpa-onnx/csrc/online-recognizer-impl.cc index c8f4c26959..328081c489 100644 --- a/sherpa-onnx/csrc/online-recognizer-impl.cc +++ b/sherpa-onnx/csrc/online-recognizer-impl.cc @@ -92,6 +92,26 @@ std::unique_ptr OnlineRecognizerImpl::Create( template std::unique_ptr OnlineRecognizerImpl::Create( Manager *mgr, const OnlineRecognizerConfig &config) { + if (config.model_config.provider_config.provider == "rknn") { +#if SHERPA_ONNX_ENABLE_RKNN + // Currently, only zipformer v1 is suported for rknn + if (config.model_config.transducer.encoder.empty() && + config.model_config.zipformer2_ctc.model.empty()) { + SHERPA_ONNX_LOGE( + "Only Zipformer transducers and CTC models are currently supported " + "by rknn. Fallback to CPU"); + } else if (!config.model_config.transducer.encoder.empty()) { + return std::make_unique(mgr, config); + } else if (!config.model_config.zipformer2_ctc.model.empty()) { + return std::make_unique(mgr, config); + } +#else + SHERPA_ONNX_LOGE( + "Please rebuild sherpa-onnx with -DSHERPA_ONNX_ENABLE_RKNN=ON if you " + "want to use rknn. Fallback to CPU"); +#endif + } + if (!config.model_config.transducer.encoder.empty()) { Ort::Env env(ORT_LOGGING_LEVEL_ERROR); diff --git a/sherpa-onnx/csrc/rknn/online-zipformer-ctc-model-rknn.cc b/sherpa-onnx/csrc/rknn/online-zipformer-ctc-model-rknn.cc index 52a7b2ba62..eeb0800771 100644 --- a/sherpa-onnx/csrc/rknn/online-zipformer-ctc-model-rknn.cc +++ b/sherpa-onnx/csrc/rknn/online-zipformer-ctc-model-rknn.cc @@ -42,39 +42,17 @@ class OnlineZipformerCtcModelRknn::Impl { Init(buf.data(), buf.size()); } - int32_t ret = RKNN_SUCC; - switch (config_.num_threads) { - case 1: - ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_AUTO); - break; - case 0: - ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_0); - break; - case -1: - ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_1); - break; - case -2: - ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_2); - break; - case -3: - ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_0_1); - break; - case -4: - ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_0_1_2); - break; - default: - SHERPA_ONNX_LOGE( - "Valid num_threads for rk npu is 1 (auto), 0 (core 0), -1 (core " - "1), -2 (core 2), -3 (core 0_1), -4 (core 0_1_2). Given: %d", - config_.num_threads); - break; - } - if (ret != RKNN_SUCC) { - SHERPA_ONNX_LOGE( - "Failed to select npu core to run the model (You can ignore it if " - "you " - "are not using RK3588."); + SetCoreMask(ctx_, config_.num_threads); + } + + template + Impl(Manager *mgr, const OnlineModelConfig &config) : config_(config) { + { + auto buf = ReadFile(mgr, config.zipformer2_ctc.model); + Init(buf.data(), buf.size()); } + + SetCoreMask(ctx_, config_.num_threads); } // TODO(fangjun): Support Android @@ -209,86 +187,13 @@ class OnlineZipformerCtcModelRknn::Impl { private: void Init(void *model_data, size_t model_data_length) { - auto ret = rknn_init(&ctx_, model_data, model_data_length, 0, nullptr); - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to init model '%s'", - config_.zipformer2_ctc.model.c_str()); - - if (config_.debug) { - rknn_sdk_version v; - ret = rknn_query(ctx_, RKNN_QUERY_SDK_VERSION, &v, sizeof(v)); - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get rknn sdk version"); - - SHERPA_ONNX_LOGE("sdk api version: %s, driver version: %s", v.api_version, - v.drv_version); - } - - rknn_input_output_num io_num; - ret = rknn_query(ctx_, RKNN_QUERY_IN_OUT_NUM, &io_num, sizeof(io_num)); - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get I/O information for the model"); - - if (config_.debug) { - SHERPA_ONNX_LOGE("model: %d inputs, %d outputs", - static_cast(io_num.n_input), - static_cast(io_num.n_output)); - } - - input_attrs_.resize(io_num.n_input); - output_attrs_.resize(io_num.n_output); - - int32_t i = 0; - for (auto &attr : input_attrs_) { - memset(&attr, 0, sizeof(attr)); - attr.index = i; - ret = rknn_query(ctx_, RKNN_QUERY_INPUT_ATTR, &attr, sizeof(attr)); - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for model input %d", i); - i += 1; - } - - if (config_.debug) { - std::ostringstream os; - std::string sep; - for (auto &attr : input_attrs_) { - os << sep << ToString(attr); - sep = "\n"; - } - SHERPA_ONNX_LOGE("\n----------Model inputs info----------\n%s", - os.str().c_str()); - } - - i = 0; - for (auto &attr : output_attrs_) { - memset(&attr, 0, sizeof(attr)); - attr.index = i; - ret = rknn_query(ctx_, RKNN_QUERY_OUTPUT_ATTR, &attr, sizeof(attr)); - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for model output %d", i); - i += 1; - } + InitContext(model_data, model_data_length, config_.debug, &ctx_); - if (config_.debug) { - std::ostringstream os; - std::string sep; - for (auto &attr : output_attrs_) { - os << sep << ToString(attr); - sep = "\n"; - } - SHERPA_ONNX_LOGE("\n----------Model outputs info----------\n%s", - os.str().c_str()); - } + InitInputOutputAttrs(ctx_, config_.debug, &input_attrs_, &output_attrs_); - rknn_custom_string custom_string; - ret = rknn_query(ctx_, RKNN_QUERY_CUSTOM_STRING, &custom_string, - sizeof(custom_string)); - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to read custom string from the model"); - if (config_.debug) { - SHERPA_ONNX_LOGE("customs string: %s", custom_string.string); - } - auto meta = Parse(custom_string); + rknn_custom_string custom_string = GetCustomString(ctx_, config_.debug); - if (config_.debug) { - for (const auto &p : meta) { - SHERPA_ONNX_LOGE("%s: %s", p.first.c_str(), p.second.c_str()); - } - } + auto meta = Parse(custom_string, config_.debug); if (meta.count("T")) { T_ = atoi(meta.at("T").c_str()); diff --git a/sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.cc b/sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.cc index 7b6d505d40..9c4f6ea6b9 100644 --- a/sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.cc +++ b/sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.cc @@ -62,65 +62,31 @@ class OnlineZipformerTransducerModelRknn::Impl { InitJoiner(buf.data(), buf.size()); } - // Now select which core to run for RK3588 - int32_t ret_encoder = RKNN_SUCC; - int32_t ret_decoder = RKNN_SUCC; - int32_t ret_joiner = RKNN_SUCC; - switch (config_.num_threads) { - case 1: - ret_encoder = rknn_set_core_mask(encoder_ctx_, RKNN_NPU_CORE_AUTO); - ret_decoder = rknn_set_core_mask(decoder_ctx_, RKNN_NPU_CORE_AUTO); - ret_joiner = rknn_set_core_mask(joiner_ctx_, RKNN_NPU_CORE_AUTO); - break; - case 0: - ret_encoder = rknn_set_core_mask(encoder_ctx_, RKNN_NPU_CORE_0); - ret_decoder = rknn_set_core_mask(decoder_ctx_, RKNN_NPU_CORE_0); - ret_joiner = rknn_set_core_mask(joiner_ctx_, RKNN_NPU_CORE_0); - break; - case -1: - ret_encoder = rknn_set_core_mask(encoder_ctx_, RKNN_NPU_CORE_1); - ret_decoder = rknn_set_core_mask(decoder_ctx_, RKNN_NPU_CORE_1); - ret_joiner = rknn_set_core_mask(joiner_ctx_, RKNN_NPU_CORE_1); - break; - case -2: - ret_encoder = rknn_set_core_mask(encoder_ctx_, RKNN_NPU_CORE_2); - ret_decoder = rknn_set_core_mask(decoder_ctx_, RKNN_NPU_CORE_2); - ret_joiner = rknn_set_core_mask(joiner_ctx_, RKNN_NPU_CORE_2); - break; - case -3: - ret_encoder = rknn_set_core_mask(encoder_ctx_, RKNN_NPU_CORE_0_1); - ret_decoder = rknn_set_core_mask(decoder_ctx_, RKNN_NPU_CORE_0_1); - ret_joiner = rknn_set_core_mask(joiner_ctx_, RKNN_NPU_CORE_0_1); - break; - case -4: - ret_encoder = rknn_set_core_mask(encoder_ctx_, RKNN_NPU_CORE_0_1_2); - ret_decoder = rknn_set_core_mask(decoder_ctx_, RKNN_NPU_CORE_0_1_2); - ret_joiner = rknn_set_core_mask(joiner_ctx_, RKNN_NPU_CORE_0_1_2); - break; - default: - SHERPA_ONNX_LOGE( - "Valid num_threads for rk npu is 1 (auto), 0 (core 0), -1 (core " - "1), -2 (core 2), -3 (core 0_1), -4 (core 0_1_2). Given: %d", - config_.num_threads); - break; - } - if (ret_encoder != RKNN_SUCC) { - SHERPA_ONNX_LOGE( - "Failed to select npu core to run encoder (You can ignore it if you " - "are not using RK3588."); + SetCoreMask(encoder_ctx_, config_.num_threads); + SetCoreMask(decoder_ctx_, config_.num_threads); + SetCoreMask(joiner_ctx_, config_.num_threads); + } + + template + Impl(Manager *mgr, const OnlineModelConfig &config) : config_(config) { + { + auto buf = ReadFile(mgr, config.transducer.encoder); + InitEncoder(buf.data(), buf.size()); } - if (ret_decoder != RKNN_SUCC) { - SHERPA_ONNX_LOGE( - "Failed to select npu core to run decoder (You can ignore it if you " - "are not using RK3588."); + { + auto buf = ReadFile(mgr, config.transducer.decoder); + InitDecoder(buf.data(), buf.size()); } - if (ret_decoder != RKNN_SUCC) { - SHERPA_ONNX_LOGE( - "Failed to select npu core to run joiner (You can ignore it if you " - "are not using RK3588."); + { + auto buf = ReadFile(mgr, config.transducer.joiner); + InitJoiner(buf.data(), buf.size()); } + + SetCoreMask(encoder_ctx_, config_.num_threads); + SetCoreMask(decoder_ctx_, config_.num_threads); + SetCoreMask(joiner_ctx_, config_.num_threads); } // TODO(fangjun): Support Android @@ -325,93 +291,15 @@ class OnlineZipformerTransducerModelRknn::Impl { private: void InitEncoder(void *model_data, size_t model_data_length) { - auto ret = - rknn_init(&encoder_ctx_, model_data, model_data_length, 0, nullptr); - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to init encoder '%s'", - config_.transducer.encoder.c_str()); - - if (config_.debug) { - rknn_sdk_version v; - ret = rknn_query(encoder_ctx_, RKNN_QUERY_SDK_VERSION, &v, sizeof(v)); - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get rknn sdk version"); - - SHERPA_ONNX_LOGE("sdk api version: %s, driver version: %s", v.api_version, - v.drv_version); - } - - rknn_input_output_num io_num; - ret = rknn_query(encoder_ctx_, RKNN_QUERY_IN_OUT_NUM, &io_num, - sizeof(io_num)); - SHERPA_ONNX_RKNN_CHECK(ret, - "Failed to get I/O information for the encoder"); - - if (config_.debug) { - SHERPA_ONNX_LOGE("encoder: %d inputs, %d outputs", - static_cast(io_num.n_input), - static_cast(io_num.n_output)); - } - - encoder_input_attrs_.resize(io_num.n_input); - encoder_output_attrs_.resize(io_num.n_output); - - int32_t i = 0; - for (auto &attr : encoder_input_attrs_) { - memset(&attr, 0, sizeof(attr)); - attr.index = i; - ret = - rknn_query(encoder_ctx_, RKNN_QUERY_INPUT_ATTR, &attr, sizeof(attr)); - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for encoder input %d", i); - i += 1; - } - - if (config_.debug) { - std::ostringstream os; - std::string sep; - for (auto &attr : encoder_input_attrs_) { - os << sep << ToString(attr); - sep = "\n"; - } - SHERPA_ONNX_LOGE("\n----------Encoder inputs info----------\n%s", - os.str().c_str()); - } - - i = 0; - for (auto &attr : encoder_output_attrs_) { - memset(&attr, 0, sizeof(attr)); - attr.index = i; - ret = - rknn_query(encoder_ctx_, RKNN_QUERY_OUTPUT_ATTR, &attr, sizeof(attr)); - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for encoder output %d", - i); - i += 1; - } + InitContext(model_data, model_data_length, config_.debug, &encoder_ctx_); - if (config_.debug) { - std::ostringstream os; - std::string sep; - for (auto &attr : encoder_output_attrs_) { - os << sep << ToString(attr); - sep = "\n"; - } - SHERPA_ONNX_LOGE("\n----------Encoder outputs info----------\n%s", - os.str().c_str()); - } + InitInputOutputAttrs(encoder_ctx_, config_.debug, &encoder_input_attrs_, + &encoder_output_attrs_); - rknn_custom_string custom_string; - ret = rknn_query(encoder_ctx_, RKNN_QUERY_CUSTOM_STRING, &custom_string, - sizeof(custom_string)); - SHERPA_ONNX_RKNN_CHECK( - ret, "Failed to read custom string from the encoder model"); - if (config_.debug) { - SHERPA_ONNX_LOGE("customs string: %s", custom_string.string); - } - auto meta = Parse(custom_string); + rknn_custom_string custom_string = + GetCustomString(encoder_ctx_, config_.debug); - if (config_.debug) { - for (const auto &p : meta) { - SHERPA_ONNX_LOGE("%s: %s", p.first.c_str(), p.second.c_str()); - } - } + auto meta = Parse(custom_string, config_.debug); if (meta.count("encoder_dims")) { SplitStringToIntegers(meta.at("encoder_dims"), ",", false, @@ -479,58 +367,10 @@ class OnlineZipformerTransducerModelRknn::Impl { } void InitDecoder(void *model_data, size_t model_data_length) { - auto ret = - rknn_init(&decoder_ctx_, model_data, model_data_length, 0, nullptr); - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to init decoder '%s'", - config_.transducer.decoder.c_str()); - - rknn_input_output_num io_num; - ret = rknn_query(decoder_ctx_, RKNN_QUERY_IN_OUT_NUM, &io_num, - sizeof(io_num)); - SHERPA_ONNX_RKNN_CHECK(ret, - "Failed to get I/O information for the decoder"); - - if (io_num.n_input != 1) { - SHERPA_ONNX_LOGE("Expect only 1 decoder input. Given %d", - static_cast(io_num.n_input)); - SHERPA_ONNX_EXIT(-1); - } + InitContext(model_data, model_data_length, config_.debug, &decoder_ctx_); - if (io_num.n_output != 1) { - SHERPA_ONNX_LOGE("Expect only 1 decoder output. Given %d", - static_cast(io_num.n_output)); - SHERPA_ONNX_EXIT(-1); - } - - if (config_.debug) { - SHERPA_ONNX_LOGE("decoder: %d inputs, %d outputs", - static_cast(io_num.n_input), - static_cast(io_num.n_output)); - } - - decoder_input_attrs_.resize(io_num.n_input); - decoder_output_attrs_.resize(io_num.n_output); - - int32_t i = 0; - for (auto &attr : decoder_input_attrs_) { - memset(&attr, 0, sizeof(attr)); - attr.index = i; - ret = - rknn_query(decoder_ctx_, RKNN_QUERY_INPUT_ATTR, &attr, sizeof(attr)); - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for decoder input %d", i); - i += 1; - } - - if (config_.debug) { - std::ostringstream os; - std::string sep; - for (auto &attr : decoder_input_attrs_) { - os << sep << ToString(attr); - sep = "\n"; - } - SHERPA_ONNX_LOGE("\n----------Decoder inputs info----------\n%s", - os.str().c_str()); - } + InitInputOutputAttrs(decoder_ctx_, config_.debug, &decoder_input_attrs_, + &decoder_output_attrs_); if (decoder_input_attrs_[0].type != RKNN_TENSOR_INT64) { SHERPA_ONNX_LOGE("Expect int64 for decoder input. Given: %d, %s", @@ -543,90 +383,13 @@ class OnlineZipformerTransducerModelRknn::Impl { if (config_.debug) { SHERPA_ONNX_LOGE("context_size: %d", context_size_); } - - i = 0; - for (auto &attr : decoder_output_attrs_) { - memset(&attr, 0, sizeof(attr)); - attr.index = i; - ret = - rknn_query(decoder_ctx_, RKNN_QUERY_OUTPUT_ATTR, &attr, sizeof(attr)); - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for decoder output %d", - i); - i += 1; - } - - if (config_.debug) { - std::ostringstream os; - std::string sep; - for (auto &attr : decoder_output_attrs_) { - os << sep << ToString(attr); - sep = "\n"; - } - SHERPA_ONNX_LOGE("\n----------Decoder outputs info----------\n%s", - os.str().c_str()); - } } void InitJoiner(void *model_data, size_t model_data_length) { - auto ret = - rknn_init(&joiner_ctx_, model_data, model_data_length, 0, nullptr); - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to init joiner '%s'", - config_.transducer.joiner.c_str()); + InitContext(model_data, model_data_length, config_.debug, &joiner_ctx_); - rknn_input_output_num io_num; - ret = - rknn_query(joiner_ctx_, RKNN_QUERY_IN_OUT_NUM, &io_num, sizeof(io_num)); - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get I/O information for the joiner"); - - if (config_.debug) { - SHERPA_ONNX_LOGE("joiner: %d inputs, %d outputs", - static_cast(io_num.n_input), - static_cast(io_num.n_output)); - } - - joiner_input_attrs_.resize(io_num.n_input); - joiner_output_attrs_.resize(io_num.n_output); - - int32_t i = 0; - for (auto &attr : joiner_input_attrs_) { - memset(&attr, 0, sizeof(attr)); - attr.index = i; - ret = rknn_query(joiner_ctx_, RKNN_QUERY_INPUT_ATTR, &attr, sizeof(attr)); - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for joiner input %d", i); - i += 1; - } - - if (config_.debug) { - std::ostringstream os; - std::string sep; - for (auto &attr : joiner_input_attrs_) { - os << sep << ToString(attr); - sep = "\n"; - } - SHERPA_ONNX_LOGE("\n----------Joiner inputs info----------\n%s", - os.str().c_str()); - } - - i = 0; - for (auto &attr : joiner_output_attrs_) { - memset(&attr, 0, sizeof(attr)); - attr.index = i; - ret = - rknn_query(joiner_ctx_, RKNN_QUERY_OUTPUT_ATTR, &attr, sizeof(attr)); - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for joiner output %d", i); - i += 1; - } - - if (config_.debug) { - std::ostringstream os; - std::string sep; - for (auto &attr : joiner_output_attrs_) { - os << sep << ToString(attr); - sep = "\n"; - } - SHERPA_ONNX_LOGE("\n----------Joiner outputs info----------\n%s", - os.str().c_str()); - } + InitInputOutputAttrs(joiner_ctx_, config_.debug, &joiner_input_attrs_, + &joiner_output_attrs_); vocab_size_ = joiner_output_attrs_[0].dims[1]; if (config_.debug) { diff --git a/sherpa-onnx/csrc/rknn/silero-vad-model-rknn.cc b/sherpa-onnx/csrc/rknn/silero-vad-model-rknn.cc index eb1963e965..d47d4e6e64 100644 --- a/sherpa-onnx/csrc/rknn/silero-vad-model-rknn.cc +++ b/sherpa-onnx/csrc/rknn/silero-vad-model-rknn.cc @@ -4,6 +4,7 @@ #include "sherpa-onnx/csrc/rknn/silero-vad-model-rknn.h" +#include #include #include #include @@ -39,6 +40,8 @@ class SileroVadModelRknn::Impl { auto buf = ReadFile(config.silero_vad.model); Init(buf.data(), buf.size()); + SetCoreMask(ctx_, config_.num_threads); + if (sample_rate_ != 16000) { SHERPA_ONNX_LOGE("Expected sample rate 16000. Given: %d", config.sample_rate); @@ -57,6 +60,8 @@ class SileroVadModelRknn::Impl { auto buf = ReadFile(mgr, config.silero_vad.model); Init(buf.data(), buf.size()); + SetCoreMask(ctx_, config_.num_threads); + if (sample_rate_ != 16000) { SHERPA_ONNX_LOGE("Expected sample rate 16000. Given: %d", config.sample_rate); @@ -172,80 +177,13 @@ class SileroVadModelRknn::Impl { private: void Init(void *model_data, size_t model_data_length) { - auto ret = rknn_init(&ctx_, model_data, model_data_length, 0, nullptr); - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to init silero vad model '%s'", - config_.silero_vad.model.c_str()); - - if (config_.debug) { - rknn_sdk_version v; - ret = rknn_query(ctx_, RKNN_QUERY_SDK_VERSION, &v, sizeof(v)); - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get rknn sdk version"); - - SHERPA_ONNX_LOGE("sdk api version: %s, driver version: %s", v.api_version, - v.drv_version); - } + InitContext(model_data, model_data_length, config_.debug, &ctx_); - rknn_input_output_num io_num; - ret = rknn_query(ctx_, RKNN_QUERY_IN_OUT_NUM, &io_num, sizeof(io_num)); - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get I/O information for the model"); + InitInputOutputAttrs(ctx_, config_.debug, &input_attrs_, &output_attrs_); - if (config_.debug) { - SHERPA_ONNX_LOGE("model: %d inputs, %d outputs", - static_cast(io_num.n_input), - static_cast(io_num.n_output)); - } - - input_attrs_.resize(io_num.n_input); - output_attrs_.resize(io_num.n_output); + rknn_custom_string custom_string = GetCustomString(ctx_, config_.debug); - int32_t i = 0; - for (auto &attr : input_attrs_) { - memset(&attr, 0, sizeof(attr)); - attr.index = i; - ret = rknn_query(ctx_, RKNN_QUERY_INPUT_ATTR, &attr, sizeof(attr)); - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for model input %d", i); - i += 1; - } - - if (config_.debug) { - std::ostringstream os; - std::string sep; - for (auto &attr : input_attrs_) { - os << sep << ToString(attr); - sep = "\n"; - } - SHERPA_ONNX_LOGE("\n----------Model inputs info----------\n%s", - os.str().c_str()); - } - - i = 0; - for (auto &attr : output_attrs_) { - memset(&attr, 0, sizeof(attr)); - attr.index = i; - ret = rknn_query(ctx_, RKNN_QUERY_OUTPUT_ATTR, &attr, sizeof(attr)); - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for model output %d", i); - i += 1; - } - - if (config_.debug) { - std::ostringstream os; - std::string sep; - for (auto &attr : output_attrs_) { - os << sep << ToString(attr); - sep = "\n"; - } - SHERPA_ONNX_LOGE("\n----------Model outputs info----------\n%s", - os.str().c_str()); - } - - rknn_custom_string custom_string; - ret = rknn_query(ctx_, RKNN_QUERY_CUSTOM_STRING, &custom_string, - sizeof(custom_string)); - SHERPA_ONNX_RKNN_CHECK(ret, "Failed to read custom string from the model"); - if (config_.debug) { - SHERPA_ONNX_LOGE("customs string: %s", custom_string.string); - } - auto meta = Parse(custom_string); + auto meta = Parse(custom_string, config_.debug); if (config_.silero_vad.window_size != 512) { SHERPA_ONNX_LOGE("we require window_size to be 512. Given: %d", diff --git a/sherpa-onnx/csrc/rknn/utils.cc b/sherpa-onnx/csrc/rknn/utils.cc index 165bf09686..5f092c61ff 100644 --- a/sherpa-onnx/csrc/rknn/utils.cc +++ b/sherpa-onnx/csrc/rknn/utils.cc @@ -4,12 +4,15 @@ #include "sherpa-onnx/csrc/rknn/utils.h" +#include + #include #include #include #include #include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/rknn/macros.h" #include "sherpa-onnx/csrc/text-utils.h" namespace sherpa_onnx { @@ -52,7 +55,7 @@ std::string ToString(const rknn_tensor_attr &attr) { } std::unordered_map Parse( - const rknn_custom_string &custom_string) { + const rknn_custom_string &custom_string, bool debug /*= false*/) { std::unordered_map ans; std::vector fields; SplitStringToVector(custom_string.string, ";", false, &fields); @@ -68,7 +71,131 @@ std::unordered_map Parse( ans[std::move(tmp[0])] = std::move(tmp[1]); } + if (debug) { + for (const auto &p : ans) { + SHERPA_ONNX_LOGE("%s: %s", p.first.c_str(), p.second.c_str()); + } + } + return ans; } +void InitContext(void *model_data, size_t model_data_length, bool debug, + rknn_context *ctx) { + auto ret = rknn_init(ctx, model_data, model_data_length, 0, nullptr); + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to init rknn"); + + if (debug) { + rknn_sdk_version v; + ret = rknn_query(*ctx, RKNN_QUERY_SDK_VERSION, &v, sizeof(v)); + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get rknn sdk version"); + + SHERPA_ONNX_LOGE("sdk api version: %s, driver version: %s", v.api_version, + v.drv_version); + } +} + +void InitInputOutputAttrs(rknn_context ctx, bool debug, + std::vector *input_attrs, + std::vector *output_attrs) { + rknn_input_output_num io_num; + auto ret = rknn_query(ctx, RKNN_QUERY_IN_OUT_NUM, &io_num, sizeof(io_num)); + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get I/O information for the model"); + + if (debug) { + SHERPA_ONNX_LOGE("model: %d inputs, %d outputs", + static_cast(io_num.n_input), + static_cast(io_num.n_output)); + } + + input_attrs->resize(io_num.n_input); + output_attrs->resize(io_num.n_output); + + int32_t i = 0; + for (auto &attr : *input_attrs) { + memset(&attr, 0, sizeof(attr)); + attr.index = i; + ret = rknn_query(ctx, RKNN_QUERY_INPUT_ATTR, &attr, sizeof(attr)); + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for model input %d", i); + i += 1; + } + + if (debug) { + std::ostringstream os; + std::string sep; + for (auto &attr : *input_attrs) { + os << sep << ToString(attr); + sep = "\n"; + } + SHERPA_ONNX_LOGE("\n----------Model inputs info----------\n%s", + os.str().c_str()); + } + + i = 0; + for (auto &attr : *output_attrs) { + memset(&attr, 0, sizeof(attr)); + attr.index = i; + ret = rknn_query(ctx, RKNN_QUERY_OUTPUT_ATTR, &attr, sizeof(attr)); + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for model output %d", i); + i += 1; + } + + if (debug) { + std::ostringstream os; + std::string sep; + for (auto &attr : *output_attrs) { + os << sep << ToString(attr); + sep = "\n"; + } + SHERPA_ONNX_LOGE("\n----------Model outputs info----------\n%s", + os.str().c_str()); + } +} + +rknn_custom_string GetCustomString(rknn_context ctx, bool debug) { + rknn_custom_string custom_string; + auto ret = rknn_query(ctx, RKNN_QUERY_CUSTOM_STRING, &custom_string, + sizeof(custom_string)); + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to read custom string from the model"); + if (debug) { + SHERPA_ONNX_LOGE("customs string: %s", custom_string.string); + } + return custom_string; +} + +void SetCoreMask(rknn_context ctx, int32_t num_threads) { + int32_t ret = RKNN_SUCC; + switch (num_threads) { + case 1: + ret = rknn_set_core_mask(ctx, RKNN_NPU_CORE_AUTO); + break; + case 0: + ret = rknn_set_core_mask(ctx, RKNN_NPU_CORE_0); + break; + case -1: + ret = rknn_set_core_mask(ctx, RKNN_NPU_CORE_1); + break; + case -2: + ret = rknn_set_core_mask(ctx, RKNN_NPU_CORE_2); + break; + case -3: + ret = rknn_set_core_mask(ctx, RKNN_NPU_CORE_0_1); + break; + case -4: + ret = rknn_set_core_mask(ctx, RKNN_NPU_CORE_0_1_2); + break; + default: + SHERPA_ONNX_LOGE( + "Valid num_threads for rk npu is 1 (auto), 0 (core 0), -1 (core " + "1), -2 (core 2), -3 (core 0_1), -4 (core 0_1_2). Given: %d", + num_threads); + break; + } + if (ret != RKNN_SUCC) { + SHERPA_ONNX_LOGE( + "Failed to select npu core to run the model (You can ignore it if " + "you are not using RK3588."); + } +} + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/rknn/utils.h b/sherpa-onnx/csrc/rknn/utils.h index 077d3f6591..4ed329bd35 100644 --- a/sherpa-onnx/csrc/rknn/utils.h +++ b/sherpa-onnx/csrc/rknn/utils.h @@ -7,17 +7,31 @@ #include #include +#include #include "rknn_api.h" // NOLINT namespace sherpa_onnx { + void ConvertNCHWtoNHWC(const float *src, int32_t n, int32_t channel, int32_t height, int32_t width, float *dst); std::string ToString(const rknn_tensor_attr &attr); std::unordered_map Parse( - const rknn_custom_string &custom_string); + const rknn_custom_string &custom_string, bool debug = false); + +void InitContext(void *model_data, size_t model_data_length, bool debug, + rknn_context *ctx); + +void InitInputOutputAttrs(rknn_context ctx, bool debug, + std::vector *input_attrs, + std::vector *output_attrs); + +rknn_custom_string GetCustomString(rknn_context ctx, bool debug); + +void SetCoreMask(rknn_context ctx, int32_t num_threads); + } // namespace sherpa_onnx #endif // SHERPA_ONNX_CSRC_RKNN_UTILS_H_