-
Notifications
You must be signed in to change notification settings - Fork 738
Add C++ runtime for *streaming* faster conformer transducer from NeMo. #889
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
cdca4e6
ca4bfe8
2bb7d7e
44f8d8c
afb10d4
7800cc0
d47bf6f
7837a5d
4c3e741
a5c9cc8
6608ec3
72a45c2
f5f7b27
e1613b6
f9633f6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,205 @@ | ||
// sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h | ||
// | ||
// Copyright (c) 2022-2024 Xiaomi Corporation | ||
// Copyright (c) 2024 Sangeet Sagar | ||
|
||
#ifndef SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_ | ||
#define SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_ | ||
|
||
#include <fstream> | ||
#include <ios> | ||
#include <memory> | ||
#include <regex> // NOLINT | ||
#include <sstream> | ||
#include <string> | ||
#include <utility> | ||
#include <vector> | ||
|
||
#if __ANDROID_API__ >= 9 | ||
#include "android/asset_manager.h" | ||
#include "android/asset_manager_jni.h" | ||
#endif | ||
|
||
#include "sherpa-onnx/csrc/macros.h" | ||
#include "sherpa-onnx/csrc/online-recognizer-impl.h" | ||
#include "sherpa-onnx/csrc/online-recognizer.h" | ||
#include "sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.h" | ||
#include "sherpa-onnx/csrc/online-transducer-nemo-model.h" | ||
#include "sherpa-onnx/csrc/symbol-table.h" | ||
#include "sherpa-onnx/csrc/transpose.h" | ||
#include "sherpa-onnx/csrc/utils.h" | ||
|
||
namespace sherpa_onnx { | ||
|
||
// defined in ./online-recognizer-transducer-impl.h | ||
OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src, | ||
const SymbolTable &sym_table, | ||
float frame_shift_ms, | ||
int32_t subsampling_factor, | ||
int32_t segment, | ||
int32_t frames_since_start); | ||
|
||
class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { | ||
public: | ||
explicit OnlineRecognizerTransducerNeMoImpl( | ||
const OnlineRecognizerConfig &config) | ||
: config_(config), | ||
symbol_table_(config_.model_config.tokens), | ||
model_(std::make_unique<OnlineTransducerNeMoModel>( | ||
config_.model_config)) { | ||
if (config_.decoding_method == "greedy_search") { | ||
decoder_ = std::make_unique<OnlineTransducerGreedySearchNeMoDecoder>( | ||
model_.get(), config_.blank_penalty); | ||
} else { | ||
SHERPA_ONNX_LOGE("Unsupported decoding method: %s", | ||
config_.decoding_method.c_str()); | ||
exit(-1); | ||
} | ||
PostInit(); | ||
} | ||
|
||
#if __ANDROID_API__ >= 9 | ||
explicit OnlineRecognizerTransducerNeMoImpl( | ||
AAssetManager *mgr, const OnlineRecognizerConfig &config) | ||
: config_(config), | ||
symbol_table_(mgr, config_.model_config.tokens), | ||
model_(std::make_unique<OnlineTransducerNeMoModel>( | ||
mgr, config_.model_config)) { | ||
if (config_.decoding_method == "greedy_search") { | ||
decoder_ = std::make_unique<OnlineTransducerGreedySearchNeMoDecoder>( | ||
model_.get(), config_.blank_penalty); | ||
} else { | ||
SHERPA_ONNX_LOGE("Unsupported decoding method: %s", | ||
config_.decoding_method.c_str()); | ||
exit(-1); | ||
} | ||
|
||
PostInit(); | ||
} | ||
#endif | ||
|
||
std::unique_ptr<OnlineStream> CreateStream() const override { | ||
auto stream = std::make_unique<OnlineStream>(config_.feat_config); | ||
stream->SetStates(model_->GetInitStates()); | ||
InitOnlineStream(stream.get()); | ||
return stream; | ||
} | ||
|
||
void DecodeStreams(OnlineStream **ss, int32_t n) const override { | ||
int32_t chunk_size = model_->ChunkSize(); | ||
int32_t chunk_shift = model_->ChunkShift(); | ||
|
||
int32_t feature_dim = ss[0]->FeatureDim(); | ||
|
||
std::vector<OnlineTransducerDecoderResult> results(n); | ||
std::vector<float> features_vec(n * chunk_size * feature_dim); | ||
std::vector<std::vector<Ort::Value>> states_vec(n); | ||
std::vector<int64_t> all_processed_frames(n); | ||
sangeet2020 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
for (int32_t i = 0; i != n; ++i) { | ||
const auto num_processed_frames = ss[i]->GetNumProcessedFrames(); | ||
std::vector<float> features = | ||
ss[i]->GetFrames(num_processed_frames, chunk_size); | ||
|
||
// Question: should num_processed_frames include chunk_shift? | ||
ss[i]->GetNumProcessedFrames() += chunk_shift; | ||
|
||
std::copy(features.begin(), features.end(), | ||
features_vec.data() + i * chunk_size * feature_dim); | ||
|
||
results[i] = std::move(ss[i]->GetResult()); | ||
states_vec[i] = std::move(ss[i]->GetStates()); | ||
all_processed_frames[i] = num_processed_frames; | ||
sangeet2020 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
auto memory_info = | ||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
|
||
std::array<int64_t, 3> x_shape{n, chunk_size, feature_dim}; | ||
|
||
Ort::Value x = Ort::Value::CreateTensor(memory_info, features_vec.data(), | ||
features_vec.size(), x_shape.data(), | ||
x_shape.size()); | ||
|
||
auto states = model_->StackStates(states_vec); | ||
int32_t num_states = states.size(); | ||
auto t = model_->RunEncoder(std::move(x), std::move(states)); | ||
// t[0] encoder_out, float tensor, (batch_size, dim, T) | ||
// t[1] encoder_out_length, int64 tensor, (batch_size,) | ||
sangeet2020 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
std::vector<Ort::Value> out_states; | ||
out_states.reserve(num_states); | ||
|
||
for (int32_t k = 1; k != num_states + 1; ++k) { | ||
out_states.push_back(std::move(t[k])); | ||
} | ||
|
||
Ort::Value encoder_out = Transpose12(model_->Allocator(), &t[0]); | ||
|
||
// defined in online-transducer-greedy-search-nemo-decoder.h | ||
decoder_-> Decode(std::move(encoder_out), std::move(t[1]), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. By the way, you don't need to pass the encoder model states to the greedy search decoder. Please pass the decoder model states to it instead. Please read carefully our python streaming transducer greedy search decoding example. I have posted the example here one more time in case you have not read it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. decoder_-> Decode(std::move(encoder_out), std::move(out_states), &results, ss, n); is this correct? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No. Remember that We need to pass the LSTM states from the decoder model to the greedy search decoder. I suggest you again that you re-read the python decoding example and figure out how the decoding works. (you need to know how LSTM works) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. working on it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I understand that when initializing the model in beginning, init_cache_state intializes the initial states of the encoder model. Further when decoding begins, before the first chunk is decoded, decoder model comes into action and initializes the decoder states. After a chunk has been decoded, it emits the next states for the decoder model which becomes the current states for the next chunk. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, so you don't need the encoder states during greedy search decoding. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. std::vector<Ort::Value> decoder_states = model_->GetDecoderInitStates(1);
decoder_-> Decode(std::move(encoder_out), std::move(decoder_states), &results, ss, n);
}
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please read
make sure you indeed understand the code. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hi @csukuangfj , i see in model->RunDecoder(std::move(decoder_input_pair.first),
std::move(decoder_input_pair.second),
model->GetDecoderInitStates(1)); and I realize that above, I did it similar way. I am missing where exactly I am doing wrong. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I revised the code // defined in online-transducer-greedy-search-nemo-decoder.h
std::vector<Ort::Value> decoder_states = model_->GetDecoderInitStates(1);
// updated decoder states are returned
decoder_states = decoder_->Decode(std::move(encoder_out),
std::move(decoder_states),
&results, ss, n);
std::vector<std::vector<Ort::Value>> next_states =
model_->UnStackStates(decoder_states); Is this correct? |
||
std::move(out_states), &results, ss, n); | ||
|
||
std::vector<std::vector<Ort::Value>> next_states = | ||
model_->UnStackStates(out_states); | ||
|
||
for (int32_t i = 0; i != n; ++i) { | ||
ss[i]->SetResult(results[i]); | ||
ss[i]->SetNeMoDecoderStates(std::move(next_states[i])); | ||
} | ||
} | ||
|
||
void InitOnlineStream(OnlineStream *stream) const { | ||
auto r = decoder_->GetEmptyResult(); | ||
|
||
stream->SetResult(r); | ||
stream->SetNeMoDecoderStates(model_->GetDecoderInitStates(batch_size_)); | ||
sangeet2020 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
private: | ||
void PostInit() { | ||
config_.feat_config.nemo_normalize_type = | ||
model_->FeatureNormalizationMethod(); | ||
|
||
config_.feat_config.low_freq = 0; | ||
// config_.feat_config.high_freq = 8000; | ||
config_.feat_config.is_librosa = true; | ||
config_.feat_config.remove_dc_offset = false; | ||
// config_.feat_config.window_type = "hann"; | ||
config_.feat_config.dither = 0; | ||
config_.feat_config.nemo_normalize_type = | ||
model_->FeatureNormalizationMethod(); | ||
|
||
int32_t vocab_size = model_->VocabSize(); | ||
|
||
// check the blank ID | ||
if (!symbol_table_.Contains("<blk>")) { | ||
SHERPA_ONNX_LOGE("tokens.txt does not include the blank token <blk>"); | ||
exit(-1); | ||
} | ||
|
||
if (symbol_table_["<blk>"] != vocab_size - 1) { | ||
SHERPA_ONNX_LOGE("<blk> is not the last token!"); | ||
exit(-1); | ||
} | ||
|
||
if (symbol_table_.NumSymbols() != vocab_size) { | ||
SHERPA_ONNX_LOGE("number of lines in tokens.txt %d != %d (vocab_size)", | ||
symbol_table_.NumSymbols(), vocab_size); | ||
exit(-1); | ||
} | ||
|
||
} | ||
|
||
private: | ||
OnlineRecognizerConfig config_; | ||
SymbolTable symbol_table_; | ||
std::unique_ptr<OnlineTransducerNeMoModel> model_; | ||
std::unique_ptr<OnlineTransducerGreedySearchNeMoDecoder> decoder_; | ||
|
||
int32_t batch_size_ = 1; | ||
sangeet2020 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
}; | ||
|
||
} // namespace sherpa_onnx | ||
|
||
#endif // SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_ |
Uh oh!
There was an error while loading. Please reload this page.