Skip to content

Add C++ runtime for *streaming* faster conformer transducer from NeMo. #889

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sherpa-onnx/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ set(sources
online-transducer-model-config.cc
online-transducer-model.cc
online-transducer-modified-beam-search-decoder.cc
online-transducer-nemo-model.cc
online-transducer-greedy-search-nemo-decoder.cc
online-wenet-ctc-model-config.cc
online-wenet-ctc-model.cc
online-zipformer-transducer-model.cc
Expand Down
29 changes: 27 additions & 2 deletions sherpa-onnx/csrc/online-recognizer-impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,27 @@
#include "sherpa-onnx/csrc/online-recognizer-ctc-impl.h"
#include "sherpa-onnx/csrc/online-recognizer-paraformer-impl.h"
#include "sherpa-onnx/csrc/online-recognizer-transducer-impl.h"
#include "sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h"
#include "sherpa-onnx/csrc/onnx-utils.h"

namespace sherpa_onnx {

std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
const OnlineRecognizerConfig &config) {

if (!config.model_config.transducer.encoder.empty()) {
return std::make_unique<OnlineRecognizerTransducerImpl>(config);
Ort::Env env(ORT_LOGGING_LEVEL_WARNING);

auto decoder_model = ReadFile(config.model_config.transducer.decoder);
auto sess = std::make_unique<Ort::Session>(env, decoder_model.data(), decoder_model.size(), Ort::SessionOptions{});

size_t node_count = sess->GetOutputCount();

if (node_count == 1) {
return std::make_unique<OnlineRecognizerTransducerImpl>(config);
} else {
return std::make_unique<OnlineRecognizerTransducerNeMoImpl>(config);
}
}

if (!config.model_config.paraformer.encoder.empty()) {
Expand All @@ -34,7 +48,18 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
AAssetManager *mgr, const OnlineRecognizerConfig &config) {
if (!config.model_config.transducer.encoder.empty()) {
return std::make_unique<OnlineRecognizerTransducerImpl>(mgr, config);
Ort::Env env(ORT_LOGGING_LEVEL_WARNING);

auto decoder_model = ReadFile(mgr, config.model_config.transducer.decoder);
auto sess = std::make_unique<Ort::Session>(env, decoder_model.data(), decoder_model.size(), Ort::SessionOptions{});

size_t node_count = sess->GetOutputCount();

if (node_count == 1) {
return std::make_unique<OnlineRecognizerTransducerImpl>(mgr, config);
} else {
return std::make_unique<OnlineRecognizerTransducerNeMoImpl>(mgr, config);
}
}

if (!config.model_config.paraformer.encoder.empty()) {
Expand Down
205 changes: 205 additions & 0 deletions sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
// sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h
//
// Copyright (c) 2022-2024 Xiaomi Corporation
// Copyright (c) 2024 Sangeet Sagar

#ifndef SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_
#define SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_

#include <fstream>
#include <ios>
#include <memory>
#include <regex> // NOLINT
#include <sstream>
#include <string>
#include <utility>
#include <vector>

#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif

#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/online-recognizer-impl.h"
#include "sherpa-onnx/csrc/online-recognizer.h"
#include "sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h"
#include "sherpa-onnx/csrc/online-transducer-nemo-model.h"
#include "sherpa-onnx/csrc/symbol-table.h"
#include "sherpa-onnx/csrc/transpose.h"
#include "sherpa-onnx/csrc/utils.h"

namespace sherpa_onnx {

// defined in ./online-recognizer-transducer-impl.h
OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
const SymbolTable &sym_table,
float frame_shift_ms,
int32_t subsampling_factor,
int32_t segment,
int32_t frames_since_start);

class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
public:
explicit OnlineRecognizerTransducerNeMoImpl(
const OnlineRecognizerConfig &config)
: config_(config),
symbol_table_(config_.model_config.tokens),
model_(std::make_unique<OnlineTransducerNeMoModel>(
config_.model_config)) {
if (config_.decoding_method == "greedy_search") {
decoder_ = std::make_unique<OnlineTransducerGreedySearchNeMoDecoder>(
model_.get(), config_.blank_penalty);
} else {
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
config_.decoding_method.c_str());
exit(-1);
}
PostInit();
}

#if __ANDROID_API__ >= 9
explicit OnlineRecognizerTransducerNeMoImpl(
AAssetManager *mgr, const OnlineRecognizerConfig &config)
: config_(config),
symbol_table_(mgr, config_.model_config.tokens),
model_(std::make_unique<OnlineTransducerNeMoModel>(
mgr, config_.model_config)) {
if (config_.decoding_method == "greedy_search") {
decoder_ = std::make_unique<OnlineTransducerGreedySearchNeMoDecoder>(
model_.get(), config_.blank_penalty);
} else {
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
config_.decoding_method.c_str());
exit(-1);
}

PostInit();
}
#endif

std::unique_ptr<OnlineStream> CreateStream() const override {
auto stream = std::make_unique<OnlineStream>(config_.feat_config);
stream->SetStates(model_->GetInitStates());
InitOnlineStream(stream.get());
return stream;
}

void DecodeStreams(OnlineStream **ss, int32_t n) const override {
int32_t chunk_size = model_->ChunkSize();
int32_t chunk_shift = model_->ChunkShift();

int32_t feature_dim = ss[0]->FeatureDim();

std::vector<OnlineTransducerDecoderResult> results(n);
std::vector<float> features_vec(n * chunk_size * feature_dim);
std::vector<std::vector<Ort::Value>> states_vec(n);
std::vector<int64_t> all_processed_frames(n);

for (int32_t i = 0; i != n; ++i) {
const auto num_processed_frames = ss[i]->GetNumProcessedFrames();
std::vector<float> features =
ss[i]->GetFrames(num_processed_frames, chunk_size);

// Question: should num_processed_frames include chunk_shift?
ss[i]->GetNumProcessedFrames() += chunk_shift;

std::copy(features.begin(), features.end(),
features_vec.data() + i * chunk_size * feature_dim);

results[i] = std::move(ss[i]->GetResult());
states_vec[i] = std::move(ss[i]->GetStates());
all_processed_frames[i] = num_processed_frames;
}

auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);

std::array<int64_t, 3> x_shape{n, chunk_size, feature_dim};

Ort::Value x = Ort::Value::CreateTensor(memory_info, features_vec.data(),
features_vec.size(), x_shape.data(),
x_shape.size());

auto states = model_->StackStates(states_vec);
int32_t num_states = states.size();
auto t = model_->RunEncoder(std::move(x), std::move(states));
// t[0] encoder_out, float tensor, (batch_size, dim, T)
// t[1] encoder_out_length, int64 tensor, (batch_size,)

std::vector<Ort::Value> out_states;
out_states.reserve(num_states);

for (int32_t k = 1; k != num_states + 1; ++k) {
out_states.push_back(std::move(t[k]));
}

Ort::Value encoder_out = Transpose12(model_->Allocator(), &t[0]);

// defined in online-transducer-greedy-search-nemo-decoder.h
decoder_-> Decode(std::move(encoder_out), std::move(t[1]),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By the way, you don't need to pass the encoder model states to the greedy search decoder.

Please pass the decoder model states to it instead.

Please read carefully our python streaming transducer greedy search decoding example.

https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/nemo/fast-conformer-hybrid-transducer-ctc/test-onnx-transducer.py#L190

I have posted the example here one more time in case you have not read it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

decoder_-> Decode(std::move(encoder_out), std::move(out_states), &results, ss, n);

is this correct?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. out_states is from the encoder.

Remember that out_states is used only for the internal states of the encoder model. We don't need to use it in the greedy search decoding.

We need to pass the LSTM states from the decoder model to the greedy search decoder.

I suggest you again that you re-read the python decoding example and figure out how the decoding works.

(you need to know how LSTM works)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

working on it.

Copy link
Contributor Author

@sangeet2020 sangeet2020 May 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand that when initializing the model in beginning, init_cache_state intializes the initial states of the encoder model. Further when decoding begins, before the first chunk is decoded, decoder model comes into action and initializes the decoder states. After a chunk has been decoded, it emits the next states for the decoder model which becomes the current states for the next chunk.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, so you don't need the encoder states during greedy search decoding.

Copy link
Contributor Author

@sangeet2020 sangeet2020 May 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

    std::vector<Ort::Value> decoder_states = model_->GetDecoderInitStates(1);
    decoder_-> Decode(std::move(encoder_out), std::move(decoder_states), &results, ss, n);
    }

GetNeMoDecoderStates fetches the intial states of the decoder. But I am not really sure of the implementation here done above.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please read

  • (1) our python code for nemo streaming transducer greedy search decoding
  • (2) our c++ code for nemo non-streaming transducer greedy search decoding

make sure you indeed understand the code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi @csukuangfj ,
thanks again.
I did, and I do understand the code.

i see in offline-transducer-greedy-search-nemo-decoder.cc how RunDecoder method takes the initial state of the decoder.

      model->RunDecoder(std::move(decoder_input_pair.first),
                        std::move(decoder_input_pair.second),
                        model->GetDecoderInitStates(1));

and I realize that above, I did it similar way. I am missing where exactly I am doing wrong.

Copy link
Contributor Author

@sangeet2020 sangeet2020 May 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I revised the code

    // defined in online-transducer-greedy-search-nemo-decoder.h
    std::vector<Ort::Value> decoder_states = model_->GetDecoderInitStates(1);
    // updated decoder states are returned
    decoder_states = decoder_->Decode(std::move(encoder_out), 
                                      std::move(decoder_states), 
                                      &results, ss, n);

    std::vector<std::vector<Ort::Value>> next_states =
        model_->UnStackStates(decoder_states);

Is this correct?

std::move(out_states), &results, ss, n);

std::vector<std::vector<Ort::Value>> next_states =
model_->UnStackStates(out_states);

for (int32_t i = 0; i != n; ++i) {
ss[i]->SetResult(results[i]);
ss[i]->SetNeMoDecoderStates(std::move(next_states[i]));
}
}

