Skip to content

Commit 675fb15

Browse files
Chung-Ichungyi.li
and
chungyi.li
authored
offline transducer: treat unk as blank (#1005)
Co-authored-by: chungyi.li <chungyi.li@ailabs.tw>
1 parent a11c859 commit 675fb15

5 files changed

+25
-9
lines changed

sherpa-onnx/csrc/offline-recognizer-transducer-impl.h

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,13 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
7878
config_(config),
7979
symbol_table_(config_.model_config.tokens),
8080
model_(std::make_unique<OfflineTransducerModel>(config_.model_config)) {
81+
if (symbol_table_.Contains("<unk>")) {
82+
unk_id_ = symbol_table_["<unk>"];
83+
}
84+
8185
if (config_.decoding_method == "greedy_search") {
8286
decoder_ = std::make_unique<OfflineTransducerGreedySearchDecoder>(
83-
model_.get(), config_.blank_penalty);
87+
model_.get(), unk_id_, config_.blank_penalty);
8488
} else if (config_.decoding_method == "modified_beam_search") {
8589
if (!config_.lm_config.model.empty()) {
8690
lm_ = OfflineLM::Create(config.lm_config);
@@ -97,7 +101,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
97101

98102
decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>(
99103
model_.get(), lm_.get(), config_.max_active_paths,
100-
config_.lm_config.scale, config_.blank_penalty);
104+
config_.lm_config.scale, unk_id_, config_.blank_penalty);
101105
} else {
102106
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
103107
config_.decoding_method.c_str());
@@ -113,9 +117,13 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
113117
symbol_table_(mgr, config_.model_config.tokens),
114118
model_(std::make_unique<OfflineTransducerModel>(mgr,
115119
config_.model_config)) {
120+
if (symbol_table_.Contains("<unk>")) {
121+
unk_id_ = symbol_table_["<unk>"];
122+
}
123+
116124
if (config_.decoding_method == "greedy_search") {
117125
decoder_ = std::make_unique<OfflineTransducerGreedySearchDecoder>(
118-
model_.get(), config_.blank_penalty);
126+
model_.get(), unk_id_, config_.blank_penalty);
119127
} else if (config_.decoding_method == "modified_beam_search") {
120128
if (!config_.lm_config.model.empty()) {
121129
lm_ = OfflineLM::Create(mgr, config.lm_config);
@@ -133,7 +141,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
133141

134142
decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>(
135143
model_.get(), lm_.get(), config_.max_active_paths,
136-
config_.lm_config.scale, config_.blank_penalty);
144+
config_.lm_config.scale, unk_id_, config_.blank_penalty);
137145
} else {
138146
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
139147
config_.decoding_method.c_str());
@@ -293,6 +301,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
293301
std::unique_ptr<OfflineTransducerModel> model_;
294302
std::unique_ptr<OfflineTransducerDecoder> decoder_;
295303
std::unique_ptr<OfflineLM> lm_;
304+
int32_t unk_id_ = -1;
296305
};
297306

298307
} // namespace sherpa_onnx

sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,9 @@ OfflineTransducerGreedySearchDecoder::Decode(Ort::Value encoder_out,
5757
std::max_element(static_cast<const float *>(p_logit),
5858
static_cast<const float *>(p_logit) + vocab_size)));
5959
p_logit += vocab_size;
60-
if (y != 0) {
60+
// blank id is hardcoded to 0
61+
// also, it treats unk as blank
62+
if (y != 0 && y != unk_id_) {
6163
ans[i].tokens.push_back(y);
6264
ans[i].timestamps.push_back(t);
6365
emitted = true;

sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,17 @@ namespace sherpa_onnx {
1515
class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder {
1616
public:
1717
OfflineTransducerGreedySearchDecoder(OfflineTransducerModel *model,
18+
int32_t unk_id,
1819
float blank_penalty)
19-
: model_(model), blank_penalty_(blank_penalty) {}
20+
: model_(model), unk_id_(unk_id), blank_penalty_(blank_penalty) {}
2021

2122
std::vector<OfflineTransducerDecoderResult> Decode(
2223
Ort::Value encoder_out, Ort::Value encoder_out_length,
2324
OfflineStream **ss = nullptr, int32_t n = 0) override;
2425

2526
private:
2627
OfflineTransducerModel *model_; // Not owned
28+
int32_t unk_id_;
2729
float blank_penalty_;
2830
};
2931

sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,9 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode(
131131

132132
float context_score = 0;
133133
auto context_state = new_hyp.context_state;
134-
if (new_token != 0) {
135-
// blank id is fixed to 0
134+
// blank is hardcoded to 0
135+
// also, it treats unk as blank
136+
if (new_token != 0 && new_token != unk_id_) {
136137
new_hyp.ys.push_back(new_token);
137138
new_hyp.timestamps.push_back(t);
138139
if (context_graphs[i] != nullptr) {

sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,13 @@ class OfflineTransducerModifiedBeamSearchDecoder
1919
OfflineTransducerModifiedBeamSearchDecoder(OfflineTransducerModel *model,
2020
OfflineLM *lm,
2121
int32_t max_active_paths,
22-
float lm_scale,
22+
float lm_scale, int32_t unk_id,
2323
float blank_penalty)
2424
: model_(model),
2525
lm_(lm),
2626
max_active_paths_(max_active_paths),
2727
lm_scale_(lm_scale),
28+
unk_id_(unk_id),
2829
blank_penalty_(blank_penalty) {}
2930

3031
std::vector<OfflineTransducerDecoderResult> Decode(
@@ -37,6 +38,7 @@ class OfflineTransducerModifiedBeamSearchDecoder
3738

3839
int32_t max_active_paths_;
3940
float lm_scale_; // used only when lm_ is not nullptr
41+
int32_t unk_id_;
4042
float blank_penalty_;
4143
};
4244

0 commit comments

Comments
 (0)