Skip to content

online-transducer: reset the encoder toghter with 2 previous output symbols (non-blank) #2129

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 24, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions sherpa-onnx/csrc/online-recognizer-transducer-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -382,14 +382,13 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
}
}

// reset encoder states
// s->SetStates(model_->GetEncoderInitStates());

auto r = decoder_->GetEmptyResult();
auto last_result = s->GetResult();
// if last result is not empty, then
// truncate all last hyps and save as the context for next result

if (static_cast<int32_t>(last_result.tokens.size()) > context_size) {
// if last result is not empty, then
// truncate all last hyps and save as the 'ys' context for next result
// (the encoder state buffers are kept)
for (const auto &it : last_result.hyps) {
auto h = it.second;
r.hyps.Add({std::vector<int64_t>(h.ys.end() - context_size,
Expand All @@ -399,6 +398,11 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {

r.tokens = std::vector<int64_t> (last_result.tokens.end() - context_size,
last_result.tokens.end());
} else {
if(config_.reset_encoder) {
// reset encoder states, use blanks as 'ys' context
s->SetStates(model_->GetEncoderInitStates());
}
}

// but reset all contextual biasing graph states to root
Expand Down
7 changes: 6 additions & 1 deletion sherpa-onnx/csrc/online-recognizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,10 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) {
"rule-fars", &rule_fars,
"If not empty, it specifies fst archives for inverse text normalization. "
"If there are multiple archives, they are separated by a comma.");

po->Register("reset-encoder", &reset_encoder,
"True to reset encoder_state on an endpoint after empty segment."
"Done in `Reset()` method, after an endpoint was detected.");
}

bool OnlineRecognizerConfig::Validate() const {
Expand Down Expand Up @@ -198,7 +202,8 @@ std::string OnlineRecognizerConfig::ToString() const {
os << "blank_penalty=" << blank_penalty << ", ";
os << "temperature_scale=" << temperature_scale << ", ";
os << "rule_fsts=\"" << rule_fsts << "\", ";
os << "rule_fars=\"" << rule_fars << "\")";
os << "rule_fars=\"" << rule_fars << "\", ";
os << "reset_encoder=\"" << (reset_encoder ? "True" : "False") << "\")";

return os.str();
}
Expand Down
12 changes: 10 additions & 2 deletions sherpa-onnx/csrc/online-recognizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ struct OnlineRecognizerConfig {
OnlineLMConfig lm_config;
EndpointConfig endpoint_config;
OnlineCtcFstDecoderConfig ctc_fst_decoder_config;

bool enable_endpoint = true;

std::string decoding_method = "greedy_search";
Expand All @@ -101,6 +102,11 @@ struct OnlineRecognizerConfig {
// If there are multiple FST archives, they are applied from left to right.
std::string rule_fars;

// True to reset encoder_state on an endpoint after empty segment.
// Done in `Reset()` method, after an endpoint was detected,
// currently only in `OnlineRecognizerTransducerImpl`.
bool reset_encoder = false;

/// used only for modified_beam_search, if hotwords_buf is non-empty,
/// the hotwords will be loaded from the buffered string instead of from the
/// "hotwords_file"
Expand All @@ -116,7 +122,8 @@ struct OnlineRecognizerConfig {
bool enable_endpoint, const std::string &decoding_method,
int32_t max_active_paths, const std::string &hotwords_file,
float hotwords_score, float blank_penalty, float temperature_scale,
const std::string &rule_fsts, const std::string &rule_fars)
const std::string &rule_fsts, const std::string &rule_fars,
bool reset_encoder)
: feat_config(feat_config),
model_config(model_config),
lm_config(lm_config),
Expand All @@ -130,7 +137,8 @@ struct OnlineRecognizerConfig {
blank_penalty(blank_penalty),
temperature_scale(temperature_scale),
rule_fsts(rule_fsts),
rule_fars(rule_fars) {}
rule_fars(rule_fars),
reset_encoder(reset_encoder) {}

void Register(ParseOptions *po);
bool Validate() const;
Expand Down
5 changes: 3 additions & 2 deletions sherpa-onnx/python/csrc/online-recognizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
const OnlineLMConfig &, const EndpointConfig &,
const OnlineCtcFstDecoderConfig &, bool,
const std::string &, int32_t, const std::string &, float,
float, float, const std::string &, const std::string &>(),
float, float, const std::string &, const std::string &, bool>(),
py::arg("feat_config"), py::arg("model_config"),
py::arg("lm_config") = OnlineLMConfig(),
py::arg("endpoint_config") = EndpointConfig(),
Expand All @@ -67,7 +67,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "",
py::arg("hotwords_score") = 0, py::arg("blank_penalty") = 0.0,
py::arg("temperature_scale") = 2.0, py::arg("rule_fsts") = "",
py::arg("rule_fars") = "")
py::arg("rule_fars") = "", py::arg("reset_encoder"))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
py::arg("rule_fars") = "", py::arg("reset_encoder"))
py::arg("rule_fars") = "", py::arg("reset_encoder") = false)

Can you give it a default value?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay, default value is in place now

.def_readwrite("feat_config", &PyClass::feat_config)
.def_readwrite("model_config", &PyClass::model_config)
.def_readwrite("lm_config", &PyClass::lm_config)
Expand All @@ -82,6 +82,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
.def_readwrite("temperature_scale", &PyClass::temperature_scale)
.def_readwrite("rule_fsts", &PyClass::rule_fsts)
.def_readwrite("rule_fars", &PyClass::rule_fars)
.def_readwrite("reset_encoder", &PyClass::reset_encoder)
.def("__str__", &PyClass::ToString);
}

Expand Down
19 changes: 19 additions & 0 deletions sherpa-onnx/python/csrc/sherpa-onnx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,15 @@ PYBIND11_MODULE(_sherpa_onnx, m) {

#if SHERPA_ONNX_ENABLE_TTS == 1
PybindOfflineTts(&m);
#else
/* Define "empty" TTS sybmbols */
m.attr("OfflineTtsKokoroModelConfig") = py::none();
m.attr("OfflineTtsMatchaModelConfig") = py::none();
m.attr("OfflineTtsModelConfig") = py::none();
m.attr("OfflineTtsVitsModelConfig") = py::none();
m.attr("GeneratedAudio") = py::none();
m.attr("OfflineTtsConfig") = py::none();
m.attr("OfflineTts") = py::none();
#endif

PybindSpeakerEmbeddingExtractor(&m);
Expand All @@ -85,6 +94,16 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
PybindFastClustering(&m);
PybindOfflineSpeakerDiarizationResult(&m);
PybindOfflineSpeakerDiarization(&m);
#else
/* Define "empty" diarization sybmbols */
m.attr("FastClusteringConfig") = py::none();
m.attr("FastClustering") = py::none();
m.attr("OfflineSpeakerDiarizationSegment") = py::none();
m.attr("OfflineSpeakerDiarizationResult") = py::none();
m.attr("OfflineSpeakerSegmentationPyannoteModelConfig") = py::none();
m.attr("OfflineSpeakerSegmentationModelConfig") = py::none();
m.attr("OfflineSpeakerDiarizationConfig") = py::none();
m.attr("OfflineSpeakerDiarization") = py::none();
#endif

PybindAlsa(&m);
Expand Down
6 changes: 6 additions & 0 deletions sherpa-onnx/python/sherpa_onnx/online_recognizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def from_transducer(
lm_scale: float = 0.1,
lm_shallow_fusion: bool = True,
temperature_scale: float = 2.0,
reset_encoder: bool = False,
debug: bool = False,
rule_fsts: str = "",
rule_fars: str = "",
Expand Down Expand Up @@ -162,6 +163,10 @@ def from_transducer(
Temperature scaling for output symbol confidence estiamation.
It affects only confidence values, the decoding uses the original
logits without temperature.
reset_encoder:
True to reset `encoder_state` on an endpoint after empty segment.
Done in `Reset()` method, after an endpoint was detected,
currently only in `OnlineRecognizerTransducerImpl`.
model_type:
Online transducer model type. Valid values are: conformer, lstm,
zipformer, zipformer2. All other values lead to loading the model twice.
Expand Down Expand Up @@ -305,6 +310,7 @@ def from_transducer(
temperature_scale=temperature_scale,
rule_fsts=rule_fsts,
rule_fars=rule_fars,
reset_encoder=reset_encoder,
)

self.recognizer = _Recognizer(recognizer_config)
Expand Down
Loading