Skip to content

Add C++ runtime for *streaming* faster conformer transducer from NeMo. #889

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sherpa-onnx/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ set(sources
online-transducer-model-config.cc
online-transducer-model.cc
online-transducer-modified-beam-search-decoder.cc
online-transducer-nemo-model.cc
online-transducer-greedy-search-nemo-decoder.cc
online-wenet-ctc-model-config.cc
online-wenet-ctc-model.cc
online-zipformer-transducer-model.cc
Expand Down
29 changes: 27 additions & 2 deletions sherpa-onnx/csrc/online-recognizer-impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,27 @@
#include "sherpa-onnx/csrc/online-recognizer-ctc-impl.h"
#include "sherpa-onnx/csrc/online-recognizer-paraformer-impl.h"
#include "sherpa-onnx/csrc/online-recognizer-transducer-impl.h"
#include "sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h"
#include "sherpa-onnx/csrc/onnx-utils.h"

namespace sherpa_onnx {

std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
const OnlineRecognizerConfig &config) {

if (!config.model_config.transducer.encoder.empty()) {
return std::make_unique<OnlineRecognizerTransducerImpl>(config);
Ort::Env env(ORT_LOGGING_LEVEL_WARNING);

auto decoder_model = ReadFile(config.model_config.transducer.decoder);
auto sess = std::make_unique<Ort::Session>(env, decoder_model.data(), decoder_model.size(), Ort::SessionOptions{});

size_t node_count = sess->GetOutputCount();

if (node_count == 1) {
return std::make_unique<OnlineRecognizerTransducerImpl>(config);
} else {
return std::make_unique<OnlineRecognizerTransducerNeMoImpl>(config);
}
}

if (!config.model_config.paraformer.encoder.empty()) {
Expand All @@ -34,7 +48,18 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
AAssetManager *mgr, const OnlineRecognizerConfig &config) {
if (!config.model_config.transducer.encoder.empty()) {
return std::make_unique<OnlineRecognizerTransducerImpl>(mgr, config);
Ort::Env env(ORT_LOGGING_LEVEL_WARNING);

auto decoder_model = ReadFile(config.model_config.transducer.decoder);
auto sess = std::make_unique<Ort::Session>(env, decoder_model.data(), decoder_model.size(), Ort::SessionOptions{});

size_t node_count = sess->GetOutputCount();

if (node_count == 1) {
return std::make_unique<OnlineRecognizerTransducerImpl>(mgr, config);
} else {
return std::make_unique<OnlineRecognizerTransducerNeMoImpl>(mgr, config);
}
}

if (!config.model_config.paraformer.encoder.empty()) {
Expand Down
203 changes: 203 additions & 0 deletions sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
// sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h
//
// Copyright (c) 2022-2024 Xiaomi Corporation

#ifndef SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_
#define SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_

#include <fstream>
#include <ios>
#include <memory>
#include <regex> // NOLINT
#include <sstream>
#include <string>
#include <utility>
#include <vector>

#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif

#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/online-recognizer-impl.h"
#include "sherpa-onnx/csrc/online-recognizer.h"
#include "sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h"
#include "sherpa-onnx/csrc/online-transducer-nemo-model.h"
#include "sherpa-onnx/csrc/pad-sequence.h"
#include "sherpa-onnx/csrc/symbol-table.h"
#include "sherpa-onnx/csrc/transpose.h"
#include "sherpa-onnx/csrc/utils.h"

namespace sherpa_onnx {

// defined in ./online-recognizer-transducer-impl.h
OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
const SymbolTable &sym_table,
float frame_shift_ms,
int32_t subsampling_factor,
int32_t segment,
int32_t frames_since_start);

class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
public:
explicit OnlineRecognizerTransducerNeMoImpl(
const OnlineRecognizerConfig &config)
: config_(config),
symbol_table_(config_.model_config.tokens),
model_(std::make_unique<OnlineTransducerNeMoModel>(
config_.model_config)) {
if (config_.decoding_method == "greedy_search") {
decoder_ = std::make_unique<OnlineTransducerGreedySearchNeMoDecoder>(
model_.get(), config_.blank_penalty);
} else {
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
config_.decoding_method.c_str());
exit(-1);
}
PostInit();
}

#if __ANDROID_API__ >= 9
explicit OnlineRecognizerTransducerNeMoImpl(
AAssetManager *mgr, const OnlineRecognizerConfig &config)
: config_(config),
symbol_table_(mgr, config_.model_config.tokens),
model_(std::make_unique<OnlineTransducerNeMoModel>(
mgr, config_.model_config)) {
if (config_.decoding_method == "greedy_search") {
decoder_ = std::make_unique<OnlineTransducerGreedySearchNeMoDecoder>(
model_.get(), config_.blank_penalty);
} else {
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
config_.decoding_method.c_str());
exit(-1);
}

PostInit();
}
#endif

std::unique_ptr<OnlineStream> CreateStream() const override {
auto stream = std::make_unique<OnlineStream>(config_.feat_config);
InitOnlineStream(stream.get());
return stream;
}

void DecodeStreams(OnlineStream **ss, int32_t n) const override {
int32_t chunk_size = model_->ChunkSize();
int32_t chunk_shift = model_->ChunkShift();

int32_t feature_dim = ss[0]->FeatureDim();

std::vector<OnlineTransducerDecoderResult> results(n);
std::vector<float> features_vec(n * chunk_size * feature_dim);
std::vector<std::vector<Ort::Value>> states_vec(n);
std::vector<int64_t> all_processed_frames(n);

for (int32_t i = 0; i != n; ++i) {
const auto num_processed_frames = ss[i]->GetNumProcessedFrames();
std::vector<float> features =
ss[i]->GetFrames(num_processed_frames, chunk_size);

// Question: should num_processed_frames include chunk_shift?
ss[i]->GetNumProcessedFrames() += chunk_shift;

std::copy(features.begin(), features.end(),
features_vec.data() + i * chunk_size * feature_dim);

results[i] = std::move(ss[i]->GetResult());
states_vec[i] = std::move(ss[i]->GetStates());
all_processed_frames[i] = num_processed_frames;
}

auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);

std::array<int64_t, 3> x_shape{n, chunk_size, feature_dim};

Ort::Value x = Ort::Value::CreateTensor(memory_info, features_vec.data(),
features_vec.size(), x_shape.data(),
x_shape.size());

std::array<int64_t, 1> processed_frames_shape{
static_cast<int64_t>(all_processed_frames.size())};

Ort::Value processed_frames = Ort::Value::CreateTensor(
memory_info, all_processed_frames.data(), all_processed_frames.size(),
processed_frames_shape.data(), processed_frames_shape.size());

auto states = model_->StackStates(states_vec);

auto [t, ns] = model_->RunEncoder(std::move(x), std::move(states),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please note that this is NeMo transducer encoder, not the one from icefall.

please refer to

self.encoder.get_inputs()[0].name: x.numpy(),
self.encoder.get_inputs()[1].name: x_lens.numpy(),
self.encoder.get_inputs()[2].name: self.cache_last_channel,
self.encoder.get_inputs()[3].name: self.cache_last_time,
self.encoder.get_inputs()[4].name: self.cache_last_channel_len,

and

std::vector<Ort::Value> Forward(Ort::Value x,
std::vector<Ort::Value> states) {
Ort::Value &cache_last_channel = states[0];
Ort::Value &cache_last_time = states[1];
Ort::Value &cache_last_channel_len = states[2];
int32_t batch_size = x.GetTensorTypeAndShapeInfo().GetShape()[0];
std::array<int64_t, 1> length_shape{batch_size};
Ort::Value length = Ort::Value::CreateTensor<int64_t>(
allocator_, length_shape.data(), length_shape.size());
int64_t *p_length = length.GetTensorMutableData<int64_t>();
std::fill(p_length, p_length + batch_size, ChunkLength());
// (B, T, C) -> (B, C, T)
x = Transpose12(allocator_, &x);
std::array<Ort::Value, 5> inputs = {
std::move(x), View(&length), std::move(cache_last_channel),
std::move(cache_last_time), std::move(cache_last_channel_len)};
auto out =
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
output_names_ptr_.data(), output_names_ptr_.size());
// out[0]: logit
// out[1] logit_length
// out[2:] states_next
//
// we need to remove out[1]
std::vector<Ort::Value> ans;
ans.reserve(out.size() - 1);
for (int32_t i = 0; i != out.size(); ++i) {
if (i == 1) {
continue;
}
ans.push_back(std::move(out[i]));
}
return ans;
}

You never need to use processed_frames.

I hope that you can understand what we have written.

Remember that the hybrid transducer + CTC shares the same encoder, which means you can borrow what we have done for the streaming NeMo CTC.

Please compare carefully between

self.model.get_inputs()[0].name: x.numpy(),
self.model.get_inputs()[1].name: x_lens.numpy(),
self.model.get_inputs()[2].name: self.cache_last_channel,
self.model.get_inputs()[3].name: self.cache_last_time,
self.model.get_inputs()[4].name: self.cache_last_channel_len,
},

and

self.encoder.get_inputs()[0].name: x.numpy(),
self.encoder.get_inputs()[1].name: x_lens.numpy(),
self.encoder.get_inputs()[2].name: self.cache_last_channel,
self.encoder.get_inputs()[3].name: self.cache_last_time,
self.encoder.get_inputs()[4].name: self.cache_last_channel_len,

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Than You. I have borrowed the forward method for RunEncoder method in online-transducer-nemo-model.cc.

I have a question regarding the initialization of the decoder states in online-recognizer-transducer-nemo-impl.h
I define these two methods.

  std::unique_ptr<OnlineStream> CreateStream() const override {
    auto stream = std::make_unique<OnlineStream>(config_.feat_config);
    stream->SetStates(model_->GetInitStates());
    InitOnlineStream(stream.get());
    return stream;
  }

  void InitOnlineStream(OnlineStream *stream) const {
    auto r = decoder_->GetEmptyResult();

    stream->SetResult(r);
    stream->SetNeMoDecoderStates(model_->GetDecoderInitStates(batch_size_));
  }

Should the line in InitOnlineStream be this?

stream->SetNeMoDecoderStates(decoder_->GetDecoderInitStates(batch_size_));

std::move(processed_frames));
// t[0] encoder_out, float tensor, (batch_size, dim, T)
// t[1] encoder_out_length, int64 tensor, (batch_size,)

