diff --git a/.github/scripts/test-online-transducer.sh b/.github/scripts/test-online-transducer.sh index e894358a56..7616b18e99 100755 --- a/.github/scripts/test-online-transducer.sh +++ b/.github/scripts/test-online-transducer.sh @@ -15,6 +15,45 @@ echo "PATH: $PATH" which $EXE +log "------------------------------------------------------------" +log "Run NeMo transducer (English)" +log "------------------------------------------------------------" +repo_url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-streaming-fast-conformer-transducer-en-80ms.tar.bz2 +curl -SL -O $repo_url +tar xvf sherpa-onnx-nemo-streaming-fast-conformer-transducer-en-80ms.tar.bz2 +rm sherpa-onnx-nemo-streaming-fast-conformer-transducer-en-80ms.tar.bz2 +repo=sherpa-onnx-nemo-streaming-fast-conformer-transducer-en-80ms + +log "Start testing ${repo_url}" + +waves=( +$repo/test_wavs/0.wav +$repo/test_wavs/1.wav +$repo/test_wavs/8k.wav +) + +for wave in ${waves[@]}; do + time $EXE \ + --tokens=$repo/tokens.txt \ + --encoder=$repo/encoder.onnx \ + --decoder=$repo/decoder.onnx \ + --joiner=$repo/joiner.onnx \ + --num-threads=2 \ + $wave +done + +time $EXE \ + --tokens=$repo/tokens.txt \ + --encoder=$repo/encoder.onnx \ + --decoder=$repo/decoder.onnx \ + --joiner=$repo/joiner.onnx \ + --num-threads=2 \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/8k.wav + +rm -rf $repo + log "------------------------------------------------------------" log "Run LSTM transducer (English)" log "------------------------------------------------------------" diff --git a/.github/workflows/aarch64-linux-gnu-shared.yaml b/.github/workflows/aarch64-linux-gnu-shared.yaml index 7f907cf995..50ba2c236e 100644 --- a/.github/workflows/aarch64-linux-gnu-shared.yaml +++ b/.github/workflows/aarch64-linux-gnu-shared.yaml @@ -196,7 +196,6 @@ jobs: GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface cd huggingface - git lfs pull mkdir -p aarch64 cp -v ../sherpa-onnx-*-shared.tar.bz2 ./aarch64 diff --git a/.github/workflows/aarch64-linux-gnu-static.yaml b/.github/workflows/aarch64-linux-gnu-static.yaml index 94bcfdb00b..13edc9c17c 100644 --- a/.github/workflows/aarch64-linux-gnu-static.yaml +++ b/.github/workflows/aarch64-linux-gnu-static.yaml @@ -187,7 +187,6 @@ jobs: GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface cd huggingface - git lfs pull mkdir -p aarch64 cp -v ../sherpa-onnx-*-static.tar.bz2 ./aarch64 diff --git a/.github/workflows/android.yaml b/.github/workflows/android.yaml index 9fb400d4c8..e15288e124 100644 --- a/.github/workflows/android.yaml +++ b/.github/workflows/android.yaml @@ -124,7 +124,6 @@ jobs: GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface cd huggingface - git lfs pull cp -v ../sherpa-onnx-*-android.tar.bz2 ./ diff --git a/.github/workflows/arm-linux-gnueabihf.yaml b/.github/workflows/arm-linux-gnueabihf.yaml index 40ed439091..269456815d 100644 --- a/.github/workflows/arm-linux-gnueabihf.yaml +++ b/.github/workflows/arm-linux-gnueabihf.yaml @@ -209,7 +209,6 @@ jobs: GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface cd huggingface - git lfs pull mkdir -p arm32 cp -v ../sherpa-onnx-*.tar.bz2 ./arm32 diff --git a/.github/workflows/build-xcframework.yaml b/.github/workflows/build-xcframework.yaml index 97eb33515d..f6d0dce735 100644 --- a/.github/workflows/build-xcframework.yaml +++ b/.github/workflows/build-xcframework.yaml @@ -138,7 +138,6 @@ jobs: GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface cd huggingface - git lfs pull cp -v ../sherpa-onnx-*.tar.bz2 ./ diff --git a/.github/workflows/riscv64-linux.yaml b/.github/workflows/riscv64-linux.yaml index daa50fb18a..5393e97358 100644 --- a/.github/workflows/riscv64-linux.yaml +++ b/.github/workflows/riscv64-linux.yaml @@ -242,7 +242,6 @@ jobs: GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface cd huggingface - git lfs pull mkdir -p riscv64 cp -v ../sherpa-onnx-*-shared.tar.bz2 ./riscv64 diff --git a/.github/workflows/windows-x64.yaml b/.github/workflows/windows-x64.yaml index 7d26157b76..d1e0a2d4f9 100644 --- a/.github/workflows/windows-x64.yaml +++ b/.github/workflows/windows-x64.yaml @@ -219,7 +219,6 @@ jobs: GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface cd huggingface - git lfs pull mkdir -p win64 cp -v ../sherpa-onnx-*.tar.bz2 ./win64 diff --git a/.github/workflows/windows-x86.yaml b/.github/workflows/windows-x86.yaml index c140dad8ad..1230b20a46 100644 --- a/.github/workflows/windows-x86.yaml +++ b/.github/workflows/windows-x86.yaml @@ -221,7 +221,6 @@ jobs: GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface cd huggingface - git lfs pull mkdir -p win32 cp -v ../sherpa-onnx-*.tar.bz2 ./win32 diff --git a/sherpa-onnx/csrc/online-recognizer-impl.cc b/sherpa-onnx/csrc/online-recognizer-impl.cc index a6f3980166..2de905772f 100644 --- a/sherpa-onnx/csrc/online-recognizer-impl.cc +++ b/sherpa-onnx/csrc/online-recognizer-impl.cc @@ -14,19 +14,18 @@ namespace sherpa_onnx { std::unique_ptr OnlineRecognizerImpl::Create( const OnlineRecognizerConfig &config) { - if (!config.model_config.transducer.encoder.empty()) { Ort::Env env(ORT_LOGGING_LEVEL_WARNING); - + auto decoder_model = ReadFile(config.model_config.transducer.decoder); - auto sess = std::make_unique(env, decoder_model.data(), decoder_model.size(), Ort::SessionOptions{}); - + auto sess = std::make_unique( + env, decoder_model.data(), decoder_model.size(), Ort::SessionOptions{}); + size_t node_count = sess->GetOutputCount(); - + if (node_count == 1) { return std::make_unique(config); } else { - SHERPA_ONNX_LOGE("Running streaming Nemo transducer model"); return std::make_unique(config); } } @@ -50,12 +49,13 @@ std::unique_ptr OnlineRecognizerImpl::Create( AAssetManager *mgr, const OnlineRecognizerConfig &config) { if (!config.model_config.transducer.encoder.empty()) { Ort::Env env(ORT_LOGGING_LEVEL_WARNING); - + auto decoder_model = ReadFile(mgr, config.model_config.transducer.decoder); - auto sess = std::make_unique(env, decoder_model.data(), decoder_model.size(), Ort::SessionOptions{}); - + auto sess = std::make_unique( + env, decoder_model.data(), decoder_model.size(), Ort::SessionOptions{}); + size_t node_count = sess->GetOutputCount(); - + if (node_count == 1) { return std::make_unique(mgr, config); } else { diff --git a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h index 8fa12d94fd..60e3aa2b90 100644 --- a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h @@ -35,18 +35,15 @@ namespace sherpa_onnx { -static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src, - const SymbolTable &sym_table, - float frame_shift_ms, - int32_t subsampling_factor, - int32_t segment, - int32_t frames_since_start) { +OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src, + const SymbolTable &sym_table, + float frame_shift_ms, int32_t subsampling_factor, + int32_t segment, int32_t frames_since_start) { OnlineRecognizerResult r; r.tokens.reserve(src.tokens.size()); r.timestamps.reserve(src.tokens.size()); for (auto i : src.tokens) { - if (i == -1) continue; auto sym = sym_table[i]; r.text.append(sym); diff --git a/sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h b/sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h index 9193b292f9..812bf13e74 100644 --- a/sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h @@ -6,6 +6,7 @@ #ifndef SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_ #define SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_ +#include #include #include #include @@ -32,23 +33,20 @@ namespace sherpa_onnx { // defined in ./online-recognizer-transducer-impl.h -// static may or may not be here? TODDOs -static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src, - const SymbolTable &sym_table, - float frame_shift_ms, - int32_t subsampling_factor, - int32_t segment, - int32_t frames_since_start); +OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src, + const SymbolTable &sym_table, + float frame_shift_ms, int32_t subsampling_factor, + int32_t segment, int32_t frames_since_start); class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { - public: + public: explicit OnlineRecognizerTransducerNeMoImpl( const OnlineRecognizerConfig &config) : config_(config), symbol_table_(config.model_config.tokens), endpoint_(config_.endpoint_config), - model_(std::make_unique( - config.model_config)) { + model_( + std::make_unique(config.model_config)) { if (config.decoding_method == "greedy_search") { decoder_ = std::make_unique( model_.get(), config_.blank_penalty); @@ -73,7 +71,7 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { model_.get(), config_.blank_penalty); } else { SHERPA_ONNX_LOGE("Unsupported decoding method: %s", - config.decoding_method.c_str()); + config.decoding_method.c_str()); exit(-1); } @@ -83,7 +81,6 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { std::unique_ptr CreateStream() const override { auto stream = std::make_unique(config_.feat_config); - stream->SetStates(model_->GetInitStates()); InitOnlineStream(stream.get()); return stream; } @@ -94,14 +91,12 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { } OnlineRecognizerResult GetResult(OnlineStream *s) const override { - OnlineTransducerDecoderResult decoder_result = s->GetResult(); - decoder_->StripLeadingBlanks(&decoder_result); - // TODO(fangjun): Remember to change these constants if needed int32_t frame_shift_ms = 10; - int32_t subsampling_factor = 8; - return Convert(decoder_result, symbol_table_, frame_shift_ms, subsampling_factor, - s->GetCurrentSegment(), s->GetNumFramesSinceStart()); + int32_t subsampling_factor = model_->SubsamplingFactor(); + return Convert(s->GetResult(), symbol_table_, frame_shift_ms, + subsampling_factor, s->GetCurrentSegment(), + s->GetNumFramesSinceStart()); } bool IsEndpoint(OnlineStream *s) const override { @@ -114,8 +109,8 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { // frame shift is 10 milliseconds float frame_shift_in_seconds = 0.01; - // subsampling factor is 8 - int32_t trailing_silence_frames = s->GetResult().num_trailing_blanks * 8; + int32_t trailing_silence_frames = + s->GetResult().num_trailing_blanks * model_->SubsamplingFactor(); return endpoint_.IsEndpoint(num_processed_frames, trailing_silence_frames, frame_shift_in_seconds); @@ -126,19 +121,16 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { // segment is incremented only when the last // result is not empty const auto &r = s->GetResult(); - if (!r.tokens.empty() && r.tokens.back() != 0) { + if (!r.tokens.empty()) { s->GetCurrentSegment() += 1; } } - // we keep the decoder_out - decoder_->UpdateDecoderOut(&s->GetResult()); - Ort::Value decoder_out = std::move(s->GetResult().decoder_out); + s->SetResult({}); + + s->SetStates(model_->GetEncoderInitStates()); - auto r = decoder_->GetEmptyResult(); - - s->SetResult(r); - s->GetResult().decoder_out = std::move(decoder_out); + s->SetNeMoDecoderStates(model_->GetDecoderInitStates()); // Note: We only update counters. The underlying audio samples // are not discarded. @@ -151,10 +143,9 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { int32_t feature_dim = ss[0]->FeatureDim(); - std::vector result(n); std::vector features_vec(n * chunk_size * feature_dim); std::vector> encoder_states(n); - + for (int32_t i = 0; i != n; ++i) { const auto num_processed_frames = ss[i]->GetNumProcessedFrames(); std::vector features = @@ -166,9 +157,7 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { std::copy(features.begin(), features.end(), features_vec.data() + i * chunk_size * feature_dim); - result[i] = std::move(ss[i]->GetResult()); encoder_states[i] = std::move(ss[i]->GetStates()); - } auto memory_info = @@ -180,42 +169,35 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { features_vec.size(), x_shape.data(), x_shape.size()); - // Batch size is 1 - auto states = std::move(encoder_states[0]); - int32_t num_states = states.size(); // num_states = 3 + auto states = model_->StackStates(std::move(encoder_states)); + int32_t num_states = states.size(); // num_states = 3 auto t = model_->RunEncoder(std::move(x), std::move(states)); // t[0] encoder_out, float tensor, (batch_size, dim, T) // t[1] next states - + std::vector out_states; out_states.reserve(num_states); - + for (int32_t k = 1; k != num_states + 1; ++k) { out_states.push_back(std::move(t[k])); } + auto unstacked_states = model_->UnStackStates(std::move(out_states)); + for (int32_t i = 0; i != n; ++i) { + ss[i]->SetStates(std::move(unstacked_states[i])); + } + Ort::Value encoder_out = Transpose12(model_->Allocator(), &t[0]); - - // defined in online-transducer-greedy-search-nemo-decoder.h - // get intial states of decoder. - std::vector &decoder_states = ss[0]->GetNeMoDecoderStates(); - - // Subsequent decoder states (for each chunks) are updated inside the Decode method. - // This returns the decoder state from the LAST chunk. We probably dont need it. So we can discard it. - decoder_states = decoder_->Decode(std::move(encoder_out), - std::move(decoder_states), - &result, ss, n); - - ss[0]->SetResult(result[0]); - - ss[0]->SetStates(std::move(out_states)); + + decoder_->Decode(std::move(encoder_out), ss, n); } void InitOnlineStream(OnlineStream *stream) const { - auto r = decoder_->GetEmptyResult(); + // set encoder states + stream->SetStates(model_->GetEncoderInitStates()); - stream->SetResult(r); - stream->SetNeMoDecoderStates(model_->GetDecoderInitStates(1)); + // set decoder states + stream->SetNeMoDecoderStates(model_->GetDecoderInitStates()); } private: @@ -250,7 +232,6 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { symbol_table_.NumSymbols(), vocab_size); exit(-1); } - } private: @@ -259,9 +240,8 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { std::unique_ptr model_; std::unique_ptr decoder_; Endpoint endpoint_; - }; } // namespace sherpa_onnx -#endif // SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_ \ No newline at end of file +#endif // SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_ diff --git a/sherpa-onnx/csrc/online-stream.cc b/sherpa-onnx/csrc/online-stream.cc index 62d93999e0..0a30cf40c1 100644 --- a/sherpa-onnx/csrc/online-stream.cc +++ b/sherpa-onnx/csrc/online-stream.cc @@ -225,7 +225,8 @@ std::vector &OnlineStream::GetStates() { return impl_->GetStates(); } -void OnlineStream::SetNeMoDecoderStates(std::vector decoder_states) { +void OnlineStream::SetNeMoDecoderStates( + std::vector decoder_states) { return impl_->SetNeMoDecoderStates(std::move(decoder_states)); } diff --git a/sherpa-onnx/csrc/online-stream.h b/sherpa-onnx/csrc/online-stream.h index 4e444366ee..e9958d736f 100644 --- a/sherpa-onnx/csrc/online-stream.h +++ b/sherpa-onnx/csrc/online-stream.h @@ -91,8 +91,8 @@ class OnlineStream { void SetStates(std::vector states); std::vector &GetStates(); - void SetNeMoDecoderStates(std::vector decoder_states); - std::vector &GetNeMoDecoderStates(); + void SetNeMoDecoderStates(std::vector decoder_states); + std::vector &GetNeMoDecoderStates(); /** * Get the context graph corresponding to this stream. diff --git a/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.cc b/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.cc index 8f95215f78..a76db5b727 100644 --- a/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.cc +++ b/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.cc @@ -10,103 +10,64 @@ #include #include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/online-stream.h" #include "sherpa-onnx/csrc/onnx-utils.h" namespace sherpa_onnx { -static std::pair BuildDecoderInput( - int32_t token, OrtAllocator *allocator) { +static Ort::Value BuildDecoderInput(int32_t token, OrtAllocator *allocator) { std::array shape{1, 1}; Ort::Value decoder_input = Ort::Value::CreateTensor(allocator, shape.data(), shape.size()); - std::array length_shape{1}; - Ort::Value decoder_input_length = Ort::Value::CreateTensor( - allocator, length_shape.data(), length_shape.size()); - int32_t *p = decoder_input.GetTensorMutableData(); - int32_t *p_length = decoder_input_length.GetTensorMutableData(); - p[0] = token; - p_length[0] = 1; - - return {std::move(decoder_input), std::move(decoder_input_length)}; -} - - -OnlineTransducerDecoderResult -OnlineTransducerGreedySearchNeMoDecoder::GetEmptyResult() const { - int32_t context_size = 8; - int32_t blank_id = 0; // always 0 - OnlineTransducerDecoderResult r; - r.tokens.resize(context_size, -1); - r.tokens.back() = blank_id; - - return r; + return decoder_input; } -static void UpdateCachedDecoderOut( - OrtAllocator *allocator, const Ort::Value *decoder_out, - std::vector *result) { - std::vector shape = - decoder_out->GetTensorTypeAndShapeInfo().GetShape(); +static void DecodeOne(const float *encoder_out, int32_t num_rows, + int32_t num_cols, OnlineTransducerNeMoModel *model, + float blank_penalty, OnlineStream *s) { auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); - std::array v_shape{1, shape[1]}; - const float *src = decoder_out->GetTensorData(); - for (auto &r : *result) { - if (!r.decoder_out) { - r.decoder_out = Ort::Value::CreateTensor(allocator, v_shape.data(), - v_shape.size()); - } + int32_t vocab_size = model->VocabSize(); + int32_t blank_id = vocab_size - 1; - float *dst = r.decoder_out.GetTensorMutableData(); - std::copy(src, src + shape[1], dst); - src += shape[1]; - } -} + auto &r = s->GetResult(); -std::vector DecodeOne( - const float *encoder_out, int32_t num_rows, int32_t num_cols, - OnlineTransducerNeMoModel *model, float blank_penalty, - std::vector& decoder_states, - std::vector *result) { + Ort::Value decoder_out{nullptr}; - auto memory_info = - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + auto decoder_input = BuildDecoderInput( + r.tokens.empty() ? blank_id : r.tokens.back(), model->Allocator()); - // OnlineTransducerDecoderResult result; - int32_t vocab_size = model->VocabSize(); - int32_t blank_id = vocab_size - 1; - - auto &r = (*result)[0]; - Ort::Value decoder_out{nullptr}; + std::vector &last_decoder_states = s->GetNeMoDecoderStates(); - auto decoder_input_pair = BuildDecoderInput(blank_id, model->Allocator()); - // decoder_input_pair[0]: decoder_input - // decoder_input_pair[1]: decoder_input_length (discarded) + std::vector tmp_decoder_states; + tmp_decoder_states.reserve(last_decoder_states.size()); + for (auto &v : last_decoder_states) { + tmp_decoder_states.push_back(View(&v)); + } // decoder_output_pair.second returns the next decoder state std::pair> decoder_output_pair = - model->RunDecoder(std::move(decoder_input_pair.first), - std::move(decoder_states)); // here decoder_states = {len=0, cap=0}. But decoder_output_pair= {first, second: {len=2, cap=2}} // ATTN + model->RunDecoder(std::move(decoder_input), + std::move(tmp_decoder_states)); std::array encoder_shape{1, num_cols, 1}; - decoder_states = std::move(decoder_output_pair.second); + bool emitted = false; - // TODO: Inside this loop, I need to framewise decoding. for (int32_t t = 0; t != num_rows; ++t) { Ort::Value cur_encoder_out = Ort::Value::CreateTensor( memory_info, const_cast(encoder_out) + t * num_cols, num_cols, encoder_shape.data(), encoder_shape.size()); Ort::Value logit = model->RunJoiner(std::move(cur_encoder_out), - View(&decoder_output_pair.first)); + View(&decoder_output_pair.first)); float *p_logit = logit.GetTensorMutableData(); if (blank_penalty > 0) { @@ -117,82 +78,52 @@ std::vector DecodeOne( static_cast(p_logit), std::max_element(static_cast(p_logit), static_cast(p_logit) + vocab_size))); - SHERPA_ONNX_LOGE("y=%d", y); + if (y != blank_id) { + emitted = true; r.tokens.push_back(y); r.timestamps.push_back(t + r.frame_offset); + r.num_trailing_blanks = 0; - decoder_input_pair = BuildDecoderInput(y, model->Allocator()); + decoder_input = BuildDecoderInput(y, model->Allocator()); // last decoder state becomes the current state for the first chunk - decoder_output_pair = - model->RunDecoder(std::move(decoder_input_pair.first), - std::move(decoder_states)); - - // Update the decoder states for the next chunk - decoder_states = std::move(decoder_output_pair.second); + decoder_output_pair = model->RunDecoder( + std::move(decoder_input), std::move(decoder_output_pair.second)); + } else { + ++r.num_trailing_blanks; } } - decoder_out = std::move(decoder_output_pair.first); -// UpdateCachedDecoderOut(model->Allocator(), &decoder_out, result); - - // Update frame_offset - for (auto &r : *result) { - r.frame_offset += num_rows; + if (emitted) { + s->SetNeMoDecoderStates(std::move(decoder_output_pair.second)); } - return std::move(decoder_states); + r.frame_offset += num_rows; } - -std::vector OnlineTransducerGreedySearchNeMoDecoder::Decode( - Ort::Value encoder_out, - std::vector decoder_states, - std::vector *result, - OnlineStream ** /*ss = nullptr*/, int32_t /*n= 0*/) { - +void OnlineTransducerGreedySearchNeMoDecoder::Decode(Ort::Value encoder_out, + OnlineStream **ss, + int32_t n) const { auto shape = encoder_out.GetTensorTypeAndShapeInfo().GetShape(); + int32_t batch_size = static_cast(shape[0]); // bs = 1 - if (shape[0] != result->size()) { - SHERPA_ONNX_LOGE( - "Size mismatch! encoder_out.size(0) %d, result.size(0): %d", - static_cast(shape[0]), - static_cast(result->size())); + if (batch_size != n) { + SHERPA_ONNX_LOGE("Size mismatch! encoder_out.size(0) %d, n: %d", + static_cast(shape[0]), n); exit(-1); } - int32_t batch_size = static_cast(shape[0]); // bs = 1 - int32_t dim1 = static_cast(shape[1]); // 2 - int32_t dim2 = static_cast(shape[2]); // 512 - - // Define and initialize encoder_out_length - Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); - - int64_t length_value = 1; - std::vector length_shape = {1}; - - Ort::Value encoder_out_length = Ort::Value::CreateTensor( - memory_info, &length_value, 1, length_shape.data(), length_shape.size() - ); - - const int64_t *p_length = encoder_out_length.GetTensorData(); - const float *p = encoder_out.GetTensorData(); + int32_t dim1 = static_cast(shape[1]); // T + int32_t dim2 = static_cast(shape[2]); // encoder_out_dim - // std::vector ans(batch_size); + const float *p = encoder_out.GetTensorData(); for (int32_t i = 0; i != batch_size; ++i) { const float *this_p = p + dim1 * dim2 * i; - int32_t this_len = p_length[i]; - // outputs the decoder state from last chunk. - auto last_decoder_states = DecodeOne(this_p, this_len, dim2, model_, blank_penalty_, decoder_states, result); - // ans[i] = decode_result_pair.first; - decoder_states = std::move(last_decoder_states); + DecodeOne(this_p, dim1, dim2, model_, blank_penalty_, ss[i]); } - - return decoder_states; - } -} // namespace sherpa_onnx \ No newline at end of file +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h b/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h index d5a7a078c9..212008fdd1 100644 --- a/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h +++ b/sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h @@ -7,27 +7,22 @@ #define SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_GREEDY_SEARCH_NEMO_DECODER_H_ #include + #include "sherpa-onnx/csrc/online-transducer-decoder.h" #include "sherpa-onnx/csrc/online-transducer-nemo-model.h" namespace sherpa_onnx { +class OnlineStream; + class OnlineTransducerGreedySearchNeMoDecoder { public: OnlineTransducerGreedySearchNeMoDecoder(OnlineTransducerNeMoModel *model, float blank_penalty) - : model_(model), - blank_penalty_(blank_penalty) {} - - OnlineTransducerDecoderResult GetEmptyResult() const; - void UpdateDecoderOut(OnlineTransducerDecoderResult *result) {} - void StripLeadingBlanks(OnlineTransducerDecoderResult * /*r*/) const {} - - std::vector Decode( - Ort::Value encoder_out, - std::vector decoder_states, - std::vector *result, - OnlineStream **ss = nullptr, int32_t n = 0); + : model_(model), blank_penalty_(blank_penalty) {} + + // @param n number of elements in ss + void Decode(Ort::Value encoder_out, OnlineStream **ss, int32_t n) const; private: OnlineTransducerNeMoModel *model_; // Not owned @@ -37,4 +32,3 @@ class OnlineTransducerGreedySearchNeMoDecoder { } // namespace sherpa_onnx #endif // SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_GREEDY_SEARCH_NEMO_DECODER_H_ - diff --git a/sherpa-onnx/csrc/online-transducer-nemo-model.cc b/sherpa-onnx/csrc/online-transducer-nemo-model.cc index b054e3b727..1b04fcaf68 100644 --- a/sherpa-onnx/csrc/online-transducer-nemo-model.cc +++ b/sherpa-onnx/csrc/online-transducer-nemo-model.cc @@ -54,7 +54,7 @@ class OnlineTransducerNeMoModel::Impl { InitJoiner(buf.data(), buf.size()); } } - + #if __ANDROID_API__ >= 9 Impl(AAssetManager *mgr, const OnlineModelConfig &config) : config_(config), @@ -79,7 +79,7 @@ class OnlineTransducerNeMoModel::Impl { #endif std::vector RunEncoder(Ort::Value features, - std::vector states) { + std::vector states) { Ort::Value &cache_last_channel = states[0]; Ort::Value &cache_last_time = states[1]; Ort::Value &cache_last_channel_len = states[2]; @@ -102,9 +102,9 @@ class OnlineTransducerNeMoModel::Impl { std::move(features), View(&length), std::move(cache_last_channel), std::move(cache_last_time), std::move(cache_last_channel_len)}; - auto out = - encoder_sess_->Run({}, encoder_input_names_ptr_.data(), inputs.data(), inputs.size(), - encoder_output_names_ptr_.data(), encoder_output_names_ptr_.size()); + auto out = encoder_sess_->Run( + {}, encoder_input_names_ptr_.data(), inputs.data(), inputs.size(), + encoder_output_names_ptr_.data(), encoder_output_names_ptr_.size()); // out[0]: logit // out[1] logit_length // out[2:] states_next @@ -127,17 +127,19 @@ class OnlineTransducerNeMoModel::Impl { std::pair> RunDecoder( Ort::Value targets, std::vector states) { - - Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); + Ort::MemoryInfo memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); + + auto shape = targets.GetTensorTypeAndShapeInfo().GetShape(); + int32_t batch_size = static_cast(shape[0]); - // Create the tensor with a single int32_t value of 1 - int32_t length_value = 1; - std::vector length_shape = {1}; + std::vector length_shape = {batch_size}; + std::vector length_value(batch_size, 1); Ort::Value targets_length = Ort::Value::CreateTensor( - memory_info, &length_value, 1, length_shape.data(), length_shape.size() - ); - + memory_info, length_value.data(), batch_size, length_shape.data(), + length_shape.size()); + std::vector decoder_inputs; decoder_inputs.reserve(2 + states.size()); @@ -171,35 +173,21 @@ class OnlineTransducerNeMoModel::Impl { Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) { std::array joiner_input = {std::move(encoder_out), std::move(decoder_out)}; - auto logit = - joiner_sess_->Run({}, joiner_input_names_ptr_.data(), joiner_input.data(), - joiner_input.size(), joiner_output_names_ptr_.data(), - joiner_output_names_ptr_.size()); + auto logit = joiner_sess_->Run({}, joiner_input_names_ptr_.data(), + joiner_input.data(), joiner_input.size(), + joiner_output_names_ptr_.data(), + joiner_output_names_ptr_.size()); return std::move(logit[0]); -} - - std::vector GetDecoderInitStates(int32_t batch_size) const { - std::array s0_shape{pred_rnn_layers_, batch_size, pred_hidden_}; - Ort::Value s0 = Ort::Value::CreateTensor(allocator_, s0_shape.data(), - s0_shape.size()); - - Fill(&s0, 0); - - std::array s1_shape{pred_rnn_layers_, batch_size, pred_hidden_}; - - Ort::Value s1 = Ort::Value::CreateTensor(allocator_, s1_shape.data(), - s1_shape.size()); - - Fill(&s1, 0); - - std::vector states; + } - states.reserve(2); - states.push_back(std::move(s0)); - states.push_back(std::move(s1)); + std::vector GetDecoderInitStates() { + std::vector ans; + ans.reserve(2); + ans.push_back(View(&lstm0_)); + ans.push_back(View(&lstm1_)); - return states; + return ans; } int32_t ChunkSize() const { return window_size_; } @@ -207,7 +195,7 @@ class OnlineTransducerNeMoModel::Impl { int32_t ChunkShift() const { return chunk_shift_; } int32_t SubsamplingFactor() const { return subsampling_factor_; } - + int32_t VocabSize() const { return vocab_size_; } OrtAllocator *Allocator() const { return allocator_; } @@ -218,7 +206,7 @@ class OnlineTransducerNeMoModel::Impl { // - cache_last_channel // - cache_last_time_ // - cache_last_channel_len - std::vector GetInitStates() { + std::vector GetEncoderInitStates() { std::vector ans; ans.reserve(3); ans.push_back(View(&cache_last_channel_)); @@ -228,7 +216,75 @@ class OnlineTransducerNeMoModel::Impl { return ans; } -private: + std::vector StackStates( + std::vector> states) const { + int32_t batch_size = static_cast(states.size()); + if (batch_size == 1) { + return std::move(states[0]); + } + + std::vector ans; + + // stack cache_last_channel + std::vector buf(batch_size); + + // there are 3 states to be stacked + for (int32_t i = 0; i != 3; ++i) { + buf.clear(); + buf.reserve(batch_size); + + for (int32_t b = 0; b != batch_size; ++b) { + assert(states[b].size() == 3); + buf.push_back(&states[b][i]); + } + + Ort::Value c{nullptr}; + if (i == 2) { + c = Cat(allocator_, buf, 0); + } else { + c = Cat(allocator_, buf, 0); + } + + ans.push_back(std::move(c)); + } + + return ans; + } + + std::vector> UnStackStates( + std::vector states) const { + assert(states.size() == 3); + + std::vector> ans; + + auto shape = states[0].GetTensorTypeAndShapeInfo().GetShape(); + int32_t batch_size = shape[0]; + ans.resize(batch_size); + + if (batch_size == 1) { + ans[0] = std::move(states); + return ans; + } + + for (int32_t i = 0; i != 3; ++i) { + std::vector v; + if (i == 2) { + v = Unbind(allocator_, &states[i], 0); + } else { + v = Unbind(allocator_, &states[i], 0); + } + + assert(v.size() == batch_size); + + for (int32_t b = 0; b != batch_size; ++b) { + ans[b].push_back(std::move(v[b])); + } + } + + return ans; + } + + private: void InitEncoder(void *model_data, size_t model_data_length) { encoder_sess_ = std::make_unique( env_, model_data, model_data_length, sess_opts_); @@ -276,10 +332,10 @@ class OnlineTransducerNeMoModel::Impl { normalize_type_ = ""; } - InitStates(); + InitEncoderStates(); } - - void InitStates() { + + void InitEncoderStates() { std::array cache_last_channel_shape{1, cache_last_channel_dim1_, cache_last_channel_dim2_, cache_last_channel_dim3_}; @@ -313,7 +369,25 @@ class OnlineTransducerNeMoModel::Impl { &decoder_input_names_ptr_); GetOutputNames(decoder_sess_.get(), &decoder_output_names_, - &decoder_output_names_ptr_); + &decoder_output_names_ptr_); + + InitDecoderStates(); + } + + void InitDecoderStates() { + int32_t batch_size = 1; + std::array s0_shape{pred_rnn_layers_, batch_size, pred_hidden_}; + lstm0_ = Ort::Value::CreateTensor(allocator_, s0_shape.data(), + s0_shape.size()); + + Fill(&lstm0_, 0); + + std::array s1_shape{pred_rnn_layers_, batch_size, pred_hidden_}; + + lstm1_ = Ort::Value::CreateTensor(allocator_, s1_shape.data(), + s1_shape.size()); + + Fill(&lstm1_, 0); } void InitJoiner(void *model_data, size_t model_data_length) { @@ -324,7 +398,7 @@ class OnlineTransducerNeMoModel::Impl { &joiner_input_names_ptr_); GetOutputNames(joiner_sess_.get(), &joiner_output_names_, - &joiner_output_names_ptr_); + &joiner_output_names_ptr_); } private: @@ -363,6 +437,7 @@ class OnlineTransducerNeMoModel::Impl { int32_t pred_rnn_layers_ = -1; int32_t pred_hidden_ = -1; + // encoder states int32_t cache_last_channel_dim1_; int32_t cache_last_channel_dim2_; int32_t cache_last_channel_dim3_; @@ -370,9 +445,14 @@ class OnlineTransducerNeMoModel::Impl { int32_t cache_last_time_dim2_; int32_t cache_last_time_dim3_; + // init encoder states Ort::Value cache_last_channel_{nullptr}; Ort::Value cache_last_time_{nullptr}; Ort::Value cache_last_channel_len_{nullptr}; + + // init decoder states + Ort::Value lstm0_{nullptr}; + Ort::Value lstm1_{nullptr}; }; OnlineTransducerNeMoModel::OnlineTransducerNeMoModel( @@ -387,10 +467,9 @@ OnlineTransducerNeMoModel::OnlineTransducerNeMoModel( OnlineTransducerNeMoModel::~OnlineTransducerNeMoModel() = default; -std::vector -OnlineTransducerNeMoModel::RunEncoder(Ort::Value features, - std::vector states) const { - return impl_->RunEncoder(std::move(features), std::move(states)); +std::vector OnlineTransducerNeMoModel::RunEncoder( + Ort::Value features, std::vector states) const { + return impl_->RunEncoder(std::move(features), std::move(states)); } std::pair> @@ -399,9 +478,9 @@ OnlineTransducerNeMoModel::RunDecoder(Ort::Value targets, return impl_->RunDecoder(std::move(targets), std::move(states)); } -std::vector OnlineTransducerNeMoModel::GetDecoderInitStates( - int32_t batch_size) const { - return impl_->GetDecoderInitStates(batch_size); +std::vector OnlineTransducerNeMoModel::GetDecoderInitStates() + const { + return impl_->GetDecoderInitStates(); } Ort::Value OnlineTransducerNeMoModel::RunJoiner(Ort::Value encoder_out, @@ -409,14 +488,13 @@ Ort::Value OnlineTransducerNeMoModel::RunJoiner(Ort::Value encoder_out, return impl_->RunJoiner(std::move(encoder_out), std::move(decoder_out)); } +int32_t OnlineTransducerNeMoModel::ChunkSize() const { + return impl_->ChunkSize(); +} -int32_t OnlineTransducerNeMoModel::ChunkSize() const { - return impl_->ChunkSize(); - } - -int32_t OnlineTransducerNeMoModel::ChunkShift() const { - return impl_->ChunkShift(); - } +int32_t OnlineTransducerNeMoModel::ChunkShift() const { + return impl_->ChunkShift(); +} int32_t OnlineTransducerNeMoModel::SubsamplingFactor() const { return impl_->SubsamplingFactor(); @@ -434,8 +512,19 @@ std::string OnlineTransducerNeMoModel::FeatureNormalizationMethod() const { return impl_->FeatureNormalizationMethod(); } -std::vector OnlineTransducerNeMoModel::GetInitStates() const { - return impl_->GetInitStates(); +std::vector OnlineTransducerNeMoModel::GetEncoderInitStates() + const { + return impl_->GetEncoderInitStates(); +} + +std::vector OnlineTransducerNeMoModel::StackStates( + std::vector> states) const { + return impl_->StackStates(std::move(states)); +} + +std::vector> OnlineTransducerNeMoModel::UnStackStates( + std::vector states) const { + return impl_->UnStackStates(std::move(states)); } -} // namespace sherpa_onnx \ No newline at end of file +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-transducer-nemo-model.h b/sherpa-onnx/csrc/online-transducer-nemo-model.h index 97a632f507..e12814cc06 100644 --- a/sherpa-onnx/csrc/online-transducer-nemo-model.h +++ b/sherpa-onnx/csrc/online-transducer-nemo-model.h @@ -32,22 +32,31 @@ class OnlineTransducerNeMoModel { OnlineTransducerNeMoModel(AAssetManager *mgr, const OnlineModelConfig &config); #endif - + ~OnlineTransducerNeMoModel(); - // A list of 3 tensors: + // A list of 3 tensors: // - cache_last_channel // - cache_last_time // - cache_last_channel_len - std::vector GetInitStates() const; + std::vector GetEncoderInitStates() const; + + // stack encoder states + std::vector StackStates( + std::vector> states) const; + + // unstack encoder states + std::vector> UnStackStates( + std::vector states) const; /** Run the encoder. * * @param features A tensor of shape (N, T, C). It is changed in-place. - * @param states It is from GetInitStates() or returned from this method. - * + * @param states It is from GetEncoderInitStates() or returned from this + * method. + * * @return Return a tuple containing: - * - ans[0]: encoder_out, a tensor of shape (N, T', encoder_out_dim) - * - ans[1:]: contains next states + * - ans[0]: encoder_out, a tensor of shape (N, encoder_out_dim, T') + * - ans[1:]: contains next states */ std::vector RunEncoder( Ort::Value features, std::vector states) const; // NOLINT @@ -63,7 +72,7 @@ class OnlineTransducerNeMoModel { std::pair> RunDecoder( Ort::Value targets, std::vector states) const; - std::vector GetDecoderInitStates(int32_t batch_size) const; + std::vector GetDecoderInitStates() const; /** Run the joint network. * @@ -71,9 +80,7 @@ class OnlineTransducerNeMoModel { * @param decoder_out Output of the decoder network. * @return Return a tensor of shape (N, 1, 1, vocab_size) containing logits. */ - Ort::Value RunJoiner(Ort::Value encoder_out, - Ort::Value decoder_out) const; - + Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) const; /** We send this number of feature frames to the encoder at a time. */ int32_t ChunkSize() const; @@ -114,10 +121,10 @@ class OnlineTransducerNeMoModel { // for details std::string FeatureNormalizationMethod() const; - private: - class Impl; - std::unique_ptr impl_; - }; + private: + class Impl; + std::unique_ptr impl_; +}; } // namespace sherpa_onnx