Skip to content

Refactor rknn code #2079

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions sherpa-onnx/csrc/online-recognizer-impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,26 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
template <typename Manager>
std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
Manager *mgr, const OnlineRecognizerConfig &config) {
if (config.model_config.provider_config.provider == "rknn") {
#if SHERPA_ONNX_ENABLE_RKNN
// Currently, only zipformer v1 is suported for rknn
if (config.model_config.transducer.encoder.empty() &&
config.model_config.zipformer2_ctc.model.empty()) {
SHERPA_ONNX_LOGE(
"Only Zipformer transducers and CTC models are currently supported "
"by rknn. Fallback to CPU");
} else if (!config.model_config.transducer.encoder.empty()) {
return std::make_unique<OnlineRecognizerTransducerRknnImpl>(mgr, config);
} else if (!config.model_config.zipformer2_ctc.model.empty()) {
return std::make_unique<OnlineRecognizerCtcRknnImpl>(mgr, config);
}
#else
SHERPA_ONNX_LOGE(
"Please rebuild sherpa-onnx with -DSHERPA_ONNX_ENABLE_RKNN=ON if you "
"want to use rknn. Fallback to CPU");
#endif
}

if (!config.model_config.transducer.encoder.empty()) {
Ort::Env env(ORT_LOGGING_LEVEL_ERROR);

Expand Down
123 changes: 14 additions & 109 deletions sherpa-onnx/csrc/rknn/online-zipformer-ctc-model-rknn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,39 +42,17 @@ class OnlineZipformerCtcModelRknn::Impl {
Init(buf.data(), buf.size());
}

int32_t ret = RKNN_SUCC;
switch (config_.num_threads) {
case 1:
ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_AUTO);
break;
case 0:
ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_0);
break;
case -1:
ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_1);
break;
case -2:
ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_2);
break;
case -3:
ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_0_1);
break;
case -4:
ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_0_1_2);
break;
default:
SHERPA_ONNX_LOGE(
"Valid num_threads for rk npu is 1 (auto), 0 (core 0), -1 (core "
"1), -2 (core 2), -3 (core 0_1), -4 (core 0_1_2). Given: %d",
config_.num_threads);
break;
}
if (ret != RKNN_SUCC) {
SHERPA_ONNX_LOGE(
"Failed to select npu core to run the model (You can ignore it if "
"you "
"are not using RK3588.");
SetCoreMask(ctx_, config_.num_threads);
}

template <typename Manager>
Impl(Manager *mgr, const OnlineModelConfig &config) : config_(config) {
{
auto buf = ReadFile(mgr, config.zipformer2_ctc.model);
Init(buf.data(), buf.size());
}

SetCoreMask(ctx_, config_.num_threads);
}

// TODO(fangjun): Support Android
Expand Down Expand Up @@ -209,86 +187,13 @@ class OnlineZipformerCtcModelRknn::Impl {

private:
void Init(void *model_data, size_t model_data_length) {
auto ret = rknn_init(&ctx_, model_data, model_data_length, 0, nullptr);
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to init model '%s'",
config_.zipformer2_ctc.model.c_str());

if (config_.debug) {
rknn_sdk_version v;
ret = rknn_query(ctx_, RKNN_QUERY_SDK_VERSION, &v, sizeof(v));
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get rknn sdk version");

SHERPA_ONNX_LOGE("sdk api version: %s, driver version: %s", v.api_version,
v.drv_version);
}

rknn_input_output_num io_num;
ret = rknn_query(ctx_, RKNN_QUERY_IN_OUT_NUM, &io_num, sizeof(io_num));
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get I/O information for the model");

if (config_.debug) {
SHERPA_ONNX_LOGE("model: %d inputs, %d outputs",
static_cast<int32_t>(io_num.n_input),
static_cast<int32_t>(io_num.n_output));
}

input_attrs_.resize(io_num.n_input);
output_attrs_.resize(io_num.n_output);

int32_t i = 0;
for (auto &attr : input_attrs_) {
memset(&attr, 0, sizeof(attr));
attr.index = i;
ret = rknn_query(ctx_, RKNN_QUERY_INPUT_ATTR, &attr, sizeof(attr));
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for model input %d", i);
i += 1;
}

if (config_.debug) {
std::ostringstream os;
std::string sep;
for (auto &attr : input_attrs_) {
os << sep << ToString(attr);
sep = "\n";
}
SHERPA_ONNX_LOGE("\n----------Model inputs info----------\n%s",
os.str().c_str());
}

i = 0;
for (auto &attr : output_attrs_) {
memset(&attr, 0, sizeof(attr));
attr.index = i;
ret = rknn_query(ctx_, RKNN_QUERY_OUTPUT_ATTR, &attr, sizeof(attr));
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for model output %d", i);
i += 1;
}
InitContext(model_data, model_data_length, config_.debug, &ctx_);

if (config_.debug) {
std::ostringstream os;
std::string sep;
for (auto &attr : output_attrs_) {
os << sep << ToString(attr);
sep = "\n";
}
SHERPA_ONNX_LOGE("\n----------Model outputs info----------\n%s",
os.str().c_str());
}
InitInputOutputAttrs(ctx_, config_.debug, &input_attrs_, &output_attrs_);

rknn_custom_string custom_string;
ret = rknn_query(ctx_, RKNN_QUERY_CUSTOM_STRING, &custom_string,
sizeof(custom_string));
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to read custom string from the model");
if (config_.debug) {
SHERPA_ONNX_LOGE("customs string: %s", custom_string.string);
}
auto meta = Parse(custom_string);
rknn_custom_string custom_string = GetCustomString(ctx_, config_.debug);

if (config_.debug) {
for (const auto &p : meta) {
SHERPA_ONNX_LOGE("%s: %s", p.first.c_str(), p.second.c_str());
}
}
auto meta = Parse(custom_string, config_.debug);

if (meta.count("T")) {
T_ = atoi(meta.at("T").c_str());
Expand Down
Loading
Loading