void InitOnlineStream(OnlineStream *stream) const {
auto r = decoder_->GetEmptyResult();

stream->SetResult(r);
stream->SetNeMoDecoderStates(model_->GetDecoderInitStates(batch_size_));
}

private:
void PostInit() {
config_.feat_config.nemo_normalize_type =
model_->FeatureNormalizationMethod();

config_.feat_config.low_freq = 0;
// config_.feat_config.high_freq = 8000;
config_.feat_config.is_librosa = true;
config_.feat_config.remove_dc_offset = false;
// config_.feat_config.window_type = "hann";
config_.feat_config.dither = 0;
config_.feat_config.nemo_normalize_type =
model_->FeatureNormalizationMethod();

int32_t vocab_size = model_->VocabSize();

// check the blank ID
if (!symbol_table_.Contains("<blk>")) {
SHERPA_ONNX_LOGE("tokens.txt does not include the blank token <blk>");
exit(-1);
}

if (symbol_table_["<blk>"] != vocab_size - 1) {
SHERPA_ONNX_LOGE("<blk> is not the last token!");
exit(-1);
}

if (symbol_table_.NumSymbols() != vocab_size) {
SHERPA_ONNX_LOGE("number of lines in tokens.txt %d != %d (vocab_size)",
symbol_table_.NumSymbols(), vocab_size);
exit(-1);
}

}