Ort::Value encoder_out = Transpose12(model_->Allocator(), &t[0]);

// defined in online-transducer-greedy-search-nemo-decoder.h
std::vector<OnlineTransducerDecoderResult> results = decoder_-> Decode(std::move(encoder_out), std::move(t[1]));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to pass the decoder model states of the previous chunk to the decoder_->Decode().

By the way, you can create a new method for decoder_
to take an additional argument containing the decoder_states.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

    decoder_-> Decode(std::move(encoder_out), std::move(t[1]),
                      std::move(out_states), &results, ss, n);

I made some changes in online-recognizer-transducer-nemo-impl.h and the Deocde() method now takes in states of previous chunks.


std::vector<std::vector<Ort::Value>> next_states =
model_->UnStackStates(ns);

for (int32_t i = 0; i != n; ++i) {
ss[i]->SetResult(results[i]);
ss[i]->SetNeMoDecoderStates(std::move(next_states[i]));
}
}

void InitOnlineStream(OnlineStream *stream) const {
auto r = decoder_->GetEmptyResult();

stream->SetResult(r);
stream->SetNeMoDecoderStates(model_->GetDecoderInitStates(batch_size_));
}

private:
void PostInit() {
config_.feat_config.nemo_normalize_type =
model_->FeatureNormalizationMethod();

config_.feat_config.low_freq = 0;
// config_.feat_config.high_freq = 8000;
config_.feat_config.is_librosa = true;
config_.feat_config.remove_dc_offset = false;
// config_.feat_config.window_type = "hann";
config_.feat_config.dither = 0;
config_.feat_config.nemo_normalize_type =
model_->FeatureNormalizationMethod();

int32_t vocab_size = model_->VocabSize();

// check the blank ID
if (!symbol_table_.Contains("<blk>")) {
SHERPA_ONNX_LOGE("tokens.txt does not include the blank token <blk>");
exit(-1);
}

if (symbol_table_["<blk>"] != vocab_size - 1) {
SHERPA_ONNX_LOGE("<blk> is not the last token!");
exit(-1);
}

if (symbol_table_.NumSymbols() != vocab_size) {
SHERPA_ONNX_LOGE("number of lines in tokens.txt %d != %d (vocab_size)",
symbol_table_.NumSymbols(), vocab_size);
exit(-1);
}
}

