Skip to content

Fix nemo streaming transducer greedy search #944

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions .github/scripts/test-online-transducer.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,45 @@ echo "PATH: $PATH"

which $EXE

log "------------------------------------------------------------"
log "Run NeMo transducer (English)"
log "------------------------------------------------------------"
repo_url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-streaming-fast-conformer-transducer-en-80ms.tar.bz2
curl -SL -O $repo_url
tar xvf sherpa-onnx-nemo-streaming-fast-conformer-transducer-en-80ms.tar.bz2
rm sherpa-onnx-nemo-streaming-fast-conformer-transducer-en-80ms.tar.bz2
repo=sherpa-onnx-nemo-streaming-fast-conformer-transducer-en-80ms

log "Start testing ${repo_url}"

waves=(
$repo/test_wavs/0.wav
$repo/test_wavs/1.wav
$repo/test_wavs/8k.wav
)

for wave in ${waves[@]}; do
time $EXE \
--tokens=$repo/tokens.txt \
--encoder=$repo/encoder.onnx \
--decoder=$repo/decoder.onnx \
--joiner=$repo/joiner.onnx \
--num-threads=2 \
$wave
done

time $EXE \
--tokens=$repo/tokens.txt \
--encoder=$repo/encoder.onnx \
--decoder=$repo/decoder.onnx \
--joiner=$repo/joiner.onnx \
--num-threads=2 \
$repo/test_wavs/0.wav \
$repo/test_wavs/1.wav \
$repo/test_wavs/8k.wav

rm -rf $repo

log "------------------------------------------------------------"
log "Run LSTM transducer (English)"
log "------------------------------------------------------------"
Expand Down
1 change: 0 additions & 1 deletion .github/workflows/aarch64-linux-gnu-shared.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,6 @@ jobs:
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface

cd huggingface
git lfs pull
mkdir -p aarch64

cp -v ../sherpa-onnx-*-shared.tar.bz2 ./aarch64
Expand Down
1 change: 0 additions & 1 deletion .github/workflows/aarch64-linux-gnu-static.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,6 @@ jobs:
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface

cd huggingface
git lfs pull
mkdir -p aarch64

cp -v ../sherpa-onnx-*-static.tar.bz2 ./aarch64
Expand Down
1 change: 0 additions & 1 deletion .github/workflows/android.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,6 @@ jobs:
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface

cd huggingface
git lfs pull

cp -v ../sherpa-onnx-*-android.tar.bz2 ./

Expand Down
1 change: 0 additions & 1 deletion .github/workflows/arm-linux-gnueabihf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,6 @@ jobs:
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface

cd huggingface
git lfs pull
mkdir -p arm32

cp -v ../sherpa-onnx-*.tar.bz2 ./arm32
Expand Down
1 change: 0 additions & 1 deletion .github/workflows/build-xcframework.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ jobs:
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface

cd huggingface
git lfs pull

cp -v ../sherpa-onnx-*.tar.bz2 ./

Expand Down
1 change: 0 additions & 1 deletion .github/workflows/riscv64-linux.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,6 @@ jobs:
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface

cd huggingface
git lfs pull
mkdir -p riscv64

cp -v ../sherpa-onnx-*-shared.tar.bz2 ./riscv64
Expand Down
1 change: 0 additions & 1 deletion .github/workflows/windows-x64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,6 @@ jobs:
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface

cd huggingface
git lfs pull
mkdir -p win64

cp -v ../sherpa-onnx-*.tar.bz2 ./win64
Expand Down
1 change: 0 additions & 1 deletion .github/workflows/windows-x86.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,6 @@ jobs:
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface

cd huggingface
git lfs pull
mkdir -p win32