private:
OnlineRecognizerConfig config_;
SymbolTable symbol_table_;
std::unique_ptr<OnlineTransducerNeMoModel> model_;
std::unique_ptr<OnlineTransducerGreedySearchNeMoDecoder> decoder_;

int32_t batch_size_ = 1;
};

} // namespace sherpa_onnx

#endif // SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_
15 changes: 15 additions & 0 deletions sherpa-onnx/csrc/online-stream.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,12 @@ class OnlineStream::Impl {

std::vector<Ort::Value> &GetStates() { return states_; }

void SetNeMoDecoderStates(std::vector<Ort::Value> decoder_states) {
decoder_states_ = std::move(decoder_states);
}

std::vector<Ort::Value> &GetNeMoDecoderStates() { return decoder_states_; }

const ContextGraphPtr &GetContextGraph() const { return context_graph_; }

std::vector<float> &GetParaformerFeatCache() {
Expand Down Expand Up @@ -129,6 +135,7 @@ class OnlineStream::Impl {
TransducerKeywordResult empty_keyword_result_;
OnlineCtcDecoderResult ctc_result_;
std::vector<Ort::Value> states_; // states for transducer or ctc models
std::vector<Ort::Value> decoder_states_; // states for nemo transducer models
std::vector<float> paraformer_feat_cache_;
std::vector<float> paraformer_encoder_out_cache_;
std::vector<float> paraformer_alpha_cache_;
Expand Down Expand Up @@ -218,6 +225,14 @@ std::vector<Ort::Value> &OnlineStream::GetStates() {
return impl_->GetStates();
}

void OnlineStream::SetNeMoDecoderStates(std::vector<Ort::Value> decoder_states) {
return impl_->SetNeMoDecoderStates(std::move(decoder_states));
}

std::vector<Ort::Value> &OnlineStream::GetNeMoDecoderStates() {
return impl_->GetNeMoDecoderStates();
}

const ContextGraphPtr &OnlineStream::GetContextGraph() const {
return impl_->GetContextGraph();
}
Expand Down
3 changes: 3 additions & 0 deletions sherpa-onnx/csrc/online-stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ class OnlineStream {
void SetStates(std::vector<Ort::Value> states);
std::vector<Ort::Value> &GetStates();

void SetNeMoDecoderStates(std::vector<Ort::Value> decoder_states);
std::vector<Ort::Value> &GetNeMoDecoderStates();

/**
* Get the context graph corresponding to this stream.
*
Expand Down
Loading
Loading