private:
OnlineRecognizerConfig config_;
SymbolTable symbol_table_;
std::unique_ptr<OnlineTransducerNeMoModel> model_;
std::unique_ptr<OnlineTransducerDecoder> decoder_;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can build an instance of OnlineTransducerGreedySearchNeMoDecoder directly.

OnlineTransducerGreedySearchNeMoDecoder does not need to inherit from OnlineTransducerNeMoModel.

Suggested change
std::unique_ptr<OnlineTransducerDecoder> decoder_;
std::unique_ptr<OnlineTransducerGreedySearchNeMoDecoder> decoder_;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I just fixed it now, after reading that greedy search decoder does not inherit form online transducer decoder. thank you


int32_t batch_size_ = 1;
};

} // namespace sherpa_onnx

#endif // SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_
11 changes: 11 additions & 0 deletions sherpa-onnx/csrc/online-stream.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,12 @@ class OnlineStream::Impl {

std::vector<Ort::Value> &GetStates() { return states_; }

void SetNeMoDecoderStates(std::vector<Ort::Value> decoder_states) {
decoder_states_ = std::move(decoder_states);
}

std::vector<Ort::Value> &GetNeMoDecoderStates() { return decoder_states_; }

const ContextGraphPtr &GetContextGraph() const { return context_graph_; }

std::vector<float> &GetParaformerFeatCache() {
Expand Down Expand Up @@ -129,6 +135,7 @@ class OnlineStream::Impl {
TransducerKeywordResult empty_keyword_result_;
OnlineCtcDecoderResult ctc_result_;
std::vector<Ort::Value> states_; // states for transducer or ctc models
std::vector<Ort::Value> decoder_states_; // states for nemo transducer models
std::vector<float> paraformer_feat_cache_;
std::vector<float> paraformer_encoder_out_cache_;
std::vector<float> paraformer_alpha_cache_;
Expand Down Expand Up @@ -218,6 +225,10 @@ std::vector<Ort::Value> &OnlineStream::GetStates() {
return impl_->GetStates();
}

std::vector<Ort::Value> &OnlineStream::GetNeMoDecoderStates() {
return impl_->GetNeMoDecoderStates();
}

const ContextGraphPtr &OnlineStream::GetContextGraph() const {
return impl_->GetContextGraph();
}
Expand Down
3 changes: 3 additions & 0 deletions sherpa-onnx/csrc/online-stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ class OnlineStream {
void SetStates(std::vector<Ort::Value> states);
std::vector<Ort::Value> &GetStates();

void SetNeMoDecoderStates(std::vector<Ort::Value> decoder_states);
std::vector<Ort::Value> &GetNeMoDecoderStates();

/**
* Get the context graph corresponding to this stream.
*
Expand Down
Loading
Loading