cp -v ../sherpa-onnx-*.tar.bz2 ./win32
Expand Down
20 changes: 10 additions & 10 deletions sherpa-onnx/csrc/online-recognizer-impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,18 @@ namespace sherpa_onnx {

std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
const OnlineRecognizerConfig &config) {

if (!config.model_config.transducer.encoder.empty()) {
Ort::Env env(ORT_LOGGING_LEVEL_WARNING);

auto decoder_model = ReadFile(config.model_config.transducer.decoder);
auto sess = std::make_unique<Ort::Session>(env, decoder_model.data(), decoder_model.size(), Ort::SessionOptions{});

auto sess = std::make_unique<Ort::Session>(
env, decoder_model.data(), decoder_model.size(), Ort::SessionOptions{});

size_t node_count = sess->GetOutputCount();

if (node_count == 1) {
return std::make_unique<OnlineRecognizerTransducerImpl>(config);
} else {
SHERPA_ONNX_LOGE("Running streaming Nemo transducer model");
return std::make_unique<OnlineRecognizerTransducerNeMoImpl>(config);
}
}
Expand All @@ -50,12 +49,13 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
AAssetManager *mgr, const OnlineRecognizerConfig &config) {
if (!config.model_config.transducer.encoder.empty()) {
Ort::Env env(ORT_LOGGING_LEVEL_WARNING);

auto decoder_model = ReadFile(mgr, config.model_config.transducer.decoder);
auto sess = std::make_unique<Ort::Session>(env, decoder_model.data(), decoder_model.size(), Ort::SessionOptions{});

auto sess = std::make_unique<Ort::Session>(
env, decoder_model.data(), decoder_model.size(), Ort::SessionOptions{});

size_t node_count = sess->GetOutputCount();

if (node_count == 1) {
return std::make_unique<OnlineRecognizerTransducerImpl>(mgr, config);
} else {
Expand Down
11 changes: 4 additions & 7 deletions sherpa-onnx/csrc/online-recognizer-transducer-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,15 @@

namespace sherpa_onnx {

static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
const SymbolTable &sym_table,
float frame_shift_ms,
int32_t subsampling_factor,
int32_t segment,
int32_t frames_since_start) {
OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
const SymbolTable &sym_table,
float frame_shift_ms, int32_t subsampling_factor,
int32_t segment, int32_t frames_since_start) {
OnlineRecognizerResult r;
r.tokens.reserve(src.tokens.size());
r.timestamps.reserve(src.tokens.size());

for (auto i : src.tokens) {
if (i == -1) continue;
auto sym = sym_table[i];

r.text.append(sym);
Expand Down
94 changes: 37 additions & 57 deletions sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#ifndef SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_
#define SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_

#include <algorithm>
#include <fstream>
#include <ios>
#include <memory>
Expand All @@ -32,23 +33,20 @@
namespace sherpa_onnx {

// defined in ./online-recognizer-transducer-impl.h
// static may or may not be here? TODDOs
static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
const SymbolTable &sym_table,
float frame_shift_ms,
int32_t subsampling_factor,
int32_t segment,
int32_t frames_since_start);
OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
const SymbolTable &sym_table,
float frame_shift_ms, int32_t subsampling_factor,
int32_t segment, int32_t frames_since_start);

class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
public:
public:
explicit OnlineRecognizerTransducerNeMoImpl(
const OnlineRecognizerConfig &config)
: config_(config),
symbol_table_(config.model_config.tokens),
endpoint_(config_.endpoint_config),
model_(std::make_unique<OnlineTransducerNeMoModel>(
config.model_config)) {
model_(
std::make_unique<OnlineTransducerNeMoModel>(config.model_config)) {
if (config.decoding_method == "greedy_search") {
decoder_ = std::make_unique<OnlineTransducerGreedySearchNeMoDecoder>(
model_.get(), config_.blank_penalty);
Expand All @@ -73,7 +71,7 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
model_.get(), config_.blank_penalty);
} else {
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
config.decoding_method.c_str());
config.decoding_method.c_str());
exit(-1);
}

Expand All @@ -83,7 +81,6 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {

std::unique_ptr<OnlineStream> CreateStream() const override {
auto stream = std::make_unique<OnlineStream>(config_.feat_config);
stream->SetStates(model_->GetInitStates());
InitOnlineStream(stream.get());
return stream;
}
Expand All @@ -94,14 +91,12 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
}

OnlineRecognizerResult GetResult(OnlineStream *s) const override {
OnlineTransducerDecoderResult decoder_result = s->GetResult();
decoder_->StripLeadingBlanks(&decoder_result);

// TODO(fangjun): Remember to change these constants if needed
int32_t frame_shift_ms = 10;
int32_t subsampling_factor = 8;
return Convert(decoder_result, symbol_table_, frame_shift_ms, subsampling_factor,
s->GetCurrentSegment(), s->GetNumFramesSinceStart());
int32_t subsampling_factor = model_->SubsamplingFactor();
return Convert(s->GetResult(), symbol_table_, frame_shift_ms,
subsampling_factor, s->GetCurrentSegment(),
s->GetNumFramesSinceStart());
}

bool IsEndpoint(OnlineStream *s) const override {
Expand All @@ -114,8 +109,8 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
// frame shift is 10 milliseconds
float frame_shift_in_seconds = 0.01;

// subsampling factor is 8
int32_t trailing_silence_frames = s->GetResult().num_trailing_blanks * 8;
int32_t trailing_silence_frames =
s->GetResult().num_trailing_blanks * model_->SubsamplingFactor();

return endpoint_.IsEndpoint(num_processed_frames, trailing_silence_frames,
frame_shift_in_seconds);
Expand All @@ -126,19 +121,16 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
// segment is incremented only when the last
// result is not empty
const auto &r = s->GetResult();
if (!r.tokens.empty() && r.tokens.back() != 0) {
if (!r.tokens.empty()) {
s->GetCurrentSegment() += 1;
}
}

// we keep the decoder_out
decoder_->UpdateDecoderOut(&s->GetResult());
Ort::Value decoder_out = std::move(s->GetResult().decoder_out);
s->SetResult({});

s->SetStates(model_->GetEncoderInitStates());

auto r = decoder_->GetEmptyResult();

s->SetResult(r);
s->GetResult().decoder_out = std::move(decoder_out);
s->SetNeMoDecoderStates(model_->GetDecoderInitStates());

// Note: We only update counters. The underlying audio samples
// are not discarded.
Expand All @@ -151,10 +143,9 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {

int32_t feature_dim = ss[0]->FeatureDim();

std::vector<OnlineTransducerDecoderResult> result(n);
std::vector<float> features_vec(n * chunk_size * feature_dim);
std::vector<std::vector<Ort::Value>> encoder_states(n);

for (int32_t i = 0; i != n; ++i) {
const auto num_processed_frames = ss[i]->GetNumProcessedFrames();
std::vector<float> features =
Expand All @@ -166,9 +157,7 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
std::copy(features.begin(), features.end(),
features_vec.data() + i * chunk_size * feature_dim);

result[i] = std::move(ss[i]->GetResult());
encoder_states[i] = std::move(ss[i]->GetStates());

}

auto memory_info =
Expand All @@ -180,42 +169,35 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
features_vec.size(), x_shape.data(),
x_shape.size());

// Batch size is 1
auto states = std::move(encoder_states[0]);
int32_t num_states = states.size(); // num_states = 3
auto states = model_->StackStates(std::move(encoder_states));
int32_t num_states = states.size(); // num_states = 3
auto t = model_->RunEncoder(std::move(x), std::move(states));
// t[0] encoder_out, float tensor, (batch_size, dim, T)
// t[1] next states

std::vector<Ort::Value> out_states;
out_states.reserve(num_states);

for (int32_t k = 1; k != num_states + 1; ++k) {
out_states.push_back(std::move(t[k]));
}

auto unstacked_states = model_->UnStackStates(std::move(out_states));
for (int32_t i = 0; i != n; ++i) {
ss[i]->SetStates(std::move(unstacked_states[i]));
}

Ort::Value encoder_out = Transpose12(model_->Allocator(), &t[0]);

// defined in online-transducer-greedy-search-nemo-decoder.h
// get intial states of decoder.
std::vector<Ort::Value> &decoder_states = ss[0]->GetNeMoDecoderStates();

// Subsequent decoder states (for each chunks) are updated inside the Decode method.
// This returns the decoder state from the LAST chunk. We probably dont need it. So we can discard it.
decoder_states = decoder_->Decode(std::move(encoder_out),
std::move(decoder_states),
&result, ss, n);

ss[0]->SetResult(result[0]);

ss[0]->SetStates(std::move(out_states));

decoder_->Decode(std::move(encoder_out), ss, n);
}

void InitOnlineStream(OnlineStream *stream) const {
auto r = decoder_->GetEmptyResult();
// set encoder states
stream->SetStates(model_->GetEncoderInitStates());

stream->SetResult(r);
stream->SetNeMoDecoderStates(model_->GetDecoderInitStates(1));
// set decoder states
stream->SetNeMoDecoderStates(model_->GetDecoderInitStates());
}

private:
Expand Down Expand Up @@ -250,7 +232,6 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
symbol_table_.NumSymbols(), vocab_size);
exit(-1);
}

}

private:
Expand All @@ -259,9 +240,8 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
std::unique_ptr<OnlineTransducerNeMoModel> model_;
std::unique_ptr<OnlineTransducerGreedySearchNeMoDecoder> decoder_;
Endpoint endpoint_;

};

} // namespace sherpa_onnx

#endif // SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_
#endif // SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_
3 changes: 2 additions & 1 deletion sherpa-onnx/csrc/online-stream.cc
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,8 @@ std::vector<Ort::Value> &OnlineStream::GetStates() {
return impl_->GetStates();
}

void OnlineStream::SetNeMoDecoderStates(std::vector<Ort::Value> decoder_states) {
void OnlineStream::SetNeMoDecoderStates(
std::vector<Ort::Value> decoder_states) {
return impl_->SetNeMoDecoderStates(std::move(decoder_states));
}

Expand Down
Loading
Loading