From cdca4e65e9c55ab6d536e20cbccd824d1c69f1e4 Mon Sep 17 00:00:00 2001 From: Sangeet Sagar <15uec053@lnmiit.ac.in> Date: Fri, 17 May 2024 12:44:49 +0200 Subject: [PATCH 01/14] adding online nemo transducer model files --- .../csrc/online-transducer-nemo-model.cc | 434 ++++++++++++++++++ .../csrc/online-transducer-nemo-model.h | 151 ++++++ 2 files changed, 585 insertions(+) create mode 100644 sherpa-onnx/csrc/online-transducer-nemo-model.cc create mode 100644 sherpa-onnx/csrc/online-transducer-nemo-model.h diff --git a/sherpa-onnx/csrc/online-transducer-nemo-model.cc b/sherpa-onnx/csrc/online-transducer-nemo-model.cc new file mode 100644 index 0000000000..685a5ed27f --- /dev/null +++ b/sherpa-onnx/csrc/online-transducer-nemo-model.cc @@ -0,0 +1,434 @@ +// sherpa-onnx/csrc/online-transducer-nemo-model.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-onnx/csrc/online-transducer-nemo-model.h" + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/online-transducer-decoder.h" +#include "sherpa-onnx/csrc/onnx-utils.h" +#include "sherpa-onnx/csrc/session.h" +#include "sherpa-onnx/csrc/transpose.h" +#include "sherpa-onnx/csrc/unbind.h" + +namespace sherpa_onnx { + +class OnlineTransducerNeMoModel::Impl { + public: + explicit Impl(const OnlineModelConfig &config) + : config_(config), + env_(ORT_LOGGING_LEVEL_WARNING), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(config.transducer.encoder); + InitEncoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(config.transducer.decoder); + InitDecoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(config.transducer.joiner); + InitJoiner(buf.data(), buf.size()); + } + } + +#if __ANDROID_API__ >= 9 + Impl(AAssetManager *mgr, const OnlineModelConfig &config) + : config_(config), + env_(ORT_LOGGING_LEVEL_WARNING), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(mgr, config.transducer.encoder_filename); + InitEncoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(mgr, config.transducer.decoder_filename); + InitDecoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(mgr, config.transducer.joiner_filename); + InitJoiner(buf.data(), buf.size()); + } + } +#endif + + std::vector StackStates( + const std::vector> &states) const { + int32_t batch_size = static_cast(states.size()); + int32_t num_encoders = static_cast(num_encoder_layers_.size()); + + std::vector buf(batch_size); + + std::vector ans; + int32_t num_states = static_cast(states[0].size()); + ans.reserve(num_states); + + for (int32_t i = 0; i != (num_states - 2) / 6; ++i) { + { + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = &states[n][6 * i]; + } + auto v = Cat(allocator_, buf, 1); + ans.push_back(std::move(v)); + } + { + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = &states[n][6 * i + 1]; + } + auto v = Cat(allocator_, buf, 1); + ans.push_back(std::move(v)); + } + { + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = &states[n][6 * i + 2]; + } + auto v = Cat(allocator_, buf, 1); + ans.push_back(std::move(v)); + } + { + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = &states[n][6 * i + 3]; + } + auto v = Cat(allocator_, buf, 1); + ans.push_back(std::move(v)); + } + { + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = &states[n][6 * i + 4]; + } + auto v = Cat(allocator_, buf, 0); + ans.push_back(std::move(v)); + } + { + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = &states[n][6 * i + 5]; + } + auto v = Cat(allocator_, buf, 0); + ans.push_back(std::move(v)); + } + } + + { + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = &states[n][num_states - 2]; + } + auto v = Cat(allocator_, buf, 0); + ans.push_back(std::move(v)); + } + + { + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = &states[n][num_states - 1]; + } + auto v = Cat(allocator_, buf, 0); + ans.push_back(std::move(v)); + } + return ans; + } + + std::vector>UnStackStates( + const std::vector &states) const { + int32_t m = std::accumulate(num_encoder_layers_.begin(), + num_encoder_layers_.end(), 0); + assert(states.size() == m * 6 + 2); + + int32_t batch_size = states[0].GetTensorTypeAndShapeInfo().GetShape()[1]; + int32_t num_encoders = num_encoder_layers_.size(); + + std::vector> ans; + ans.resize(batch_size); + + for (int32_t i = 0; i != m; ++i) { + { + auto v = Unbind(allocator_, &states[i * 6], 1); + assert(v.size() == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + { + auto v = Unbind(allocator_, &states[i * 6 + 1], 1); + assert(v.size() == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + { + auto v = Unbind(allocator_, &states[i * 6 + 2], 1); + assert(v.size() == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + { + auto v = Unbind(allocator_, &states[i * 6 + 3], 1); + assert(v.size() == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + { + auto v = Unbind(allocator_, &states[i * 6 + 4], 0); + assert(v.size() == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + { + auto v = Unbind(allocator_, &states[i * 6 + 5], 0); + assert(v.size() == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + } + + { + auto v = Unbind(allocator_, &states[m * 6], 0); + assert(v.size() == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + { + auto v = Unbind(allocator_, &states[m * 6 + 1], 0); + assert(v.size() == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + + return ans; + } + + std::pair>RunEncoder(Ort::Value features, + std::vector states, + Ort::Value /* processed_frames */) { + std::vector encoder_inputs; + encoder_inputs.reserve(1 + states.size()); + + encoder_inputs.push_back(std::move(features)); + for (auto &v : states) { + encoder_inputs.push_back(std::move(v)); + } + + auto encoder_out = encoder_sess_->Run( + {}, encoder_input_names_ptr_.data(), encoder_inputs.data(), + encoder_inputs.size(), encoder_output_names_ptr_.data(), + encoder_output_names_ptr_.size()); + + std::vector next_states; + next_states.reserve(states.size()); + + for (int32_t i = 1; i != static_cast(encoder_out.size()); ++i) { + next_states.push_back(std::move(encoder_out[i])); + } + return {std::move(encoder_out[0]), std::move(next_states)}; + } + + Ort::Value RunDecoder(Ort::Value decoder_input) { + auto decoder_out = decoder_sess_->Run( + {}, decoder_input_names_ptr_.data(), &decoder_input, 1, + decoder_output_names_ptr_.data(), decoder_output_names_ptr_.size()); + return std::move(decoder_out[0]); + } + + Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) { + std::array joiner_input = {std::move(encoder_out), + std::move(decoder_out)}; + auto logit = + joiner_sess_->Run({}, joiner_input_names_ptr_.data(), joiner_input.data(), + joiner_input.size(), joiner_output_names_ptr_.data(), + joiner_output_names_ptr_.size()); + + return std::move(logit[0]); +} + + std::vector GetDecoderInitStates(int32_t batch_size) const { + std::array s0_shape{pred_rnn_layers_, batch_size, pred_hidden_}; + Ort::Value s0 = Ort::Value::CreateTensor(allocator_, s0_shape.data(), + s0_shape.size()); + + Fill(&s0, 0); + + std::array s1_shape{pred_rnn_layers_, batch_size, pred_hidden_}; + + Ort::Value s1 = Ort::Value::CreateTensor(allocator_, s1_shape.data(), + s1_shape.size()); + + Fill(&s1, 0); + + std::vector states; + + states.reserve(2); + states.push_back(std::move(s0)); + states.push_back(std::move(s1)); + + return states; + } + + int32_t SubsamplingFactor() const { return subsampling_factor_; } + int32_t VocabSize() const { return vocab_size_; } + + OrtAllocator *Allocator() const { return allocator_; } + + std::string FeatureNormalizationMethod() const { return normalize_type_; } + +private: + void InitEncoder(void *model_data, size_t model_data_length) { + encoder_sess_ = std::make_unique( + env_, model_data, model_data_length, sess_opts_); + + GetInputNames(encoder_sess_.get(), &encoder_input_names_, + &encoder_input_names_ptr_); + + GetOutputNames(encoder_sess_.get(), &encoder_output_names_, + &encoder_output_names_ptr_); + + // get meta data + Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata(); + if (config_.debug) { + std::ostringstream os; + os << "---encoder---\n"; + PrintModelMetadata(os, meta_data); + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); + } + + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below + SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size"); + + // need to increase by 1 since the blank token is not included in computing + // vocab_size in NeMo. + vocab_size_ += 1; + + SHERPA_ONNX_READ_META_DATA(subsampling_factor_, "subsampling_factor"); + SHERPA_ONNX_READ_META_DATA_STR(normalize_type_, "normalize_type"); + SHERPA_ONNX_READ_META_DATA(pred_rnn_layers_, "pred_rnn_layers"); + SHERPA_ONNX_READ_META_DATA(pred_hidden_, "pred_hidden"); + + if (normalize_type_ == "NA") { + normalize_type_ = ""; + } + } + + void InitDecoder(void *model_data, size_t model_data_length) { + decoder_sess_ = std::make_unique( + env_, model_data, model_data_length, sess_opts_); + + GetInputNames(decoder_sess_.get(), &decoder_input_names_, + &decoder_input_names_ptr_); + + GetOutputNames(decoder_sess_.get(), &decoder_output_names_, + &decoder_output_names_ptr_); + } + + void InitJoiner(void *model_data, size_t model_data_length) { + joiner_sess_ = std::make_unique( + env_, model_data, model_data_length, sess_opts_); + + GetInputNames(joiner_sess_.get(), &joiner_input_names_, + &joiner_input_names_ptr_); + + GetOutputNames(joiner_sess_.get(), &joiner_output_names_, + &joiner_output_names_ptr_); + } + + private: + OnlineModelConfig config_; + Ort::Env env_; + Ort::SessionOptions sess_opts_; + Ort::AllocatorWithDefaultOptions allocator_; + + std::unique_ptr encoder_sess_; + std::unique_ptr decoder_sess_; + std::unique_ptr joiner_sess_; + + std::vector encoder_input_names_; + std::vector encoder_input_names_ptr_; + + std::vector encoder_output_names_; + std::vector encoder_output_names_ptr_; + + std::vector decoder_input_names_; + std::vector decoder_input_names_ptr_; + + std::vector decoder_output_names_; + std::vector decoder_output_names_ptr_; + + std::vector joiner_input_names_; + std::vector joiner_input_names_ptr_; + + std::vector joiner_output_names_; + std::vector joiner_output_names_ptr_; + + int32_t vocab_size_ = 0; + int32_t subsampling_factor_ = 8; + std::string normalize_type_; + int32_t pred_rnn_layers_ = -1; + int32_t pred_hidden_ = -1; +}; + +OnlineTransducerNeMoModel::OnlineTransducerNeMoModel( + const OnlineModelConfig &config) + : impl_(std::make_unique(config)) {} + +#if __ANDROID_API__ >= 9 +OnlineTransducerNeMoModel::OnlineTransducerNeMoModel( + AAssetManager *mgr, const OnlineModelConfig &config) + : impl_(std::make_unique(mgr, config)) {} +#endif + +OnlineTransducerNeMoModel::~OnlineTransducerNeMoModel() = default; + +int32_t ChunkLength() const { return window_size_; } + +int32_t ChunkShift() const { return chunk_shift_; } + +int32_t OnlineTransducerNeMoModel::SubsamplingFactor() const { + return impl_->SubsamplingFactor(); +} + +int32_t OnlineTransducerNeMoModel::VocabSize() const { + return impl_->VocabSize(); +} + +OrtAllocator *OnlineTransducerNeMoModel::Allocator() const { + return impl_->Allocator(); +} + +std::string OnlineTransducerNeMoModel::FeatureNormalizationMethod() const { + return impl_->FeatureNormalizationMethod(); +} + +} // namespace sherpa_onnx \ No newline at end of file diff --git a/sherpa-onnx/csrc/online-transducer-nemo-model.h b/sherpa-onnx/csrc/online-transducer-nemo-model.h new file mode 100644 index 0000000000..e502136d45 --- /dev/null +++ b/sherpa-onnx/csrc/online-transducer-nemo-model.h @@ -0,0 +1,151 @@ +// sherpa-onnx/csrc/online-transducer-nemo-model.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_NEMO_MODEL_H_ +#define SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_NEMO_MODEL_H_ + +#include +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/online-model-config.h" + +namespace sherpa_onnx { + +// see +// https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py#L40 +// Its decoder is stateful, not stateless. +class OnlineTransducerNeMoModel { + public: + explicit OnlineTransducerNeMoModel(const OnlineModelConfig &config); + +#if __ANDROID_API__ >= 9 + OnlineTransducerNeMoModel(AAssetManager *mgr, + const OfflineModelConfig &config); +#endif + + ~OnlineTransducerNeMoModel(); + + /** Stack a list of individual states into a batch. + * + * It is the inverse operation of `UnStackStates`. + * + * @param states states[i] contains the state for the i-th utterance. + * @return Return a single value representing the batched state. + */ + std::vector StackStates( + const std::vector> &states) const; + + /** Unstack a batch state into a list of individual states. + * + * It is the inverse operation of `StackStates`. + * + * @param states A batched state. + * @return ans[i] contains the state for the i-th utterance. + */ + std::vector> UnStackStates( + const std::vector &states) const; + + // /** Get the initial encoder states. + // * + // * @return Return the initial encoder state. + // */ + // std::vector GetEncoderInitStates() = 0; + + /** Run the encoder. + * + * @param features A tensor of shape (N, T, C). It is changed in-place. + * @param states Encoder state of the previous chunk. It is changed in-place. + * @param processed_frames Processed frames before subsampling. It is a 1-D + * tensor with data type int64_t. + * + * @return Return a tuple containing: + * - encoder_out, a tensor of shape (N, T', encoder_out_dim) + * - next_states Encoder state for the next chunk. + */ + std::pair> RunEncoder( + Ort::Value features, std::vector states, + Ort::Value processed_frames) const; // NOLINT + + /** Run the decoder network. + * + * @param targets A int32 tensor of shape (batch_size, 1) + * @param targets_length A int32 tensor of shape (batch_size,) + * @param states The states for the decoder model. + * @return Return a vector: + * - ans[0] is the decoder_out (a float tensor) + * - ans[1] is the decoder_out_length (a int32 tensor) + * - ans[2:] is the states_next + */ + std::pair> RunDecoder( + Ort::Value targets, Ort::Value targets_length, + std::vector states) const; + + std::vector GetDecoderInitStates(int32_t batch_size) const; + + /** Run the joint network. + * + * @param encoder_out Output of the encoder network. + * @param decoder_out Output of the decoder network. + * @return Return a tensor of shape (N, 1, 1, vocab_size) containing logits. + */ + virtual Ort::Value RunJoiner( Ort::Value encoder_out, + Ort::Value decoder_out) const; + + // cache_last_time_dim3 in the model meta_data + int32_t ContextSize() const; + + /** We send this number of feature frames to the encoder at a time. */ + int32_t ChunkSize() const; + + /** Number of input frames to discard after each call to RunEncoder. + * + * For instance, if we have 30 frames, chunk_size=8, chunk_shift=6. + * + * In the first call of RunEncoder, we use frames 0~7 since chunk_size is 8. + * Then we discard frame 0~5 since chunk_shift is 6. + * In the second call of RunEncoder, we use frames 6~13; and then we discard + * frames 6~11. + * In the third call of RunEncoder, we use frames 12~19; and then we discard + * frames 12~16. + * + * Note: ChunkSize() - ChunkShift() == right context size + */ + int32_t ChunkShift() const; + + /** Return the subsampling factor of the model. + */ + int32_t SubsamplingFactor() const; + + int32_t VocabSize() const; + + /** Return an allocator for allocating memory + */ + OrtAllocator *Allocator() const; + + // Possible values: + // - per_feature + // - all_features (not implemented yet) + // - fixed_mean (not implemented) + // - fixed_std (not implemented) + // - or just leave it to empty + // See + // https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/parts/preprocessing/features.py#L59 + // for details + std::string FeatureNormalizationMethod() const; + + private: + class Impl; + std::unique_ptr impl_; + }; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_NEMO_MODEL_H_ From ca4bfe806e8a5876eda581f96364e8f91cfca25e Mon Sep 17 00:00:00 2001 From: Sangeet Sagar <15uec053@lnmiit.ac.in> Date: Sun, 19 May 2024 15:26:08 +0200 Subject: [PATCH 02/14] model file added; necessary methods added --- sherpa-onnx/csrc/CMakeLists.txt | 1 + .../csrc/online-transducer-nemo-model.cc | 300 ++++++++++-------- .../csrc/online-transducer-nemo-model.h | 7 +- 3 files changed, 171 insertions(+), 137 deletions(-) diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index fc32e5a4f2..5a7f165592 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -74,6 +74,7 @@ set(sources online-transducer-model-config.cc online-transducer-model.cc online-transducer-modified-beam-search-decoder.cc + online-transducer-nemo-model.cc online-wenet-ctc-model-config.cc online-wenet-ctc-model.cc online-zipformer-transducer-model.cc diff --git a/sherpa-onnx/csrc/online-transducer-nemo-model.cc b/sherpa-onnx/csrc/online-transducer-nemo-model.cc index 685a5ed27f..27b3ecccfd 100644 --- a/sherpa-onnx/csrc/online-transducer-nemo-model.cc +++ b/sherpa-onnx/csrc/online-transducer-nemo-model.cc @@ -20,10 +20,12 @@ #include "android/asset_manager_jni.h" #endif +#include "sherpa-onnx/csrc/cat.h" #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/online-transducer-decoder.h" #include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/session.h" +#include "sherpa-onnx/csrc/text-utils.h" #include "sherpa-onnx/csrc/transpose.h" #include "sherpa-onnx/csrc/unbind.h" @@ -76,163 +78,74 @@ class OnlineTransducerNeMoModel::Impl { #endif std::vector StackStates( - const std::vector> &states) const { + std::vector> states) const { int32_t batch_size = static_cast(states.size()); - int32_t num_encoders = static_cast(num_encoder_layers_.size()); + if (batch_size == 1) { + return std::move(states[0]); + } + + std::vector ans; + // stack cache_last_channel std::vector buf(batch_size); - std::vector ans; - int32_t num_states = static_cast(states[0].size()); - ans.reserve(num_states); - - for (int32_t i = 0; i != (num_states - 2) / 6; ++i) { - { - for (int32_t n = 0; n != batch_size; ++n) { - buf[n] = &states[n][6 * i]; - } - auto v = Cat(allocator_, buf, 1); - ans.push_back(std::move(v)); - } - { - for (int32_t n = 0; n != batch_size; ++n) { - buf[n] = &states[n][6 * i + 1]; - } - auto v = Cat(allocator_, buf, 1); - ans.push_back(std::move(v)); - } - { - for (int32_t n = 0; n != batch_size; ++n) { - buf[n] = &states[n][6 * i + 2]; - } - auto v = Cat(allocator_, buf, 1); - ans.push_back(std::move(v)); - } - { - for (int32_t n = 0; n != batch_size; ++n) { - buf[n] = &states[n][6 * i + 3]; - } - auto v = Cat(allocator_, buf, 1); - ans.push_back(std::move(v)); - } - { - for (int32_t n = 0; n != batch_size; ++n) { - buf[n] = &states[n][6 * i + 4]; - } - auto v = Cat(allocator_, buf, 0); - ans.push_back(std::move(v)); - } - { - for (int32_t n = 0; n != batch_size; ++n) { - buf[n] = &states[n][6 * i + 5]; - } - auto v = Cat(allocator_, buf, 0); - ans.push_back(std::move(v)); - } - } + // there are 3 states to be stacked + for (int32_t i = 0; i != 3; ++i) { + buf.clear(); + buf.reserve(batch_size); - { - for (int32_t n = 0; n != batch_size; ++n) { - buf[n] = &states[n][num_states - 2]; + for (int32_t b = 0; b != batch_size; ++b) { + assert(states[b].size() == 3); + buf.push_back(&states[b][i]); } - auto v = Cat(allocator_, buf, 0); - ans.push_back(std::move(v)); - } - { - for (int32_t n = 0; n != batch_size; ++n) { - buf[n] = &states[n][num_states - 1]; + Ort::Value c{nullptr}; + if (i == 2) { + c = Cat(allocator_, buf, 0); + } else { + c = Cat(allocator_, buf, 0); } - auto v = Cat(allocator_, buf, 0); - ans.push_back(std::move(v)); + + ans.push_back(std::move(c)); } + return ans; } - std::vector>UnStackStates( - const std::vector &states) const { - int32_t m = std::accumulate(num_encoder_layers_.begin(), - num_encoder_layers_.end(), 0); - assert(states.size() == m * 6 + 2); - - int32_t batch_size = states[0].GetTensorTypeAndShapeInfo().GetShape()[1]; - int32_t num_encoders = num_encoder_layers_.size(); + std::vector> UnStackStates( + std::vector states) const { + assert(states.size() == 3); std::vector> ans; - ans.resize(batch_size); - - for (int32_t i = 0; i != m; ++i) { - { - auto v = Unbind(allocator_, &states[i * 6], 1); - assert(v.size() == batch_size); - for (int32_t n = 0; n != batch_size; ++n) { - ans[n].push_back(std::move(v[n])); - } - } - { - auto v = Unbind(allocator_, &states[i * 6 + 1], 1); - assert(v.size() == batch_size); - - for (int32_t n = 0; n != batch_size; ++n) { - ans[n].push_back(std::move(v[n])); - } - } - { - auto v = Unbind(allocator_, &states[i * 6 + 2], 1); - assert(v.size() == batch_size); - - for (int32_t n = 0; n != batch_size; ++n) { - ans[n].push_back(std::move(v[n])); - } - } - { - auto v = Unbind(allocator_, &states[i * 6 + 3], 1); - assert(v.size() == batch_size); - - for (int32_t n = 0; n != batch_size; ++n) { - ans[n].push_back(std::move(v[n])); - } - } - { - auto v = Unbind(allocator_, &states[i * 6 + 4], 0); - assert(v.size() == batch_size); - - for (int32_t n = 0; n != batch_size; ++n) { - ans[n].push_back(std::move(v[n])); - } - } - { - auto v = Unbind(allocator_, &states[i * 6 + 5], 0); - assert(v.size() == batch_size); + auto shape = states[0].GetTensorTypeAndShapeInfo().GetShape(); + int32_t batch_size = shape[0]; + ans.resize(batch_size); - for (int32_t n = 0; n != batch_size; ++n) { - ans[n].push_back(std::move(v[n])); - } - } + if (batch_size == 1) { + ans[0] = std::move(states); + return ans; } - { - auto v = Unbind(allocator_, &states[m * 6], 0); - assert(v.size() == batch_size); - - for (int32_t n = 0; n != batch_size; ++n) { - ans[n].push_back(std::move(v[n])); + for (int32_t i = 0; i != 3; ++i) { + std::vector v; + if (i == 2) { + v = Unbind(allocator_, &states[i], 0); + } else { + v = Unbind(allocator_, &states[i], 0); } - } - { - auto v = Unbind(allocator_, &states[m * 6 + 1], 0); + assert(v.size() == batch_size); - for (int32_t n = 0; n != batch_size; ++n) { - ans[n].push_back(std::move(v[n])); + for (int32_t b = 0; b != batch_size; ++b) { + ans[b].push_back(std::move(v[b])); } } return ans; } - std::pair>RunEncoder(Ort::Value features, + std::pair> RunEncoder(Ort::Value features, std::vector states, Ort::Value /* processed_frames */) { std::vector encoder_inputs; @@ -257,11 +170,37 @@ class OnlineTransducerNeMoModel::Impl { return {std::move(encoder_out[0]), std::move(next_states)}; } - Ort::Value RunDecoder(Ort::Value decoder_input) { + std::pair> RunDecoder( + Ort::Value targets, Ort::Value targets_length, + std::vector states) { + std::vector decoder_inputs; + decoder_inputs.reserve(2 + states.size()); + + decoder_inputs.push_back(std::move(targets)); + decoder_inputs.push_back(std::move(targets_length)); + + for (auto &s : states) { + decoder_inputs.push_back(std::move(s)); + } + auto decoder_out = decoder_sess_->Run( - {}, decoder_input_names_ptr_.data(), &decoder_input, 1, - decoder_output_names_ptr_.data(), decoder_output_names_ptr_.size()); - return std::move(decoder_out[0]); + {}, decoder_input_names_ptr_.data(), decoder_inputs.data(), + decoder_inputs.size(), decoder_output_names_ptr_.data(), + decoder_output_names_ptr_.size()); + + std::vector states_next; + states_next.reserve(states.size()); + + // decoder_out[0]: decoder_output + // decoder_out[1]: decoder_output_length + // decoder_out[2:] states_next + + for (int32_t i = 0; i != states.size(); ++i) { + states_next.push_back(std::move(decoder_out[i + 2])); + } + + // we discard decoder_out[1] + return {std::move(decoder_out[0]), std::move(states_next)}; } Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) { @@ -298,7 +237,13 @@ class OnlineTransducerNeMoModel::Impl { return states; } + + int32_t ChunkSize() const { return window_size_; } + + int32_t ChunkShift() const { return chunk_shift_; } + int32_t SubsamplingFactor() const { return subsampling_factor_; } + int32_t VocabSize() const { return vocab_size_; } OrtAllocator *Allocator() const { return allocator_; } @@ -332,14 +277,54 @@ class OnlineTransducerNeMoModel::Impl { // vocab_size in NeMo. vocab_size_ += 1; + SHERPA_ONNX_READ_META_DATA(window_size_, "window_size"); + SHERPA_ONNX_READ_META_DATA(chunk_shift_, "chunk_shift"); SHERPA_ONNX_READ_META_DATA(subsampling_factor_, "subsampling_factor"); SHERPA_ONNX_READ_META_DATA_STR(normalize_type_, "normalize_type"); SHERPA_ONNX_READ_META_DATA(pred_rnn_layers_, "pred_rnn_layers"); SHERPA_ONNX_READ_META_DATA(pred_hidden_, "pred_hidden"); + SHERPA_ONNX_READ_META_DATA(cache_last_channel_dim1_, + "cache_last_channel_dim1"); + SHERPA_ONNX_READ_META_DATA(cache_last_channel_dim2_, + "cache_last_channel_dim2"); + SHERPA_ONNX_READ_META_DATA(cache_last_channel_dim3_, + "cache_last_channel_dim3"); + SHERPA_ONNX_READ_META_DATA(cache_last_time_dim1_, "cache_last_time_dim1"); + SHERPA_ONNX_READ_META_DATA(cache_last_time_dim2_, "cache_last_time_dim2"); + SHERPA_ONNX_READ_META_DATA(cache_last_time_dim3_, "cache_last_time_dim3"); + if (normalize_type_ == "NA") { normalize_type_ = ""; } + + InitStates(); + } + + void InitStates() { + std::array cache_last_channel_shape{1, cache_last_channel_dim1_, + cache_last_channel_dim2_, + cache_last_channel_dim3_}; + + cache_last_channel_ = Ort::Value::CreateTensor( + allocator_, cache_last_channel_shape.data(), + cache_last_channel_shape.size()); + + Fill(&cache_last_channel_, 0); + + std::array cache_last_time_shape{ + 1, cache_last_time_dim1_, cache_last_time_dim2_, cache_last_time_dim3_}; + + cache_last_time_ = Ort::Value::CreateTensor( + allocator_, cache_last_time_shape.data(), cache_last_time_shape.size()); + + Fill(&cache_last_time_, 0); + + int64_t shape = 1; + cache_last_channel_len_ = + Ort::Value::CreateTensor(allocator_, &shape, 1); + + cache_last_channel_len_.GetTensorMutableData()[0] = 0; } void InitDecoder(void *model_data, size_t model_data_length) { @@ -392,11 +377,24 @@ class OnlineTransducerNeMoModel::Impl { std::vector joiner_output_names_; std::vector joiner_output_names_ptr_; + int32_t window_size_; + int32_t chunk_shift_; int32_t vocab_size_ = 0; int32_t subsampling_factor_ = 8; std::string normalize_type_; int32_t pred_rnn_layers_ = -1; int32_t pred_hidden_ = -1; + + int32_t cache_last_channel_dim1_; + int32_t cache_last_channel_dim2_; + int32_t cache_last_channel_dim3_; + int32_t cache_last_time_dim1_; + int32_t cache_last_time_dim2_; + int32_t cache_last_time_dim3_; + + Ort::Value cache_last_channel_{nullptr}; + Ort::Value cache_last_time_{nullptr}; + Ort::Value cache_last_channel_len_{nullptr}; }; OnlineTransducerNeMoModel::OnlineTransducerNeMoModel( @@ -411,9 +409,39 @@ OnlineTransducerNeMoModel::OnlineTransducerNeMoModel( OnlineTransducerNeMoModel::~OnlineTransducerNeMoModel() = default; -int32_t ChunkLength() const { return window_size_; } +std::pair> +OnlineTransducerNeMoModel::RunEncoder(Ort::Value features, + std::vector states, + Ort::Value processed_frames) const { +return impl_->RunEncoder(std::move(features), std::move(states), std::move(processed_frames)); +} + +std::pair> +OnlineTransducerNeMoModel::RunDecoder(Ort::Value targets, + Ort::Value targets_length, + std::vector states) const { + return impl_->RunDecoder(std::move(targets), std::move(targets_length), + std::move(states)); +} + +std::vector OnlineTransducerNeMoModel::GetDecoderInitStates( + int32_t batch_size) const { + return impl_->GetDecoderInitStates(batch_size); +} + +Ort::Value OnlineTransducerNeMoModel::RunJoiner(Ort::Value encoder_out, + Ort::Value decoder_out) const { + return impl_->RunJoiner(std::move(encoder_out), std::move(decoder_out)); +} -int32_t ChunkShift() const { return chunk_shift_; } + +int32_t OnlineTransducerNeMoModel::ChunkSize() const { + return impl_->ChunkSize(); + } + +int32_t OnlineTransducerNeMoModel::ChunkShift() const { + return impl_->ChunkShift(); + } int32_t OnlineTransducerNeMoModel::SubsamplingFactor() const { return impl_->SubsamplingFactor(); diff --git a/sherpa-onnx/csrc/online-transducer-nemo-model.h b/sherpa-onnx/csrc/online-transducer-nemo-model.h index e502136d45..4b6270f190 100644 --- a/sherpa-onnx/csrc/online-transducer-nemo-model.h +++ b/sherpa-onnx/csrc/online-transducer-nemo-model.h @@ -32,7 +32,12 @@ class OnlineTransducerNeMoModel { #endif ~OnlineTransducerNeMoModel(); - + // A list of 3 tensors: + // - cache_last_channel + // - cache_last_time + // - cache_last_channel_len + std::vector GetInitStates() const; + /** Stack a list of individual states into a batch. * * It is the inverse operation of `UnStackStates`. From 2bb7d7ecaba6091182bb94495d344602dc7dd1bd Mon Sep 17 00:00:00 2001 From: Sangeet Sagar <15uec053@lnmiit.ac.in> Date: Sun, 19 May 2024 17:39:38 +0200 Subject: [PATCH 03/14] cc file outline added --- ...e-transducer-greedy-search-nemo-decoder.cc | 156 ++++++++++++++++++ ...ne-transducer-greedy-search-nemo-decoder.h | 33 ++++ 2 files changed, 189 insertions(+) create mode 100644 sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.cc create mode 100644 sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h diff --git a/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.cc b/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.cc new file mode 100644 index 0000000000..d3cba31d71 --- /dev/null +++ b/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.cc @@ -0,0 +1,156 @@ +// sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h" + +#include +#include +#include + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/onnx-utils.h" + +namespace sherpa_onnx { + +static void UseCachedDecoderOut( + const std::vector &results, + Ort::Value *decoder_out) { + std::vector shape = + decoder_out->GetTensorTypeAndShapeInfo().GetShape(); + float *dst = decoder_out->GetTensorMutableData(); + for (const auto &r : results) { + if (r.decoder_out) { + const float *src = r.decoder_out.GetTensorData(); + std::copy(src, src + shape[1], dst); + } + dst += shape[1]; + } +} + +static void UpdateCachedDecoderOut( + OrtAllocator *allocator, const Ort::Value *decoder_out, + std::vector *results) { + std::vector shape = + decoder_out->GetTensorTypeAndShapeInfo().GetShape(); + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + std::array v_shape{1, shape[1]}; + + const float *src = decoder_out->GetTensorData(); + for (auto &r : *results) { + if (!r.decoder_out) { + r.decoder_out = Ort::Value::CreateTensor(allocator, v_shape.data(), + v_shape.size()); + } + + float *dst = r.decoder_out.GetTensorMutableData(); + std::copy(src, src + shape[1], dst); + src += shape[1]; + } +} + +static std::pair BuildDecoderInput( + int32_t token, OrtAllocator *allocator) { + std::array shape{1, 1}; + + Ort::Value decoder_input = + Ort::Value::CreateTensor(allocator, shape.data(), shape.size()); + + std::array length_shape{1}; + Ort::Value decoder_input_length = Ort::Value::CreateTensor( + allocator, length_shape.data(), length_shape.size()); + + int32_t *p = decoder_input.GetTensorMutableData(); + + int32_t *p_length = decoder_input_length.GetTensorMutableData(); + + p[0] = token; + + p_length[0] = 1; + + return {std::move(decoder_input), std::move(decoder_input_length)}; +} + +static OnlineTransducerDecoderResult DecodeOne( + const float *p, int32_t num_rows, int32_t num_cols, + OnlineTransducerNeMoModel *model, float blank_penalty) { + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + OnlineTransducerDecoderResult ans; + + int32_t vocab_size = model->VocabSize(); + int32_t blank_id = vocab_size - 1; + + auto decoder_input_pair = BuildDecoderInput(blank_id, model->Allocator()); + + std::pair> decoder_output_pair = + model->RunDecoder(std::move(decoder_input_pair.first), + std::move(decoder_input_pair.second), + model->GetDecoderInitStates(1)); + + std::array encoder_shape{1, num_cols, 1}; + + for (int32_t t = 0; t != num_rows; ++t) { + Ort::Value cur_encoder_out = Ort::Value::CreateTensor( + memory_info, const_cast(p) + t * num_cols, num_cols, + encoder_shape.data(), encoder_shape.size()); + + Ort::Value logit = model->RunJoiner(std::move(cur_encoder_out), + View(&decoder_output_pair.first)); + + float *p_logit = logit.GetTensorMutableData(); + if (blank_penalty > 0) { + + p_logit[blank_id] -= blank_penalty; + } + + auto y = static_cast(std::distance( + static_cast(p_logit), + std::max_element(static_cast(p_logit), + static_cast(p_logit) + vocab_size))); + + if (y != blank_id) { + ans.tokens.push_back(y); + ans.timestamps.push_back(t); + + decoder_input_pair = BuildDecoderInput(y, model->Allocator()); + + decoder_output_pair = + model->RunDecoder(std::move(decoder_input_pair.first), + std::move(decoder_input_pair.second), + std::move(decoder_output_pair.second)); + } // if (y != blank_id) + } // for (int32_t i = 0; i != num_rows; ++i) + + return ans; +} + +std::vector +OnlineTransducerGreedySearchNeMoDecoder::Decode( + Ort::Value encoder_out, + std::vector *result) { + auto shape = encoder_out.GetTensorTypeAndShapeInfo().GetShape(); + + int32_t batch_size = static_cast(shape[0]); + int32_t dim1 = static_cast(shape[1]); + int32_t dim2 = static_cast(shape[2]); + + const float *p = encoder_out.GetTensorData(); + + // checking for non-null elements in results + + // create a new tensor with modified shape based on + // the first element of result and use cached decoder_out + // values if available. + + // For each frame (num of frames is given by dim2), compute logits, + // determine tokens, and update results, + // then regenerate decoder output + // if tokens are emitted. + + // call UpdateCachedDecoderOut and update frame offset +} + +} // namespace sherpa_onnx \ No newline at end of file diff --git a/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h b/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h new file mode 100644 index 0000000000..d465f94c3b --- /dev/null +++ b/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h @@ -0,0 +1,33 @@ +// sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_GREEDY_SEARCH_NEMO_DECODER_H_ +#define SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_GREEDY_SEARCH_NEMO_DECODER_H_ + +#include + +#include "sherpa-onnx/csrc/online-transducer-decoder.h" +#include "sherpa-onnx/csrc/online-transducer-nemo-model.h" + +namespace sherpa_onnx { + +class OnlineTransducerGreedySearchNeMoDecoder + : public OnlineTransducerDecoder { + public: + OnlineTransducerGreedySearchNeMoDecoder(OnlineTransducerNeMoModel *model, + float blank_penalty) + : model_(model), blank_penalty_(blank_penalty) {} + + std::vector Decode( + Ort::Value encoder_out, Ort::Value encoder_out_length, + OnlineStream **ss = nullptr, int32_t n = 0) override; + + private: + OnlineTransducerNeMoModel *model_; // Not owned + float blank_penalty_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_GREEDY_SEARCH_NEMO_DECODER_H_ From 44f8d8c190d1430e4e6c20c6d4b1a79b75563454 Mon Sep 17 00:00:00 2001 From: Sangeet Sagar <15uec053@lnmiit.ac.in> Date: Wed, 22 May 2024 12:17:01 +0200 Subject: [PATCH 04/14] add transducer decoding script --- .../online-recognizer-transducer-nemo-impl.h | 201 ++++++++++++++++++ sherpa-onnx/csrc/online-stream.cc | 11 + sherpa-onnx/csrc/online-stream.h | 3 + .../csrc/online-transducer-nemo-model.cc | 8 +- .../csrc/online-transducer-nemo-model.h | 2 +- 5 files changed, 220 insertions(+), 5 deletions(-) create mode 100644 sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h diff --git a/sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h b/sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h new file mode 100644 index 0000000000..324eaccafe --- /dev/null +++ b/sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h @@ -0,0 +1,201 @@ +// sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h +// +// Copyright (c) 2022-2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_ +#define SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_ + +#include +#include +#include +#include // NOLINT +#include +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/online-recognizer-impl.h" +#include "sherpa-onnx/csrc/online-recognizer.h" +#include "sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h" +#include "sherpa-onnx/csrc/online-transducer-nemo-model.h" +#include "sherpa-onnx/csrc/pad-sequence.h" +#include "sherpa-onnx/csrc/symbol-table.h" +#include "sherpa-onnx/csrc/transpose.h" +#include "sherpa-onnx/csrc/utils.h" + +namespace sherpa_onnx { + +// defined in ./online-recognizer-transducer-impl.h +OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src, + const SymbolTable &sym_table, + float frame_shift_ms, + int32_t subsampling_factor, + int32_t segment, + int32_t frames_since_start); + +class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { + public: + explicit OnlineRecognizerTransducerNeMoImpl( + const OnlineRecognizerConfig &config) + : config_(config), + symbol_table_(config_.model_config.tokens), + model_(std::make_unique( + config_.model_config)) { + if (config_.decoding_method == "greedy_search") { + decoder_ = std::make_unique( + model_.get(), config_.blank_penalty); + } else { + SHERPA_ONNX_LOGE("Unsupported decoding method: %s", + config_.decoding_method.c_str()); + exit(-1); + } + PostInit(); + } + +#if __ANDROID_API__ >= 9 + explicit OnlineRecognizerTransducerNeMoImpl( + AAssetManager *mgr, const OnlineRecognizerConfig &config) + : config_(config), + symbol_table_(mgr, config_.model_config.tokens), + model_(std::make_unique( + mgr, config_.model_config)) { + if (config_.decoding_method == "greedy_search") { + decoder_ = std::make_unique( + model_.get(), config_.blank_penalty); + } else { + SHERPA_ONNX_LOGE("Unsupported decoding method: %s", + config_.decoding_method.c_str()); + exit(-1); + } + + PostInit(); + } +#endif + + std::unique_ptr CreateStream() const override { + auto stream = std::make_unique(config_.feat_config); + InitOnlineStream(stream.get()); + return stream; + } + + void DecodeStreams(OnlineStream **ss, int32_t n) const override { + int32_t chunk_size = model_->ChunkSize(); + int32_t chunk_shift = model_->ChunkShift(); + + int32_t feature_dim = ss[0]->FeatureDim(); + + std::vector results(n); + std::vector features_vec(n * chunk_size * feature_dim); + std::vector> states_vec(n); + std::vector all_processed_frames(n); + + for (int32_t i = 0; i != n; ++i) { + const auto num_processed_frames = ss[i]->GetNumProcessedFrames(); + std::vector features = + ss[i]->GetFrames(num_processed_frames, chunk_size); + + // Question: should num_processed_frames include chunk_shift? + ss[i]->GetNumProcessedFrames() += chunk_shift; + + std::copy(features.begin(), features.end(), + features_vec.data() + i * chunk_size * feature_dim); + + results[i] = std::move(ss[i]->GetResult()); + states_vec[i] = std::move(ss[i]->GetStates()); + all_processed_frames[i] = num_processed_frames; + } + + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + std::array x_shape{n, chunk_size, feature_dim}; + + Ort::Value x = Ort::Value::CreateTensor(memory_info, features_vec.data(), + features_vec.size(), x_shape.data(), + x_shape.size()); + + std::array processed_frames_shape{ + static_cast(all_processed_frames.size())}; + + Ort::Value processed_frames = Ort::Value::CreateTensor( + memory_info, all_processed_frames.data(), all_processed_frames.size(), + processed_frames_shape.data(), processed_frames_shape.size()); + + auto states = model_->StackStates(states_vec); + + auto [t, ns] = model_->RunEncoder(std::move(x), std::move(states), + std::move(processed_frames)); + + Ort::Value encoder_out = Transpose12(model_->Allocator(), &t[0]); + + // defined in online-transducer-greedy-search-nemo-decoder.h + auto results = decoder_-> Decode(std::move(encoder_out), std::move(t[1])); + + std::vector> next_states = + model_->UnStackStates(ns); + + for (int32_t i = 0; i != n; ++i) { + ss[i]->SetResult(results[i]); + ss[i]->SetNeMoDecoderStates(std::move(next_states[i])); + } + } + + void InitOnlineStream(OnlineStream *stream) const { + auto r = decoder_->GetEmptyResult(); + + stream->SetResult(r); + stream->SetNeMoDecoderStates(model_->GetDecoderInitStates(batch_size_)); + } + + private: + void PostInit() { + config_.feat_config.nemo_normalize_type = + model_->FeatureNormalizationMethod(); + + config_.feat_config.low_freq = 0; + // config_.feat_config.high_freq = 8000; + config_.feat_config.is_librosa = true; + config_.feat_config.remove_dc_offset = false; + // config_.feat_config.window_type = "hann"; + config_.feat_config.dither = 0; + config_.feat_config.nemo_normalize_type = + model_->FeatureNormalizationMethod(); + + int32_t vocab_size = model_->VocabSize(); + + // check the blank ID + if (!symbol_table_.Contains("")) { + SHERPA_ONNX_LOGE("tokens.txt does not include the blank token "); + exit(-1); + } + + if (symbol_table_[""] != vocab_size - 1) { + SHERPA_ONNX_LOGE(" is not the last token!"); + exit(-1); + } + + if (symbol_table_.NumSymbols() != vocab_size) { + SHERPA_ONNX_LOGE("number of lines in tokens.txt %d != %d (vocab_size)", + symbol_table_.NumSymbols(), vocab_size); + exit(-1); + } + } + + private: + OnlineRecognizerConfig config_; + SymbolTable symbol_table_; + std::unique_ptr model_; + std::unique_ptr decoder_; + + int32_t batch_size_ = 1; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_ \ No newline at end of file diff --git a/sherpa-onnx/csrc/online-stream.cc b/sherpa-onnx/csrc/online-stream.cc index 52cfb899f5..b1d35dc553 100644 --- a/sherpa-onnx/csrc/online-stream.cc +++ b/sherpa-onnx/csrc/online-stream.cc @@ -90,6 +90,12 @@ class OnlineStream::Impl { std::vector &GetStates() { return states_; } + void SetNeMoDecoderStates(std::vector decoder_states) { + decoder_states_ = std::move(decoder_states); + } + + std::vector &GetNeMoDecoderStates() { return decoder_states_; } + const ContextGraphPtr &GetContextGraph() const { return context_graph_; } std::vector &GetParaformerFeatCache() { @@ -129,6 +135,7 @@ class OnlineStream::Impl { TransducerKeywordResult empty_keyword_result_; OnlineCtcDecoderResult ctc_result_; std::vector states_; // states for transducer or ctc models + std::vector decoder_states_; // states for nemo transducer models std::vector paraformer_feat_cache_; std::vector paraformer_encoder_out_cache_; std::vector paraformer_alpha_cache_; @@ -218,6 +225,10 @@ std::vector &OnlineStream::GetStates() { return impl_->GetStates(); } +std::vector &OnlineStream::GetNeMoDecoderStates() { + return impl_->GetNeMoDecoderStates(); +} + const ContextGraphPtr &OnlineStream::GetContextGraph() const { return impl_->GetContextGraph(); } diff --git a/sherpa-onnx/csrc/online-stream.h b/sherpa-onnx/csrc/online-stream.h index 49b7f7402b..4e444366ee 100644 --- a/sherpa-onnx/csrc/online-stream.h +++ b/sherpa-onnx/csrc/online-stream.h @@ -91,6 +91,9 @@ class OnlineStream { void SetStates(std::vector states); std::vector &GetStates(); + void SetNeMoDecoderStates(std::vector decoder_states); + std::vector &GetNeMoDecoderStates(); + /** * Get the context graph corresponding to this stream. * diff --git a/sherpa-onnx/csrc/online-transducer-nemo-model.cc b/sherpa-onnx/csrc/online-transducer-nemo-model.cc index 27b3ecccfd..5f02255cd1 100644 --- a/sherpa-onnx/csrc/online-transducer-nemo-model.cc +++ b/sherpa-onnx/csrc/online-transducer-nemo-model.cc @@ -145,7 +145,7 @@ class OnlineTransducerNeMoModel::Impl { return ans; } - std::pair> RunEncoder(Ort::Value features, + std::pair, std::vector> RunEncoder(Ort::Value features, std::vector states, Ort::Value /* processed_frames */) { std::vector encoder_inputs; @@ -167,7 +167,7 @@ class OnlineTransducerNeMoModel::Impl { for (int32_t i = 1; i != static_cast(encoder_out.size()); ++i) { next_states.push_back(std::move(encoder_out[i])); } - return {std::move(encoder_out[0]), std::move(next_states)}; + return {std::move(encoder_out), std::move(next_states)}; } std::pair> RunDecoder( @@ -409,11 +409,11 @@ OnlineTransducerNeMoModel::OnlineTransducerNeMoModel( OnlineTransducerNeMoModel::~OnlineTransducerNeMoModel() = default; -std::pair> +std::pair, std::vector> OnlineTransducerNeMoModel::RunEncoder(Ort::Value features, std::vector states, Ort::Value processed_frames) const { -return impl_->RunEncoder(std::move(features), std::move(states), std::move(processed_frames)); + return impl_->RunEncoder(std::move(features), std::move(states), std::move(processed_frames)); } std::pair> diff --git a/sherpa-onnx/csrc/online-transducer-nemo-model.h b/sherpa-onnx/csrc/online-transducer-nemo-model.h index 4b6270f190..6f6779962a 100644 --- a/sherpa-onnx/csrc/online-transducer-nemo-model.h +++ b/sherpa-onnx/csrc/online-transducer-nemo-model.h @@ -75,7 +75,7 @@ class OnlineTransducerNeMoModel { * - encoder_out, a tensor of shape (N, T', encoder_out_dim) * - next_states Encoder state for the next chunk. */ - std::pair> RunEncoder( + std::pair, std::vector> RunEncoder( Ort::Value features, std::vector states, Ort::Value processed_frames) const; // NOLINT From afb10d46e393271e104118d85ac7159626aa6991 Mon Sep 17 00:00:00 2001 From: Sangeet Sagar <15uec053@lnmiit.ac.in> Date: Thu, 23 May 2024 12:16:29 +0200 Subject: [PATCH 05/14] add support for nemo transducer --- sherpa-onnx/csrc/CMakeLists.txt | 1 + sherpa-onnx/csrc/online-recognizer-impl.cc | 29 ++++- .../online-recognizer-transducer-nemo-impl.h | 8 +- ...e-transducer-greedy-search-nemo-decoder.cc | 112 +++++++----------- ...ne-transducer-greedy-search-nemo-decoder.h | 17 +-- 5 files changed, 85 insertions(+), 82 deletions(-) diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index 5a7f165592..214fc94b45 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -75,6 +75,7 @@ set(sources online-transducer-model.cc online-transducer-modified-beam-search-decoder.cc online-transducer-nemo-model.cc + online-transducer-greedy-search-nemo-decoder.cc online-wenet-ctc-model-config.cc online-wenet-ctc-model.cc online-zipformer-transducer-model.cc diff --git a/sherpa-onnx/csrc/online-recognizer-impl.cc b/sherpa-onnx/csrc/online-recognizer-impl.cc index 56da814f7a..42204baf84 100644 --- a/sherpa-onnx/csrc/online-recognizer-impl.cc +++ b/sherpa-onnx/csrc/online-recognizer-impl.cc @@ -7,13 +7,27 @@ #include "sherpa-onnx/csrc/online-recognizer-ctc-impl.h" #include "sherpa-onnx/csrc/online-recognizer-paraformer-impl.h" #include "sherpa-onnx/csrc/online-recognizer-transducer-impl.h" +#include "sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h" +#include "sherpa-onnx/csrc/onnx-utils.h" namespace sherpa_onnx { std::unique_ptr OnlineRecognizerImpl::Create( const OnlineRecognizerConfig &config) { + if (!config.model_config.transducer.encoder.empty()) { - return std::make_unique(config); + Ort::Env env(ORT_LOGGING_LEVEL_WARNING); + + auto decoder_model = ReadFile(config.model_config.transducer.decoder); + auto sess = std::make_unique(env, decoder_model.data(), decoder_model.size(), Ort::SessionOptions{}); + + size_t node_count = sess->GetOutputCount(); + + if (node_count == 1) { + return std::make_unique(config); + } else { + return std::make_unique(config); + } } if (!config.model_config.paraformer.encoder.empty()) { @@ -34,7 +48,18 @@ std::unique_ptr OnlineRecognizerImpl::Create( std::unique_ptr OnlineRecognizerImpl::Create( AAssetManager *mgr, const OnlineRecognizerConfig &config) { if (!config.model_config.transducer.encoder.empty()) { - return std::make_unique(mgr, config); + Ort::Env env(ORT_LOGGING_LEVEL_WARNING); + + auto decoder_model = ReadFile(config.model_config.transducer.decoder); + auto sess = std::make_unique(env, decoder_model.data(), decoder_model.size(), Ort::SessionOptions{}); + + size_t node_count = sess->GetOutputCount(); + + if (node_count == 1) { + return std::make_unique(mgr, config); + } else { + return std::make_unique(mgr, config); + } } if (!config.model_config.paraformer.encoder.empty()) { diff --git a/sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h b/sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h index 324eaccafe..8ac8e15f3a 100644 --- a/sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h @@ -131,11 +131,13 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { auto [t, ns] = model_->RunEncoder(std::move(x), std::move(states), std::move(processed_frames)); - + // t[0] encoder_out, float tensor, (batch_size, dim, T) + // t[1] encoder_out_length, int64 tensor, (batch_size,) + Ort::Value encoder_out = Transpose12(model_->Allocator(), &t[0]); // defined in online-transducer-greedy-search-nemo-decoder.h - auto results = decoder_-> Decode(std::move(encoder_out), std::move(t[1])); + std::vector results = decoder_-> Decode(std::move(encoder_out), std::move(t[1])); std::vector> next_states = model_->UnStackStates(ns); @@ -193,7 +195,7 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { std::unique_ptr model_; std::unique_ptr decoder_; - int32_t batch_size_ = 1; + int32_t batch_size_ = 1; }; } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.cc b/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.cc index d3cba31d71..b8abb78ab0 100644 --- a/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.cc +++ b/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.cc @@ -1,55 +1,18 @@ // sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.cc // -// Copyright (c) 2023 Xiaomi Corporation +// Copyright (c) 2024 Xiaomi Corporation #include "sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h" #include +#include #include -#include #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/onnx-utils.h" namespace sherpa_onnx { -static void UseCachedDecoderOut( - const std::vector &results, - Ort::Value *decoder_out) { - std::vector shape = - decoder_out->GetTensorTypeAndShapeInfo().GetShape(); - float *dst = decoder_out->GetTensorMutableData(); - for (const auto &r : results) { - if (r.decoder_out) { - const float *src = r.decoder_out.GetTensorData(); - std::copy(src, src + shape[1], dst); - } - dst += shape[1]; - } -} - -static void UpdateCachedDecoderOut( - OrtAllocator *allocator, const Ort::Value *decoder_out, - std::vector *results) { - std::vector shape = - decoder_out->GetTensorTypeAndShapeInfo().GetShape(); - auto memory_info = - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); - std::array v_shape{1, shape[1]}; - - const float *src = decoder_out->GetTensorData(); - for (auto &r : *results) { - if (!r.decoder_out) { - r.decoder_out = Ort::Value::CreateTensor(allocator, v_shape.data(), - v_shape.size()); - } - - float *dst = r.decoder_out.GetTensorMutableData(); - std::copy(src, src + shape[1], dst); - src += shape[1]; - } -} - static std::pair BuildDecoderInput( int32_t token, OrtAllocator *allocator) { std::array shape{1, 1}; @@ -72,37 +35,44 @@ static std::pair BuildDecoderInput( return {std::move(decoder_input), std::move(decoder_input_length)}; } +OnlineTransducerGreedySearchNeMoDecoder::OnlineTransducerGreedySearchNeMoDecoder( + OnlineTransducerNeMoModel *model, float blank_penalty) + : model_(model), blank_penalty_(blank_penalty) { + // Initialize decoder state + auto init_states = model_->GetDecoderInitStates(1); + decoder_states_ = std::move(init_states); +} + static OnlineTransducerDecoderResult DecodeOne( - const float *p, int32_t num_rows, int32_t num_cols, - OnlineTransducerNeMoModel *model, float blank_penalty) { + const float *encoder_out, int32_t num_rows, int32_t num_cols, + OnlineTransducerNeMoModel *model, float blank_penalty, + std::vector& decoder_states) { + auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); - OnlineTransducerDecoderResult ans; - + OnlineTransducerDecoderResult result; int32_t vocab_size = model->VocabSize(); int32_t blank_id = vocab_size - 1; auto decoder_input_pair = BuildDecoderInput(blank_id, model->Allocator()); - std::pair> decoder_output_pair = model->RunDecoder(std::move(decoder_input_pair.first), - std::move(decoder_input_pair.second), - model->GetDecoderInitStates(1)); + std::move(decoder_input_pair.second), + std::move(decoder_states)); std::array encoder_shape{1, num_cols, 1}; for (int32_t t = 0; t != num_rows; ++t) { Ort::Value cur_encoder_out = Ort::Value::CreateTensor( - memory_info, const_cast(p) + t * num_cols, num_cols, + memory_info, const_cast(encoder_out) + t * num_cols, num_cols, encoder_shape.data(), encoder_shape.size()); Ort::Value logit = model->RunJoiner(std::move(cur_encoder_out), - View(&decoder_output_pair.first)); + View(&decoder_output_pair.first)); float *p_logit = logit.GetTensorMutableData(); if (blank_penalty > 0) { - p_logit[blank_id] -= blank_penalty; } @@ -112,45 +82,47 @@ static OnlineTransducerDecoderResult DecodeOne( static_cast(p_logit) + vocab_size))); if (y != blank_id) { - ans.tokens.push_back(y); - ans.timestamps.push_back(t); + result.tokens.push_back(y); + result.timestamps.push_back(t); decoder_input_pair = BuildDecoderInput(y, model->Allocator()); decoder_output_pair = model->RunDecoder(std::move(decoder_input_pair.first), - std::move(decoder_input_pair.second), - std::move(decoder_output_pair.second)); - } // if (y != blank_id) - } // for (int32_t i = 0; i != num_rows; ++i) + std::move(decoder_input_pair.second), + std::move(decoder_output_pair.second)); + } // if (y != blank_id) + } // for (int32_t t = 0; t != num_rows; ++t) - return ans; + // Update the decoder states for the next chunk + decoder_states = std::move(decoder_output_pair.second); + + return result; } std::vector OnlineTransducerGreedySearchNeMoDecoder::Decode( - Ort::Value encoder_out, - std::vector *result) { - auto shape = encoder_out.GetTensorTypeAndShapeInfo().GetShape(); + Ort::Value encoder_out, Ort::Value encoder_out_length, + OnlineStream ** /*ss = nullptr*/, int32_t /*n= 0*/) { + auto shape = encoder_out.GetTensorTypeAndShapeInfo().GetShape(); int32_t batch_size = static_cast(shape[0]); int32_t dim1 = static_cast(shape[1]); int32_t dim2 = static_cast(shape[2]); + const int64_t *p_length = encoder_out_length.GetTensorData(); const float *p = encoder_out.GetTensorData(); - - // checking for non-null elements in results - // create a new tensor with modified shape based on - // the first element of result and use cached decoder_out - // values if available. + std::vector ans(batch_size); - // For each frame (num of frames is given by dim2), compute logits, - // determine tokens, and update results, - // then regenerate decoder output - // if tokens are emitted. + for (int32_t i = 0; i != batch_size; ++i) { + const float *this_p = p + dim1 * dim2 * i; + int32_t this_len = p_length[i]; - // call UpdateCachedDecoderOut and update frame offset + ans[i] = DecodeOne(this_p, this_len, dim2, model_, blank_penalty_, decoder_states_); + } + + return ans; } -} // namespace sherpa_onnx \ No newline at end of file +} // namespace sherpa_onnx \ No newline at end of file diff --git a/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h b/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h index d465f94c3b..1216cf0fa4 100644 --- a/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h +++ b/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h @@ -6,28 +6,31 @@ #define SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_GREEDY_SEARCH_NEMO_DECODER_H_ #include - #include "sherpa-onnx/csrc/online-transducer-decoder.h" #include "sherpa-onnx/csrc/online-transducer-nemo-model.h" namespace sherpa_onnx { -class OnlineTransducerGreedySearchNeMoDecoder - : public OnlineTransducerDecoder { +class OnlineTransducerGreedySearchNeMoDecoder : public OnlineTransducerDecoder { public: OnlineTransducerGreedySearchNeMoDecoder(OnlineTransducerNeMoModel *model, - float blank_penalty) - : model_(model), blank_penalty_(blank_penalty) {} + float blank_penalty); std::vector Decode( - Ort::Value encoder_out, Ort::Value encoder_out_length, - OnlineStream **ss = nullptr, int32_t n = 0) override; + Ort::Value encoder_out, Ort::Value encoder_out_length, + OnlineStream **ss = nullptr, int32_t n = 0); private: + OnlineTransducerDecoderResult DecodeChunkByChunk(const float *encoder_out, + int32_t num_rows, + int32_t num_cols); + OnlineTransducerNeMoModel *model_; // Not owned float blank_penalty_; + std::vector decoder_states_; // Decoder states to be maintained across chunks }; } // namespace sherpa_onnx #endif // SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_GREEDY_SEARCH_NEMO_DECODER_H_ + From 7800cc0ed827cf78a4513f662f44135194458b9d Mon Sep 17 00:00:00 2001 From: Sangeet Sagar <15uec053@lnmiit.ac.in> Date: Thu, 23 May 2024 17:33:10 +0200 Subject: [PATCH 06/14] fixed deocder method to take states of previous chunks --- sherpa-onnx/csrc/online-recognizer-impl.cc | 2 +- .../online-recognizer-transducer-nemo-impl.h | 28 ++++--- sherpa-onnx/csrc/online-stream.cc | 4 + .../csrc/online-transducer-nemo-model.cc | 84 ++++++++++++++----- .../csrc/online-transducer-nemo-model.h | 24 +++--- 5 files changed, 93 insertions(+), 49 deletions(-) diff --git a/sherpa-onnx/csrc/online-recognizer-impl.cc b/sherpa-onnx/csrc/online-recognizer-impl.cc index 42204baf84..1ce1d54925 100644 --- a/sherpa-onnx/csrc/online-recognizer-impl.cc +++ b/sherpa-onnx/csrc/online-recognizer-impl.cc @@ -50,7 +50,7 @@ std::unique_ptr OnlineRecognizerImpl::Create( if (!config.model_config.transducer.encoder.empty()) { Ort::Env env(ORT_LOGGING_LEVEL_WARNING); - auto decoder_model = ReadFile(config.model_config.transducer.decoder); + auto decoder_model = ReadFile(mgr, config.model_config.transducer.decoder); auto sess = std::make_unique(env, decoder_model.data(), decoder_model.size(), Ort::SessionOptions{}); size_t node_count = sess->GetOutputCount(); diff --git a/sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h b/sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h index 8ac8e15f3a..9043260cac 100644 --- a/sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h @@ -1,6 +1,7 @@ // sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h // // Copyright (c) 2022-2024 Xiaomi Corporation +// Copyright (c) 2024 Sangeet Sagar #ifndef SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_ #define SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_ @@ -24,7 +25,6 @@ #include "sherpa-onnx/csrc/online-recognizer.h" #include "sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h" #include "sherpa-onnx/csrc/online-transducer-nemo-model.h" -#include "sherpa-onnx/csrc/pad-sequence.h" #include "sherpa-onnx/csrc/symbol-table.h" #include "sherpa-onnx/csrc/transpose.h" #include "sherpa-onnx/csrc/utils.h" @@ -80,6 +80,7 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { std::unique_ptr CreateStream() const override { auto stream = std::make_unique(config_.feat_config); + stream->SetStates(model_->GetInitStates()); InitOnlineStream(stream.get()); return stream; } @@ -120,27 +121,27 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { features_vec.size(), x_shape.data(), x_shape.size()); - std::array processed_frames_shape{ - static_cast(all_processed_frames.size())}; - - Ort::Value processed_frames = Ort::Value::CreateTensor( - memory_info, all_processed_frames.data(), all_processed_frames.size(), - processed_frames_shape.data(), processed_frames_shape.size()); - auto states = model_->StackStates(states_vec); - - auto [t, ns] = model_->RunEncoder(std::move(x), std::move(states), - std::move(processed_frames)); + int32_t num_states = states.size(); + auto t = model_->RunEncoder(std::move(x), std::move(states)); // t[0] encoder_out, float tensor, (batch_size, dim, T) // t[1] encoder_out_length, int64 tensor, (batch_size,) + std::vector out_states; + out_states.reserve(num_states); + + for (int32_t k = 1; k != num_states + 1; ++k) { + out_states.push_back(std::move(t[k])); + } + Ort::Value encoder_out = Transpose12(model_->Allocator(), &t[0]); // defined in online-transducer-greedy-search-nemo-decoder.h - std::vector results = decoder_-> Decode(std::move(encoder_out), std::move(t[1])); + decoder_-> Decode(std::move(encoder_out), std::move(t[1]), + std::move(out_states), &results, ss, n); std::vector> next_states = - model_->UnStackStates(ns); + model_->UnStackStates(out_states); for (int32_t i = 0; i != n; ++i) { ss[i]->SetResult(results[i]); @@ -187,6 +188,7 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { symbol_table_.NumSymbols(), vocab_size); exit(-1); } + } private: diff --git a/sherpa-onnx/csrc/online-stream.cc b/sherpa-onnx/csrc/online-stream.cc index b1d35dc553..62d93999e0 100644 --- a/sherpa-onnx/csrc/online-stream.cc +++ b/sherpa-onnx/csrc/online-stream.cc @@ -225,6 +225,10 @@ std::vector &OnlineStream::GetStates() { return impl_->GetStates(); } +void OnlineStream::SetNeMoDecoderStates(std::vector decoder_states) { + return impl_->SetNeMoDecoderStates(std::move(decoder_states)); +} + std::vector &OnlineStream::GetNeMoDecoderStates() { return impl_->GetNeMoDecoderStates(); } diff --git a/sherpa-onnx/csrc/online-transducer-nemo-model.cc b/sherpa-onnx/csrc/online-transducer-nemo-model.cc index 5f02255cd1..9f9fe5b625 100644 --- a/sherpa-onnx/csrc/online-transducer-nemo-model.cc +++ b/sherpa-onnx/csrc/online-transducer-nemo-model.cc @@ -1,6 +1,7 @@ // sherpa-onnx/csrc/online-transducer-nemo-model.cc // // Copyright (c) 2024 Xiaomi Corporation +// Copyright (c) 2024 Sangeet Sagar #include "sherpa-onnx/csrc/online-transducer-nemo-model.h" @@ -145,29 +146,51 @@ class OnlineTransducerNeMoModel::Impl { return ans; } - std::pair, std::vector> RunEncoder(Ort::Value features, - std::vector states, - Ort::Value /* processed_frames */) { - std::vector encoder_inputs; - encoder_inputs.reserve(1 + states.size()); + std::vector RunEncoder(Ort::Value features, + std::vector states) { + Ort::Value &cache_last_channel = states[0]; + Ort::Value &cache_last_time = states[1]; + Ort::Value &cache_last_channel_len = states[2]; - encoder_inputs.push_back(std::move(features)); - for (auto &v : states) { - encoder_inputs.push_back(std::move(v)); - } + int32_t batch_size = features.GetTensorTypeAndShapeInfo().GetShape()[0]; + + std::array length_shape{batch_size}; + + Ort::Value length = Ort::Value::CreateTensor( + allocator_, length_shape.data(), length_shape.size()); + + int64_t *p_length = length.GetTensorMutableData(); + + std::fill(p_length, p_length + batch_size, ChunkSize()); - auto encoder_out = encoder_sess_->Run( - {}, encoder_input_names_ptr_.data(), encoder_inputs.data(), - encoder_inputs.size(), encoder_output_names_ptr_.data(), - encoder_output_names_ptr_.size()); + // (B, T, C) -> (B, C, T) + features = Transpose12(allocator_, &features); - std::vector next_states; - next_states.reserve(states.size()); + std::array inputs = { + std::move(features), View(&length), std::move(cache_last_channel), + std::move(cache_last_time), std::move(cache_last_channel_len)}; + + auto out = + encoder_sess_->Run({}, encoder_input_names_ptr_.data(), inputs.data(), inputs.size(), + encoder_output_names_ptr_.data(), encoder_output_names_ptr_.size()); + // out[0]: logit + // out[1] logit_length + // out[2:] states_next + // + // we need to remove out[1] + + std::vector ans; + ans.reserve(out.size() - 1); + + for (int32_t i = 0; i != out.size(); ++i) { + if (i == 1) { + continue; + } - for (int32_t i = 1; i != static_cast(encoder_out.size()); ++i) { - next_states.push_back(std::move(encoder_out[i])); + ans.push_back(std::move(out[i])); } - return {std::move(encoder_out), std::move(next_states)}; + + return ans; } std::pair> RunDecoder( @@ -250,6 +273,20 @@ class OnlineTransducerNeMoModel::Impl { std::string FeatureNormalizationMethod() const { return normalize_type_; } + // Return a vector containing 3 tensors + // - cache_last_channel + // - cache_last_time_ + // - cache_last_channel_len + std::vector GetInitStates() { + std::vector ans; + ans.reserve(3); + ans.push_back(View(&cache_last_channel_)); + ans.push_back(View(&cache_last_time_)); + ans.push_back(View(&cache_last_channel_len_)); + + return ans; + } + private: void InitEncoder(void *model_data, size_t model_data_length) { encoder_sess_ = std::make_unique( @@ -409,11 +446,10 @@ OnlineTransducerNeMoModel::OnlineTransducerNeMoModel( OnlineTransducerNeMoModel::~OnlineTransducerNeMoModel() = default; -std::pair, std::vector> +std::vector OnlineTransducerNeMoModel::RunEncoder(Ort::Value features, - std::vector states, - Ort::Value processed_frames) const { - return impl_->RunEncoder(std::move(features), std::move(states), std::move(processed_frames)); + std::vector states) const { + return impl_->RunEncoder(std::move(features), std::move(states)); } std::pair> @@ -459,4 +495,8 @@ std::string OnlineTransducerNeMoModel::FeatureNormalizationMethod() const { return impl_->FeatureNormalizationMethod(); } +std::vector OnlineTransducerNeMoModel::GetInitStates() const { + return impl_->GetInitStates(); +} + } // namespace sherpa_onnx \ No newline at end of file diff --git a/sherpa-onnx/csrc/online-transducer-nemo-model.h b/sherpa-onnx/csrc/online-transducer-nemo-model.h index 6f6779962a..8d9f926738 100644 --- a/sherpa-onnx/csrc/online-transducer-nemo-model.h +++ b/sherpa-onnx/csrc/online-transducer-nemo-model.h @@ -1,6 +1,8 @@ // sherpa-onnx/csrc/online-transducer-nemo-model.h // // Copyright (c) 2024 Xiaomi Corporation +// Copyright (c) 2024 Sangeet Sagar + #ifndef SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_NEMO_MODEL_H_ #define SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_NEMO_MODEL_H_ @@ -58,26 +60,22 @@ class OnlineTransducerNeMoModel { std::vector> UnStackStates( const std::vector &states) const; - // /** Get the initial encoder states. - // * - // * @return Return the initial encoder state. - // */ - // std::vector GetEncoderInitStates() = 0; + // A list of 3 tensors: + // - cache_last_channel + // - cache_last_time + // - cache_last_channel_len + std::vector GetInitStates() const; /** Run the encoder. * * @param features A tensor of shape (N, T, C). It is changed in-place. - * @param states Encoder state of the previous chunk. It is changed in-place. - * @param processed_frames Processed frames before subsampling. It is a 1-D - * tensor with data type int64_t. - * + * @param states It is from GetInitStates() or returned from this method. + * * @return Return a tuple containing: * - encoder_out, a tensor of shape (N, T', encoder_out_dim) - * - next_states Encoder state for the next chunk. */ - std::pair, std::vector> RunEncoder( - Ort::Value features, std::vector states, - Ort::Value processed_frames) const; // NOLINT + std::vector RunEncoder( + Ort::Value features, std::vector states) const; // NOLINT /** Run the decoder network. * From d47bf6fc32b265d1dfd8ee09f7838ebb59106dab Mon Sep 17 00:00:00 2001 From: Sangeet Sagar <15uec053@lnmiit.ac.in> Date: Thu, 23 May 2024 18:19:35 +0200 Subject: [PATCH 07/14] minor changes --- .../online-recognizer-transducer-nemo-impl.h | 2 +- ...e-transducer-greedy-search-nemo-decoder.cc | 29 +++++++++++-------- ...ne-transducer-greedy-search-nemo-decoder.h | 7 +++-- .../csrc/online-transducer-nemo-model.h | 6 ---- 4 files changed, 23 insertions(+), 21 deletions(-) diff --git a/sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h b/sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h index 9043260cac..4e1867d3a6 100644 --- a/sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h @@ -195,7 +195,7 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { OnlineRecognizerConfig config_; SymbolTable symbol_table_; std::unique_ptr model_; - std::unique_ptr decoder_; + std::unique_ptr decoder_; int32_t batch_size_ = 1; }; diff --git a/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.cc b/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.cc index b8abb78ab0..030255db34 100644 --- a/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.cc +++ b/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.cc @@ -35,15 +35,15 @@ static std::pair BuildDecoderInput( return {std::move(decoder_input), std::move(decoder_input_length)}; } -OnlineTransducerGreedySearchNeMoDecoder::OnlineTransducerGreedySearchNeMoDecoder( - OnlineTransducerNeMoModel *model, float blank_penalty) - : model_(model), blank_penalty_(blank_penalty) { - // Initialize decoder state - auto init_states = model_->GetDecoderInitStates(1); - decoder_states_ = std::move(init_states); -} - -static OnlineTransducerDecoderResult DecodeOne( +// OnlineTransducerGreedySearchNeMoDecoder::OnlineTransducerGreedySearchNeMoDecoder( +// OnlineTransducerNeMoModel *model, float blank_penalty) +// : model_(model), blank_penalty_(blank_penalty) { +// // Initialize decoder state +// auto init_states = model_->GetDecoderInitStates(1); +// decoder_states_ = std::move(init_states); +// } + +std::pair> DecodeOne( const float *encoder_out, int32_t num_rows, int32_t num_cols, OnlineTransducerNeMoModel *model, float blank_penalty, std::vector& decoder_states) { @@ -97,12 +97,15 @@ static OnlineTransducerDecoderResult DecodeOne( // Update the decoder states for the next chunk decoder_states = std::move(decoder_output_pair.second); - return result; + return {result, decoder_states}; } std::vector OnlineTransducerGreedySearchNeMoDecoder::Decode( - Ort::Value encoder_out, Ort::Value encoder_out_length, + Ort::Value encoder_out, + Ort::Value encoder_out_length, + std::vector decoder_states, + std::vector *results, OnlineStream ** /*ss = nullptr*/, int32_t /*n= 0*/) { auto shape = encoder_out.GetTensorTypeAndShapeInfo().GetShape(); @@ -119,7 +122,9 @@ OnlineTransducerGreedySearchNeMoDecoder::Decode( const float *this_p = p + dim1 * dim2 * i; int32_t this_len = p_length[i]; - ans[i] = DecodeOne(this_p, this_len, dim2, model_, blank_penalty_, decoder_states_); + auto decode_result_pair = DecodeOne(this_p, this_len, dim2, model_, blank_penalty_, decoder_states); + ans[i] = decode_result_pair.first; + decoder_states = std::move(decode_result_pair.second); // Update decoder states for next chunk } return ans; diff --git a/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h b/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h index 1216cf0fa4..5b112bb44d 100644 --- a/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h +++ b/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h @@ -17,7 +17,10 @@ class OnlineTransducerGreedySearchNeMoDecoder : public OnlineTransducerDecoder { float blank_penalty); std::vector Decode( - Ort::Value encoder_out, Ort::Value encoder_out_length, + Ort::Value encoder_out, + Ort::Value encoder_out_length, + std::vector decoder_states, + std::vector *results, OnlineStream **ss = nullptr, int32_t n = 0); private: @@ -27,7 +30,7 @@ class OnlineTransducerGreedySearchNeMoDecoder : public OnlineTransducerDecoder { OnlineTransducerNeMoModel *model_; // Not owned float blank_penalty_; - std::vector decoder_states_; // Decoder states to be maintained across chunks + // std::vector decoder_states_; // Decoder states to be maintained across chunks }; } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-transducer-nemo-model.h b/sherpa-onnx/csrc/online-transducer-nemo-model.h index 8d9f926738..f1f46932cc 100644 --- a/sherpa-onnx/csrc/online-transducer-nemo-model.h +++ b/sherpa-onnx/csrc/online-transducer-nemo-model.h @@ -60,12 +60,6 @@ class OnlineTransducerNeMoModel { std::vector> UnStackStates( const std::vector &states) const; - // A list of 3 tensors: - // - cache_last_channel - // - cache_last_time - // - cache_last_channel_len - std::vector GetInitStates() const; - /** Run the encoder. * * @param features A tensor of shape (N, T, C). It is changed in-place. From 7837a5d976951d66e6e560eb79b911c958c02e86 Mon Sep 17 00:00:00 2001 From: Sangeet Sagar <15uec053@lnmiit.ac.in> Date: Fri, 24 May 2024 12:57:41 +0200 Subject: [PATCH 08/14] doc updated, model definitions modified --- .../online-recognizer-transducer-nemo-impl.h | 13 +++++------ ...e-transducer-greedy-search-nemo-decoder.cc | 22 ++++++++++++++----- ...ne-transducer-greedy-search-nemo-decoder.h | 1 - .../csrc/online-transducer-nemo-model.cc | 20 ++++++++++++----- .../csrc/online-transducer-nemo-model.h | 16 +++++--------- 5 files changed, 41 insertions(+), 31 deletions(-) diff --git a/sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h b/sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h index 4e1867d3a6..20539d2173 100644 --- a/sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h @@ -94,8 +94,7 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { std::vector results(n); std::vector features_vec(n * chunk_size * feature_dim); std::vector> states_vec(n); - std::vector all_processed_frames(n); - + for (int32_t i = 0; i != n; ++i) { const auto num_processed_frames = ss[i]->GetNumProcessedFrames(); std::vector features = @@ -109,7 +108,7 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { results[i] = std::move(ss[i]->GetResult()); states_vec[i] = std::move(ss[i]->GetStates()); - all_processed_frames[i] = num_processed_frames; + } auto memory_info = @@ -125,7 +124,7 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { int32_t num_states = states.size(); auto t = model_->RunEncoder(std::move(x), std::move(states)); // t[0] encoder_out, float tensor, (batch_size, dim, T) - // t[1] encoder_out_length, int64 tensor, (batch_size,) + // t[1] next states std::vector out_states; out_states.reserve(num_states); @@ -137,8 +136,7 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { Ort::Value encoder_out = Transpose12(model_->Allocator(), &t[0]); // defined in online-transducer-greedy-search-nemo-decoder.h - decoder_-> Decode(std::move(encoder_out), std::move(t[1]), - std::move(out_states), &results, ss, n); + decoder_-> Decode(std::move(encoder_out), std::move(out_states), &results, ss, n); std::vector> next_states = model_->UnStackStates(out_states); @@ -153,7 +151,7 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { auto r = decoder_->GetEmptyResult(); stream->SetResult(r); - stream->SetNeMoDecoderStates(model_->GetDecoderInitStates(batch_size_)); + stream->SetNeMoDecoderStates(model_->GetDecoderInitStates(1)); } private: @@ -197,7 +195,6 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { std::unique_ptr model_; std::unique_ptr decoder_; - int32_t batch_size_ = 1; }; } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.cc b/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.cc index 030255db34..35c3aa9cad 100644 --- a/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.cc +++ b/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.cc @@ -56,9 +56,10 @@ std::pair> DecodeOne( int32_t blank_id = vocab_size - 1; auto decoder_input_pair = BuildDecoderInput(blank_id, model->Allocator()); + // decoder_input_pair[0]: decoder_input + // decoder_input_pair[1]: decoder_input_length (discarded) std::pair> decoder_output_pair = model->RunDecoder(std::move(decoder_input_pair.first), - std::move(decoder_input_pair.second), std::move(decoder_states)); std::array encoder_shape{1, num_cols, 1}; @@ -89,7 +90,6 @@ std::pair> DecodeOne( decoder_output_pair = model->RunDecoder(std::move(decoder_input_pair.first), - std::move(decoder_input_pair.second), std::move(decoder_output_pair.second)); } // if (y != blank_id) } // for (int32_t t = 0; t != num_rows; ++t) @@ -103,15 +103,25 @@ std::pair> DecodeOne( std::vector OnlineTransducerGreedySearchNeMoDecoder::Decode( Ort::Value encoder_out, - Ort::Value encoder_out_length, std::vector decoder_states, std::vector *results, OnlineStream ** /*ss = nullptr*/, int32_t /*n= 0*/) { auto shape = encoder_out.GetTensorTypeAndShapeInfo().GetShape(); - int32_t batch_size = static_cast(shape[0]); - int32_t dim1 = static_cast(shape[1]); - int32_t dim2 = static_cast(shape[2]); + int32_t batch_size = static_cast(shape[0]); // bs = 1 + int32_t dim1 = static_cast(shape[1]); // feature dimension + int32_t dim2 = static_cast(shape[2]); // frames + + // Define and initialize encoder_out_length + Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); + + int64_t length_value = 1; + std::vector length_shape = {1}; + + Ort::Value encoder_out_length = Ort::Value::CreateTensor( + memory_info, &length_value, 1, length_shape.data(), length_shape.size() + ); + const int64_t *p_length = encoder_out_length.GetTensorData(); const float *p = encoder_out.GetTensorData(); diff --git a/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h b/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h index 5b112bb44d..7d2add5ae0 100644 --- a/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h +++ b/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h @@ -18,7 +18,6 @@ class OnlineTransducerGreedySearchNeMoDecoder : public OnlineTransducerDecoder { std::vector Decode( Ort::Value encoder_out, - Ort::Value encoder_out_length, std::vector decoder_states, std::vector *results, OnlineStream **ss = nullptr, int32_t n = 0); diff --git a/sherpa-onnx/csrc/online-transducer-nemo-model.cc b/sherpa-onnx/csrc/online-transducer-nemo-model.cc index 9f9fe5b625..5b91570d95 100644 --- a/sherpa-onnx/csrc/online-transducer-nemo-model.cc +++ b/sherpa-onnx/csrc/online-transducer-nemo-model.cc @@ -194,8 +194,18 @@ class OnlineTransducerNeMoModel::Impl { } std::pair> RunDecoder( - Ort::Value targets, Ort::Value targets_length, - std::vector states) { + Ort::Value targets, std::vector states) { + + Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); + + // Create the tensor with a single int32_t value of 1 + int32_t length_value = 1; + std::vector length_shape = {1}; + + Ort::Value targets_length = Ort::Value::CreateTensor( + memory_info, &length_value, 1, length_shape.data(), length_shape.size() + ); + std::vector decoder_inputs; decoder_inputs.reserve(2 + states.size()); @@ -215,7 +225,7 @@ class OnlineTransducerNeMoModel::Impl { states_next.reserve(states.size()); // decoder_out[0]: decoder_output - // decoder_out[1]: decoder_output_length + // decoder_out[1]: decoder_output_length (discarded) // decoder_out[2:] states_next for (int32_t i = 0; i != states.size(); ++i) { @@ -454,10 +464,8 @@ OnlineTransducerNeMoModel::RunEncoder(Ort::Value features, std::pair> OnlineTransducerNeMoModel::RunDecoder(Ort::Value targets, - Ort::Value targets_length, std::vector states) const { - return impl_->RunDecoder(std::move(targets), std::move(targets_length), - std::move(states)); + return impl_->RunDecoder(std::move(targets), std::move(states)); } std::vector OnlineTransducerNeMoModel::GetDecoderInitStates( diff --git a/sherpa-onnx/csrc/online-transducer-nemo-model.h b/sherpa-onnx/csrc/online-transducer-nemo-model.h index f1f46932cc..adeb41c4e7 100644 --- a/sherpa-onnx/csrc/online-transducer-nemo-model.h +++ b/sherpa-onnx/csrc/online-transducer-nemo-model.h @@ -30,7 +30,7 @@ class OnlineTransducerNeMoModel { #if __ANDROID_API__ >= 9 OnlineTransducerNeMoModel(AAssetManager *mgr, - const OfflineModelConfig &config); + const OnlineModelConfig &config); #endif ~OnlineTransducerNeMoModel(); @@ -50,7 +50,7 @@ class OnlineTransducerNeMoModel { std::vector StackStates( const std::vector> &states) const; - /** Unstack a batch state into a list of individual states. + /** Unstack a batched state into a list of individual states. * * It is the inverse operation of `StackStates`. * @@ -66,7 +66,8 @@ class OnlineTransducerNeMoModel { * @param states It is from GetInitStates() or returned from this method. * * @return Return a tuple containing: - * - encoder_out, a tensor of shape (N, T', encoder_out_dim) + * - ans[0]: encoder_out, a tensor of shape (N, T', encoder_out_dim) + * - ans[1:]: contains next states */ std::vector RunEncoder( Ort::Value features, std::vector states) const; // NOLINT @@ -74,16 +75,13 @@ class OnlineTransducerNeMoModel { /** Run the decoder network. * * @param targets A int32 tensor of shape (batch_size, 1) - * @param targets_length A int32 tensor of shape (batch_size,) * @param states The states for the decoder model. * @return Return a vector: * - ans[0] is the decoder_out (a float tensor) - * - ans[1] is the decoder_out_length (a int32 tensor) - * - ans[2:] is the states_next + * - ans[1:] is the next states */ std::pair> RunDecoder( - Ort::Value targets, Ort::Value targets_length, - std::vector states) const; + Ort::Value targets, std::vector states) const; std::vector GetDecoderInitStates(int32_t batch_size) const; @@ -96,8 +94,6 @@ class OnlineTransducerNeMoModel { virtual Ort::Value RunJoiner( Ort::Value encoder_out, Ort::Value decoder_out) const; - // cache_last_time_dim3 in the model meta_data - int32_t ContextSize() const; /** We send this number of feature frames to the encoder at a time. */ int32_t ChunkSize() const; From 4c3e741821c002d3b92210b8130f29d835b9bc26 Mon Sep 17 00:00:00 2001 From: Sangeet Sagar <15uec053@lnmiit.ac.in> Date: Fri, 24 May 2024 20:11:31 +0200 Subject: [PATCH 09/14] more fixes, bugs... --- .../online-recognizer-transducer-nemo-impl.h | 25 +++++++++++-------- sherpa-onnx/csrc/online-transducer-decoder.h | 5 ++++ ...e-transducer-greedy-search-nemo-decoder.cc | 19 ++++++-------- ...ne-transducer-greedy-search-nemo-decoder.h | 20 +++++++-------- ...-transducer-modified-beam-search-decoder.h | 5 ++++ 5 files changed, 40 insertions(+), 34 deletions(-) diff --git a/sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h b/sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h index 20539d2173..5cc8d85738 100644 --- a/sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h @@ -44,15 +44,15 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { explicit OnlineRecognizerTransducerNeMoImpl( const OnlineRecognizerConfig &config) : config_(config), - symbol_table_(config_.model_config.tokens), + symbol_table_(config.model_config.tokens), model_(std::make_unique( - config_.model_config)) { - if (config_.decoding_method == "greedy_search") { + config.model_config)) { + if (config.decoding_method == "greedy_search") { decoder_ = std::make_unique( model_.get(), config_.blank_penalty); } else { SHERPA_ONNX_LOGE("Unsupported decoding method: %s", - config_.decoding_method.c_str()); + config.decoding_method.c_str()); exit(-1); } PostInit(); @@ -62,15 +62,15 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { explicit OnlineRecognizerTransducerNeMoImpl( AAssetManager *mgr, const OnlineRecognizerConfig &config) : config_(config), - symbol_table_(mgr, config_.model_config.tokens), + symbol_table_(mgr, config.model_config.tokens), model_(std::make_unique( - mgr, config_.model_config)) { - if (config_.decoding_method == "greedy_search") { + mgr, config.model_config)) { + if (config.decoding_method == "greedy_search") { decoder_ = std::make_unique( model_.get(), config_.blank_penalty); } else { SHERPA_ONNX_LOGE("Unsupported decoding method: %s", - config_.decoding_method.c_str()); + config.decoding_method.c_str()); exit(-1); } @@ -136,10 +136,13 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { Ort::Value encoder_out = Transpose12(model_->Allocator(), &t[0]); // defined in online-transducer-greedy-search-nemo-decoder.h - decoder_-> Decode(std::move(encoder_out), std::move(out_states), &results, ss, n); + std::vector decoder_states = model_->GetDecoderInitStates(1); + decoder_states = decoder_->Decode(std::move(encoder_out), + std::move(decoder_states), + &results, ss, n); std::vector> next_states = - model_->UnStackStates(out_states); + model_->UnStackStates(decoder_states); for (int32_t i = 0; i != n; ++i) { ss[i]->SetResult(results[i]); @@ -193,7 +196,7 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { OnlineRecognizerConfig config_; SymbolTable symbol_table_; std::unique_ptr model_; - std::unique_ptr decoder_; + std::unique_ptr decoder_; }; diff --git a/sherpa-onnx/csrc/online-transducer-decoder.h b/sherpa-onnx/csrc/online-transducer-decoder.h index 6265366044..691b896a71 100644 --- a/sherpa-onnx/csrc/online-transducer-decoder.h +++ b/sherpa-onnx/csrc/online-transducer-decoder.h @@ -82,6 +82,11 @@ class OnlineTransducerDecoder { virtual void Decode(Ort::Value encoder_out, std::vector *result) = 0; + virtual std::vector Decode(Ort::Value encoder_out, + std::vector decoder_states, + std::vector *results, + OnlineStream **ss = nullptr, int32_t n = 0) = 0; + /** Run transducer beam search given the output from the encoder model. * * Note: Currently this interface is for contextual-biasing feature which diff --git a/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.cc b/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.cc index 35c3aa9cad..7d20aafae5 100644 --- a/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.cc +++ b/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.cc @@ -35,14 +35,6 @@ static std::pair BuildDecoderInput( return {std::move(decoder_input), std::move(decoder_input_length)}; } -// OnlineTransducerGreedySearchNeMoDecoder::OnlineTransducerGreedySearchNeMoDecoder( -// OnlineTransducerNeMoModel *model, float blank_penalty) -// : model_(model), blank_penalty_(blank_penalty) { -// // Initialize decoder state -// auto init_states = model_->GetDecoderInitStates(1); -// decoder_states_ = std::move(init_states); -// } - std::pair> DecodeOne( const float *encoder_out, int32_t num_rows, int32_t num_cols, OnlineTransducerNeMoModel *model, float blank_penalty, @@ -100,8 +92,7 @@ std::pair> DecodeOne( return {result, decoder_states}; } -std::vector -OnlineTransducerGreedySearchNeMoDecoder::Decode( +std::vector OnlineTransducerGreedySearchNeMoDecoder::Decode( Ort::Value encoder_out, std::vector decoder_states, std::vector *results, @@ -122,7 +113,6 @@ OnlineTransducerGreedySearchNeMoDecoder::Decode( memory_info, &length_value, 1, length_shape.data(), length_shape.size() ); - const int64_t *p_length = encoder_out_length.GetTensorData(); const float *p = encoder_out.GetTensorData(); @@ -135,9 +125,14 @@ OnlineTransducerGreedySearchNeMoDecoder::Decode( auto decode_result_pair = DecodeOne(this_p, this_len, dim2, model_, blank_penalty_, decoder_states); ans[i] = decode_result_pair.first; decoder_states = std::move(decode_result_pair.second); // Update decoder states for next chunk + + if (results != nullptr && i < results->size()) { + (*results)[i] = ans[i]; + } } + + return decoder_states; - return ans; } } // namespace sherpa_onnx \ No newline at end of file diff --git a/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h b/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h index 7d2add5ae0..f5fb98eb8f 100644 --- a/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h +++ b/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h @@ -14,22 +14,20 @@ namespace sherpa_onnx { class OnlineTransducerGreedySearchNeMoDecoder : public OnlineTransducerDecoder { public: OnlineTransducerGreedySearchNeMoDecoder(OnlineTransducerNeMoModel *model, - float blank_penalty); + float blank_penalty) + : model_(model), + blank_penalty_(blank_penalty){} - std::vector Decode( - Ort::Value encoder_out, - std::vector decoder_states, - std::vector *results, - OnlineStream **ss = nullptr, int32_t n = 0); - private: - OnlineTransducerDecoderResult DecodeChunkByChunk(const float *encoder_out, - int32_t num_rows, - int32_t num_cols); + std::vector Decode( + Ort::Value encoder_out, + std::vector decoder_states, + std::vector *results, + OnlineStream **ss = nullptr, int32_t n = 0) override; + private: OnlineTransducerNeMoModel *model_; // Not owned float blank_penalty_; - // std::vector decoder_states_; // Decoder states to be maintained across chunks }; } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h index 839aa768a4..5cac1c0b20 100644 --- a/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h +++ b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h @@ -39,6 +39,11 @@ class OnlineTransducerModifiedBeamSearchDecoder void Decode(Ort::Value encoder_out, std::vector *result) override; + std::vector Decode(Ort::Value encoder_out, + std::vector decoder_states, + std::vector *results, + OnlineStream **ss = nullptr, int32_t n = 0) override; + void Decode(Ort::Value encoder_out, OnlineStream **ss, std::vector *result) override; From a5c9cc8296ddba329ee18fef8d983f4eb7570494 Mon Sep 17 00:00:00 2001 From: Sangeet Sagar <15uec053@lnmiit.ac.in> Date: Sat, 25 May 2024 01:40:50 +0200 Subject: [PATCH 10/14] revert changes --- .../csrc/online-transducer-modified-beam-search-decoder.cc | 2 +- .../csrc/online-transducer-modified-beam-search-decoder.h | 7 +------ 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc index ea3f78f4bc..07dde94337 100644 --- a/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc +++ b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc @@ -250,4 +250,4 @@ void OnlineTransducerModifiedBeamSearchDecoder::UpdateDecoderOut( result->decoder_out = model_->RunDecoder(std::move(decoder_input)); } -} // namespace sherpa_onnx +} // namespace sherpa_onnx \ No newline at end of file diff --git a/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h index 5cac1c0b20..36324296b0 100644 --- a/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h +++ b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h @@ -39,11 +39,6 @@ class OnlineTransducerModifiedBeamSearchDecoder void Decode(Ort::Value encoder_out, std::vector *result) override; - std::vector Decode(Ort::Value encoder_out, - std::vector decoder_states, - std::vector *results, - OnlineStream **ss = nullptr, int32_t n = 0) override; - void Decode(Ort::Value encoder_out, OnlineStream **ss, std::vector *result) override; @@ -62,4 +57,4 @@ class OnlineTransducerModifiedBeamSearchDecoder } // namespace sherpa_onnx -#endif // SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODIFIED_BEAM_SEARCH_DECODER_H_ +#endif // SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODIFIED_BEAM_SEARCH_DECODER_H_ \ No newline at end of file From 6608ec3d73aab994146d7f8c56083f4178a9524c Mon Sep 17 00:00:00 2001 From: Sangeet Sagar <15uec053@lnmiit.ac.in> Date: Sat, 25 May 2024 02:04:03 +0200 Subject: [PATCH 11/14] Revert files to commit 7837a5d976951d66e6e560eb79b911c958c02e86 --- .../csrc/online-transducer-modified-beam-search-decoder.cc | 2 +- .../csrc/online-transducer-modified-beam-search-decoder.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc index 07dde94337..ea3f78f4bc 100644 --- a/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc +++ b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc @@ -250,4 +250,4 @@ void OnlineTransducerModifiedBeamSearchDecoder::UpdateDecoderOut( result->decoder_out = model_->RunDecoder(std::move(decoder_input)); } -} // namespace sherpa_onnx \ No newline at end of file +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h index 36324296b0..839aa768a4 100644 --- a/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h +++ b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h @@ -57,4 +57,4 @@ class OnlineTransducerModifiedBeamSearchDecoder } // namespace sherpa_onnx -#endif // SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODIFIED_BEAM_SEARCH_DECODER_H_ \ No newline at end of file +#endif // SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODIFIED_BEAM_SEARCH_DECODER_H_ From 72a45c2c7e07cdc9bebc4ceb7e888b33088231d2 Mon Sep 17 00:00:00 2001 From: Sangeet Sagar <15uec053@lnmiit.ac.in> Date: Mon, 27 May 2024 10:06:37 +0200 Subject: [PATCH 12/14] add missing methods in online-recognizer-transducer-nemo-impl.h, other cosmetic changes --- .../online-recognizer-transducer-nemo-impl.h | 75 +++++++++++++++++-- sherpa-onnx/csrc/online-transducer-decoder.h | 4 +- ...e-transducer-greedy-search-nemo-decoder.cc | 26 +++++-- ...ne-transducer-greedy-search-nemo-decoder.h | 14 ++-- .../csrc/online-transducer-nemo-model.h | 4 +- 5 files changed, 100 insertions(+), 23 deletions(-) diff --git a/sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h b/sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h index 5cc8d85738..6a12d65e51 100644 --- a/sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h @@ -32,7 +32,8 @@ namespace sherpa_onnx { // defined in ./online-recognizer-transducer-impl.h -OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src, +// static may or may not be here? TODDOs +static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src, const SymbolTable &sym_table, float frame_shift_ms, int32_t subsampling_factor, @@ -45,6 +46,7 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { const OnlineRecognizerConfig &config) : config_(config), symbol_table_(config.model_config.tokens), + endpoint_(config_.endpoint_config), model_(std::make_unique( config.model_config)) { if (config.decoding_method == "greedy_search") { @@ -63,6 +65,7 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { AAssetManager *mgr, const OnlineRecognizerConfig &config) : config_(config), symbol_table_(mgr, config.model_config.tokens), + endpoint_(mgrconfig_.endpoint_config), model_(std::make_unique( mgr, config.model_config)) { if (config.decoding_method == "greedy_search") { @@ -85,13 +88,70 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { return stream; } + bool IsReady(OnlineStream *s) const override { + return s->GetNumProcessedFrames() + model_->ChunkSize() < + s->NumFramesReady(); + } + + OnlineRecognizerResult GetResult(OnlineStream *s) const override { + OnlineTransducerDecoderResult decoder_result = s->GetResult(); + decoder_->StripLeadingBlanks(&decoder_result); + + // TODO(fangjun): Remember to change these constants if needed + int32_t frame_shift_ms = 10; + int32_t subsampling_factor = 4; + return Convert(decoder_result, symbol_table_, frame_shift_ms, subsampling_factor, + s->GetCurrentSegment(), s->GetNumFramesSinceStart()); + } + + bool IsEndpoint(OnlineStream *s) const override { + if (!config_.enable_endpoint) { + return false; + } + + int32_t num_processed_frames = s->GetNumProcessedFrames(); + + // frame shift is 10 milliseconds + float frame_shift_in_seconds = 0.01; + + // subsampling factor is 4 + int32_t trailing_silence_frames = s->GetResult().num_trailing_blanks * 4; + + return endpoint_.IsEndpoint(num_processed_frames, trailing_silence_frames, + frame_shift_in_seconds); + } + + void Reset(OnlineStream *s) const override { + { + // segment is incremented only when the last + // result is not empty + const auto &r = s->GetResult(); + if (!r.tokens.empty() && r.tokens.back() != 0) { + s->GetCurrentSegment() += 1; + } + } + + // we keep the decoder_out + decoder_->UpdateDecoderOut(&s->GetResult()); + Ort::Value decoder_out = std::move(s->GetResult().decoder_out); + + auto r = decoder_->GetEmptyResult(); + + s->SetResult(r); + s->GetResult().decoder_out = std::move(decoder_out); + + // Note: We only update counters. The underlying audio samples + // are not discarded. + s->Reset(); + } + void DecodeStreams(OnlineStream **ss, int32_t n) const override { int32_t chunk_size = model_->ChunkSize(); int32_t chunk_shift = model_->ChunkShift(); int32_t feature_dim = ss[0]->FeatureDim(); - std::vector results(n); + std::vector result(n); std::vector features_vec(n * chunk_size * feature_dim); std::vector> states_vec(n); @@ -106,7 +166,7 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { std::copy(features.begin(), features.end(), features_vec.data() + i * chunk_size * feature_dim); - results[i] = std::move(ss[i]->GetResult()); + result[i] = std::move(ss[i]->GetResult()); states_vec[i] = std::move(ss[i]->GetStates()); } @@ -137,15 +197,15 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { // defined in online-transducer-greedy-search-nemo-decoder.h std::vector decoder_states = model_->GetDecoderInitStates(1); - decoder_states = decoder_->Decode(std::move(encoder_out), + decoder_states = decoder_->Decode_me(std::move(encoder_out), std::move(decoder_states), - &results, ss, n); + &result, ss, n); std::vector> next_states = model_->UnStackStates(decoder_states); for (int32_t i = 0; i != n; ++i) { - ss[i]->SetResult(results[i]); + ss[i]->SetResult(result[i]); ss[i]->SetNeMoDecoderStates(std::move(next_states[i])); } } @@ -154,7 +214,7 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { auto r = decoder_->GetEmptyResult(); stream->SetResult(r); - stream->SetNeMoDecoderStates(model_->GetDecoderInitStates(1)); + // stream->SetNeMoDecoderStates(model_->GetDecoderInitStates(1)); } private: @@ -197,6 +257,7 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { SymbolTable symbol_table_; std::unique_ptr model_; std::unique_ptr decoder_; + Endpoint endpoint_; }; diff --git a/sherpa-onnx/csrc/online-transducer-decoder.h b/sherpa-onnx/csrc/online-transducer-decoder.h index 691b896a71..0b3ababb46 100644 --- a/sherpa-onnx/csrc/online-transducer-decoder.h +++ b/sherpa-onnx/csrc/online-transducer-decoder.h @@ -82,9 +82,9 @@ class OnlineTransducerDecoder { virtual void Decode(Ort::Value encoder_out, std::vector *result) = 0; - virtual std::vector Decode(Ort::Value encoder_out, + virtual std::vector Decode_me(Ort::Value encoder_out, std::vector decoder_states, - std::vector *results, + std::vector *result, OnlineStream **ss = nullptr, int32_t n = 0) = 0; /** Run transducer beam search given the output from the encoder model. diff --git a/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.cc b/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.cc index 7d20aafae5..7ac91695b9 100644 --- a/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.cc +++ b/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.cc @@ -1,6 +1,7 @@ // sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.cc // -// Copyright (c) 2024 Xiaomi Corporation +// Copyright (c) 2024 Xiaomi Corporation +// Copyright (c) 2024 Sangeet Sagar #include "sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h" @@ -35,11 +36,24 @@ static std::pair BuildDecoderInput( return {std::move(decoder_input), std::move(decoder_input_length)}; } +OnlineTransducerDecoderResult +OnlineTransducerGreedySearchNeMoDecoder::GetEmptyResult() const { + int32_t context_size = 8; + int32_t blank_id = 0; // always 0 + OnlineTransducerDecoderResult r; + r.tokens.resize(context_size, -1); + r.tokens.back() = blank_id; + + return r; +} + + std::pair> DecodeOne( const float *encoder_out, int32_t num_rows, int32_t num_cols, OnlineTransducerNeMoModel *model, float blank_penalty, std::vector& decoder_states) { + // num_rows = frames auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); @@ -89,13 +103,13 @@ std::pair> DecodeOne( // Update the decoder states for the next chunk decoder_states = std::move(decoder_output_pair.second); - return {result, decoder_states}; + return {result, std::move(decoder_states)}; } -std::vector OnlineTransducerGreedySearchNeMoDecoder::Decode( +std::vector OnlineTransducerGreedySearchNeMoDecoder::Decode_me( Ort::Value encoder_out, std::vector decoder_states, - std::vector *results, + std::vector *result, OnlineStream ** /*ss = nullptr*/, int32_t /*n= 0*/) { auto shape = encoder_out.GetTensorTypeAndShapeInfo().GetShape(); @@ -126,8 +140,8 @@ std::vector OnlineTransducerGreedySearchNeMoDecoder::Decode( ans[i] = decode_result_pair.first; decoder_states = std::move(decode_result_pair.second); // Update decoder states for next chunk - if (results != nullptr && i < results->size()) { - (*results)[i] = ans[i]; + if (result != nullptr && i < result->size()) { + (*result)[i] = ans[i]; } } diff --git a/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h b/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h index f5fb98eb8f..d78f3f1250 100644 --- a/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h +++ b/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h @@ -1,6 +1,7 @@ // sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h // // Copyright (c) 2024 Xiaomi Corporation +// Copyright (c) 2024 Sangeet Sagar #ifndef SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_GREEDY_SEARCH_NEMO_DECODER_H_ #define SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_GREEDY_SEARCH_NEMO_DECODER_H_ @@ -18,12 +19,13 @@ class OnlineTransducerGreedySearchNeMoDecoder : public OnlineTransducerDecoder { : model_(model), blank_penalty_(blank_penalty){} - - std::vector Decode( - Ort::Value encoder_out, - std::vector decoder_states, - std::vector *results, - OnlineStream **ss = nullptr, int32_t n = 0) override; + OnlineTransducerDecoderResult GetEmptyResult() const override; + + std::vector Decode_me( + Ort::Value encoder_out, + std::vector decoder_states, + std::vector *result, + OnlineStream **ss = nullptr, int32_t n = 0) override; private: OnlineTransducerNeMoModel *model_; // Not owned diff --git a/sherpa-onnx/csrc/online-transducer-nemo-model.h b/sherpa-onnx/csrc/online-transducer-nemo-model.h index adeb41c4e7..e236221c5b 100644 --- a/sherpa-onnx/csrc/online-transducer-nemo-model.h +++ b/sherpa-onnx/csrc/online-transducer-nemo-model.h @@ -91,8 +91,8 @@ class OnlineTransducerNeMoModel { * @param decoder_out Output of the decoder network. * @return Return a tensor of shape (N, 1, 1, vocab_size) containing logits. */ - virtual Ort::Value RunJoiner( Ort::Value encoder_out, - Ort::Value decoder_out) const; + Ort::Value RunJoiner(Ort::Value encoder_out, + Ort::Value decoder_out) const; /** We send this number of feature frames to the encoder at a time. */ From e1613b6d89550f567deb448915511ba9f4566b2d Mon Sep 17 00:00:00 2001 From: Sangeet Sagar <15uec053@lnmiit.ac.in> Date: Tue, 28 May 2024 10:51:29 +0200 Subject: [PATCH 13/14] remove methods not needed, online-recognizer-transducer-nemo-impl.h updated, compilation success, decoding not working yet --- sherpa-onnx/csrc/online-recognizer-impl.cc | 1 + .../online-recognizer-transducer-nemo-impl.h | 31 ++++--- sherpa-onnx/csrc/online-transducer-decoder.h | 5 - ...e-transducer-greedy-search-nemo-decoder.cc | 92 ++++++++++++++----- ...ne-transducer-greedy-search-nemo-decoder.h | 12 ++- .../csrc/online-transducer-nemo-model.cc | 69 -------------- .../csrc/online-transducer-nemo-model.h | 22 +---- 7 files changed, 95 insertions(+), 137 deletions(-) diff --git a/sherpa-onnx/csrc/online-recognizer-impl.cc b/sherpa-onnx/csrc/online-recognizer-impl.cc index 1ce1d54925..a6f3980166 100644 --- a/sherpa-onnx/csrc/online-recognizer-impl.cc +++ b/sherpa-onnx/csrc/online-recognizer-impl.cc @@ -26,6 +26,7 @@ std::unique_ptr OnlineRecognizerImpl::Create( if (node_count == 1) { return std::make_unique(config); } else { + SHERPA_ONNX_LOGE("Running streaming Nemo transducer model"); return std::make_unique(config); } } diff --git a/sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h b/sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h index 6a12d65e51..716f12c7fd 100644 --- a/sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h @@ -99,7 +99,7 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { // TODO(fangjun): Remember to change these constants if needed int32_t frame_shift_ms = 10; - int32_t subsampling_factor = 4; + int32_t subsampling_factor = 8; return Convert(decoder_result, symbol_table_, frame_shift_ms, subsampling_factor, s->GetCurrentSegment(), s->GetNumFramesSinceStart()); } @@ -114,8 +114,8 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { // frame shift is 10 milliseconds float frame_shift_in_seconds = 0.01; - // subsampling factor is 4 - int32_t trailing_silence_frames = s->GetResult().num_trailing_blanks * 4; + // subsampling factor is 8 + int32_t trailing_silence_frames = s->GetResult().num_trailing_blanks * 8; return endpoint_.IsEndpoint(num_processed_frames, trailing_silence_frames, frame_shift_in_seconds); @@ -180,7 +180,8 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { features_vec.size(), x_shape.data(), x_shape.size()); - auto states = model_->StackStates(states_vec); + // Batch size is 1 + auto states = std::move(states_vec[0]); int32_t num_states = states.size(); auto t = model_->RunEncoder(std::move(x), std::move(states)); // t[0] encoder_out, float tensor, (batch_size, dim, T) @@ -196,25 +197,27 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { Ort::Value encoder_out = Transpose12(model_->Allocator(), &t[0]); // defined in online-transducer-greedy-search-nemo-decoder.h - std::vector decoder_states = model_->GetDecoderInitStates(1); - decoder_states = decoder_->Decode_me(std::move(encoder_out), + // get intial states of decoder. + std::vector &decoder_states = ss[0]->GetNeMoDecoderStates(); + + // Subsequent decoder states (for each chunks) are updated inside the Decode method. + // This returns the decoder state from the LAST chunk. We probably dont need it. So we can discard it. + decoder_states = decoder_->Decode(std::move(encoder_out), std::move(decoder_states), &result, ss, n); - std::vector> next_states = - model_->UnStackStates(decoder_states); - for (int32_t i = 0; i != n; ++i) { - ss[i]->SetResult(result[i]); - ss[i]->SetNeMoDecoderStates(std::move(next_states[i])); - } + ss[0]->SetResult(result[0]); + + // We probably dont need it. Will discard it. + ss[0]->SetStates(std::move(decoder_states)); } void InitOnlineStream(OnlineStream *stream) const { auto r = decoder_->GetEmptyResult(); stream->SetResult(r); - // stream->SetNeMoDecoderStates(model_->GetDecoderInitStates(1)); + stream->SetNeMoDecoderStates(model_->GetDecoderInitStates(1)); } private: @@ -256,7 +259,7 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { OnlineRecognizerConfig config_; SymbolTable symbol_table_; std::unique_ptr model_; - std::unique_ptr decoder_; + std::unique_ptr decoder_; Endpoint endpoint_; }; diff --git a/sherpa-onnx/csrc/online-transducer-decoder.h b/sherpa-onnx/csrc/online-transducer-decoder.h index 0b3ababb46..6265366044 100644 --- a/sherpa-onnx/csrc/online-transducer-decoder.h +++ b/sherpa-onnx/csrc/online-transducer-decoder.h @@ -82,11 +82,6 @@ class OnlineTransducerDecoder { virtual void Decode(Ort::Value encoder_out, std::vector *result) = 0; - virtual std::vector Decode_me(Ort::Value encoder_out, - std::vector decoder_states, - std::vector *result, - OnlineStream **ss = nullptr, int32_t n = 0) = 0; - /** Run transducer beam search given the output from the encoder model. * * Note: Currently this interface is for contextual-biasing feature which diff --git a/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.cc b/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.cc index 7ac91695b9..f049df7b3f 100644 --- a/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.cc +++ b/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.cc @@ -36,6 +36,7 @@ static std::pair BuildDecoderInput( return {std::move(decoder_input), std::move(decoder_input_length)}; } + OnlineTransducerDecoderResult OnlineTransducerGreedySearchNeMoDecoder::GetEmptyResult() const { int32_t context_size = 8; @@ -47,29 +48,58 @@ OnlineTransducerGreedySearchNeMoDecoder::GetEmptyResult() const { return r; } +static void UpdateCachedDecoderOut( + OrtAllocator *allocator, const Ort::Value *decoder_out, + std::vector *result) { + std::vector shape = + decoder_out->GetTensorTypeAndShapeInfo().GetShape(); + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + std::array v_shape{1, shape[1]}; + + const float *src = decoder_out->GetTensorData(); + for (auto &r : *result) { + if (!r.decoder_out) { + r.decoder_out = Ort::Value::CreateTensor(allocator, v_shape.data(), + v_shape.size()); + } + + float *dst = r.decoder_out.GetTensorMutableData(); + std::copy(src, src + shape[1], dst); + src += shape[1]; + } +} -std::pair> DecodeOne( +std::vector DecodeOne( const float *encoder_out, int32_t num_rows, int32_t num_cols, OnlineTransducerNeMoModel *model, float blank_penalty, - std::vector& decoder_states) { + std::vector& decoder_states, + std::vector *result) { - // num_rows = frames auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); - OnlineTransducerDecoderResult result; + // OnlineTransducerDecoderResult result; int32_t vocab_size = model->VocabSize(); int32_t blank_id = vocab_size - 1; + + auto &r = (*result)[0]; + Ort::Value decoder_out{nullptr}; auto decoder_input_pair = BuildDecoderInput(blank_id, model->Allocator()); // decoder_input_pair[0]: decoder_input // decoder_input_pair[1]: decoder_input_length (discarded) + + // decoder_output_pair.second returns the next decoder state std::pair> decoder_output_pair = model->RunDecoder(std::move(decoder_input_pair.first), std::move(decoder_states)); std::array encoder_shape{1, num_cols, 1}; + decoder_states = std::move(decoder_output_pair.second); + + // start with each chunks in the input sequence. Is this loop really meant for that? for (int32_t t = 0; t != num_rows; ++t) { Ort::Value cur_encoder_out = Ort::Value::CreateTensor( memory_info, const_cast(encoder_out) + t * num_cols, num_cols, @@ -89,33 +119,52 @@ std::pair> DecodeOne( static_cast(p_logit) + vocab_size))); if (y != blank_id) { - result.tokens.push_back(y); - result.timestamps.push_back(t); + r.tokens.push_back(y); + r.timestamps.push_back(t + r.frame_offset); decoder_input_pair = BuildDecoderInput(y, model->Allocator()); + // last decoder state becomes the current state for the first chunk decoder_output_pair = model->RunDecoder(std::move(decoder_input_pair.first), - std::move(decoder_output_pair.second)); - } // if (y != blank_id) - } // for (int32_t t = 0; t != num_rows; ++t) + std::move(decoder_states)); + } + + // Update the decoder states for the next chunk + decoder_states = std::move(decoder_output_pair.second); + } - // Update the decoder states for the next chunk - decoder_states = std::move(decoder_output_pair.second); + decoder_out = std::move(decoder_output_pair.first); + // UpdateCachedDecoderOut(model->Allocator(), &decoder_out, result); - return {result, std::move(decoder_states)}; + // Update frame_offset + for (auto &r : *result) { + r.frame_offset += num_rows; + } + + return std::move(decoder_states); } -std::vector OnlineTransducerGreedySearchNeMoDecoder::Decode_me( + +std::vector OnlineTransducerGreedySearchNeMoDecoder::Decode( Ort::Value encoder_out, std::vector decoder_states, std::vector *result, OnlineStream ** /*ss = nullptr*/, int32_t /*n= 0*/) { auto shape = encoder_out.GetTensorTypeAndShapeInfo().GetShape(); + + if (shape[0] != result->size()) { + SHERPA_ONNX_LOGE( + "Size mismatch! encoder_out.size(0) %d, result.size(0): %d", + static_cast(shape[0]), + static_cast(result->size())); + exit(-1); + } + int32_t batch_size = static_cast(shape[0]); // bs = 1 - int32_t dim1 = static_cast(shape[1]); // feature dimension - int32_t dim2 = static_cast(shape[2]); // frames + int32_t dim1 = static_cast(shape[1]); + int32_t dim2 = static_cast(shape[2]); // Define and initialize encoder_out_length Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); @@ -130,19 +179,16 @@ std::vector OnlineTransducerGreedySearchNeMoDecoder::Decode_me( const int64_t *p_length = encoder_out_length.GetTensorData(); const float *p = encoder_out.GetTensorData(); - std::vector ans(batch_size); + // std::vector ans(batch_size); for (int32_t i = 0; i != batch_size; ++i) { const float *this_p = p + dim1 * dim2 * i; int32_t this_len = p_length[i]; - auto decode_result_pair = DecodeOne(this_p, this_len, dim2, model_, blank_penalty_, decoder_states); - ans[i] = decode_result_pair.first; - decoder_states = std::move(decode_result_pair.second); // Update decoder states for next chunk - - if (result != nullptr && i < result->size()) { - (*result)[i] = ans[i]; - } + // outputs the decoder state from last chunk. + auto last_decoder_states = DecodeOne(this_p, this_len, dim2, model_, blank_penalty_, decoder_states, result); + // ans[i] = decode_result_pair.first; + decoder_states = std::move(last_decoder_states); } return decoder_states; diff --git a/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h b/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h index d78f3f1250..d5a7a078c9 100644 --- a/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h +++ b/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h @@ -12,20 +12,22 @@ namespace sherpa_onnx { -class OnlineTransducerGreedySearchNeMoDecoder : public OnlineTransducerDecoder { +class OnlineTransducerGreedySearchNeMoDecoder { public: OnlineTransducerGreedySearchNeMoDecoder(OnlineTransducerNeMoModel *model, float blank_penalty) : model_(model), - blank_penalty_(blank_penalty){} + blank_penalty_(blank_penalty) {} - OnlineTransducerDecoderResult GetEmptyResult() const override; + OnlineTransducerDecoderResult GetEmptyResult() const; + void UpdateDecoderOut(OnlineTransducerDecoderResult *result) {} + void StripLeadingBlanks(OnlineTransducerDecoderResult * /*r*/) const {} - std::vector Decode_me( + std::vector Decode( Ort::Value encoder_out, std::vector decoder_states, std::vector *result, - OnlineStream **ss = nullptr, int32_t n = 0) override; + OnlineStream **ss = nullptr, int32_t n = 0); private: OnlineTransducerNeMoModel *model_; // Not owned diff --git a/sherpa-onnx/csrc/online-transducer-nemo-model.cc b/sherpa-onnx/csrc/online-transducer-nemo-model.cc index 5b91570d95..d887618e76 100644 --- a/sherpa-onnx/csrc/online-transducer-nemo-model.cc +++ b/sherpa-onnx/csrc/online-transducer-nemo-model.cc @@ -78,74 +78,6 @@ class OnlineTransducerNeMoModel::Impl { } #endif - std::vector StackStates( - std::vector> states) const { - int32_t batch_size = static_cast(states.size()); - if (batch_size == 1) { - return std::move(states[0]); - } - - std::vector ans; - - // stack cache_last_channel - std::vector buf(batch_size); - - // there are 3 states to be stacked - for (int32_t i = 0; i != 3; ++i) { - buf.clear(); - buf.reserve(batch_size); - - for (int32_t b = 0; b != batch_size; ++b) { - assert(states[b].size() == 3); - buf.push_back(&states[b][i]); - } - - Ort::Value c{nullptr}; - if (i == 2) { - c = Cat(allocator_, buf, 0); - } else { - c = Cat(allocator_, buf, 0); - } - - ans.push_back(std::move(c)); - } - - return ans; - } - - std::vector> UnStackStates( - std::vector states) const { - assert(states.size() == 3); - - std::vector> ans; - - auto shape = states[0].GetTensorTypeAndShapeInfo().GetShape(); - int32_t batch_size = shape[0]; - ans.resize(batch_size); - - if (batch_size == 1) { - ans[0] = std::move(states); - return ans; - } - - for (int32_t i = 0; i != 3; ++i) { - std::vector v; - if (i == 2) { - v = Unbind(allocator_, &states[i], 0); - } else { - v = Unbind(allocator_, &states[i], 0); - } - - assert(v.size() == batch_size); - - for (int32_t b = 0; b != batch_size; ++b) { - ans[b].push_back(std::move(v[b])); - } - } - - return ans; - } - std::vector RunEncoder(Ort::Value features, std::vector states) { Ort::Value &cache_last_channel = states[0]; @@ -270,7 +202,6 @@ class OnlineTransducerNeMoModel::Impl { return states; } - int32_t ChunkSize() const { return window_size_; } int32_t ChunkShift() const { return chunk_shift_; } diff --git a/sherpa-onnx/csrc/online-transducer-nemo-model.h b/sherpa-onnx/csrc/online-transducer-nemo-model.h index e236221c5b..97a632f507 100644 --- a/sherpa-onnx/csrc/online-transducer-nemo-model.h +++ b/sherpa-onnx/csrc/online-transducer-nemo-model.h @@ -40,26 +40,6 @@ class OnlineTransducerNeMoModel { // - cache_last_channel_len std::vector GetInitStates() const; - /** Stack a list of individual states into a batch. - * - * It is the inverse operation of `UnStackStates`. - * - * @param states states[i] contains the state for the i-th utterance. - * @return Return a single value representing the batched state. - */ - std::vector StackStates( - const std::vector> &states) const; - - /** Unstack a batched state into a list of individual states. - * - * It is the inverse operation of `StackStates`. - * - * @param states A batched state. - * @return ans[i] contains the state for the i-th utterance. - */ - std::vector> UnStackStates( - const std::vector &states) const; - /** Run the encoder. * * @param features A tensor of shape (N, T, C). It is changed in-place. @@ -91,7 +71,7 @@ class OnlineTransducerNeMoModel { * @param decoder_out Output of the decoder network. * @return Return a tensor of shape (N, 1, 1, vocab_size) containing logits. */ - Ort::Value RunJoiner(Ort::Value encoder_out, + Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) const; From f9633f61e6716945374f8faf2a152b42598d568b Mon Sep 17 00:00:00 2001 From: sangeet2020 <15uec053@gmail.com> Date: Wed, 29 May 2024 15:43:06 +0200 Subject: [PATCH 14/14] Decoding works. but results are not perfect. For first few frames, its good, then incorrect predictions. --- .../csrc/online-recognizer-transducer-impl.h | 1 + .../online-recognizer-transducer-nemo-impl.h | 14 ++++++-------- ...ne-transducer-greedy-search-nemo-decoder.cc | 18 +++++++++--------- .../csrc/online-transducer-nemo-model.cc | 2 +- 4 files changed, 17 insertions(+), 18 deletions(-) diff --git a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h index 402346fa84..b0c6d36f21 100644 --- a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h @@ -46,6 +46,7 @@ static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src, r.timestamps.reserve(src.tokens.size()); for (auto i : src.tokens) { + if (i == -1) continue; auto sym = sym_table[i]; r.text.append(sym); diff --git a/sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h b/sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h index 716f12c7fd..9193b292f9 100644 --- a/sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h @@ -153,7 +153,7 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { std::vector result(n); std::vector features_vec(n * chunk_size * feature_dim); - std::vector> states_vec(n); + std::vector> encoder_states(n); for (int32_t i = 0; i != n; ++i) { const auto num_processed_frames = ss[i]->GetNumProcessedFrames(); @@ -167,7 +167,7 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { features_vec.data() + i * chunk_size * feature_dim); result[i] = std::move(ss[i]->GetResult()); - states_vec[i] = std::move(ss[i]->GetStates()); + encoder_states[i] = std::move(ss[i]->GetStates()); } @@ -181,8 +181,8 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { x_shape.size()); // Batch size is 1 - auto states = std::move(states_vec[0]); - int32_t num_states = states.size(); + auto states = std::move(encoder_states[0]); + int32_t num_states = states.size(); // num_states = 3 auto t = model_->RunEncoder(std::move(x), std::move(states)); // t[0] encoder_out, float tensor, (batch_size, dim, T) // t[1] next states @@ -203,14 +203,12 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { // Subsequent decoder states (for each chunks) are updated inside the Decode method. // This returns the decoder state from the LAST chunk. We probably dont need it. So we can discard it. decoder_states = decoder_->Decode(std::move(encoder_out), - std::move(decoder_states), + std::move(decoder_states), &result, ss, n); - ss[0]->SetResult(result[0]); - // We probably dont need it. Will discard it. - ss[0]->SetStates(std::move(decoder_states)); + ss[0]->SetStates(std::move(out_states)); } void InitOnlineStream(OnlineStream *stream) const { diff --git a/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.cc b/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.cc index f049df7b3f..8f95215f78 100644 --- a/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.cc +++ b/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.cc @@ -93,13 +93,13 @@ std::vector DecodeOne( // decoder_output_pair.second returns the next decoder state std::pair> decoder_output_pair = model->RunDecoder(std::move(decoder_input_pair.first), - std::move(decoder_states)); + std::move(decoder_states)); // here decoder_states = {len=0, cap=0}. But decoder_output_pair= {first, second: {len=2, cap=2}} // ATTN std::array encoder_shape{1, num_cols, 1}; decoder_states = std::move(decoder_output_pair.second); - // start with each chunks in the input sequence. Is this loop really meant for that? + // TODO: Inside this loop, I need to framewise decoding. for (int32_t t = 0; t != num_rows; ++t) { Ort::Value cur_encoder_out = Ort::Value::CreateTensor( memory_info, const_cast(encoder_out) + t * num_cols, num_cols, @@ -117,7 +117,7 @@ std::vector DecodeOne( static_cast(p_logit), std::max_element(static_cast(p_logit), static_cast(p_logit) + vocab_size))); - + SHERPA_ONNX_LOGE("y=%d", y); if (y != blank_id) { r.tokens.push_back(y); r.timestamps.push_back(t + r.frame_offset); @@ -128,14 +128,14 @@ std::vector DecodeOne( decoder_output_pair = model->RunDecoder(std::move(decoder_input_pair.first), std::move(decoder_states)); + + // Update the decoder states for the next chunk + decoder_states = std::move(decoder_output_pair.second); } - - // Update the decoder states for the next chunk - decoder_states = std::move(decoder_output_pair.second); } decoder_out = std::move(decoder_output_pair.first); - // UpdateCachedDecoderOut(model->Allocator(), &decoder_out, result); +// UpdateCachedDecoderOut(model->Allocator(), &decoder_out, result); // Update frame_offset for (auto &r : *result) { @@ -163,8 +163,8 @@ std::vector OnlineTransducerGreedySearchNeMoDecoder::Decode( } int32_t batch_size = static_cast(shape[0]); // bs = 1 - int32_t dim1 = static_cast(shape[1]); - int32_t dim2 = static_cast(shape[2]); + int32_t dim1 = static_cast(shape[1]); // 2 + int32_t dim2 = static_cast(shape[2]); // 512 // Define and initialize encoder_out_length Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); diff --git a/sherpa-onnx/csrc/online-transducer-nemo-model.cc b/sherpa-onnx/csrc/online-transducer-nemo-model.cc index d887618e76..b054e3b727 100644 --- a/sherpa-onnx/csrc/online-transducer-nemo-model.cc +++ b/sherpa-onnx/csrc/online-transducer-nemo-model.cc @@ -80,7 +80,7 @@ class OnlineTransducerNeMoModel::Impl { std::vector RunEncoder(Ort::Value features, std::vector states) { - Ort::Value &cache_last_channel = states[0]; + Ort::Value &cache_last_channel = states[0]; Ort::Value &cache_last_time = states[1]; Ort::Value &cache_last_channel_len = states[2];