Skip to content

Commit a388818

Browse files
authored
Support customize scores for hotwords (#926)
* Support customize scores for hotwords * Skip blank lines
1 parent a689249 commit a388818

File tree

6 files changed

+103
-35
lines changed

6 files changed

+103
-35
lines changed

sherpa-onnx/csrc/context-graph.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,9 @@ class ContextGraph {
6161
}
6262

6363
ContextGraph(const std::vector<std::vector<int32_t>> &token_ids,
64-
float context_score, const std::vector<float> &scores = {},
65-
const std::vector<std::string> &phrases = {})
66-
: ContextGraph(token_ids, context_score, 0.0f, scores, phrases,
67-
std::vector<float>()) {}
64+
float context_score, const std::vector<float> &scores = {})
65+
: ContextGraph(token_ids, context_score, 0.0f, scores,
66+
std::vector<std::string>(), std::vector<float>()) {}
6867

6968
std::tuple<float, const ContextState *, const ContextState *> ForwardOneStep(
7069
const ContextState *state, int32_t token_id,

sherpa-onnx/csrc/offline-recognizer-transducer-impl.h

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -145,15 +145,35 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
145145
auto hws = std::regex_replace(hotwords, std::regex("/"), "\n");
146146
std::istringstream is(hws);
147147
std::vector<std::vector<int32_t>> current;
148+
std::vector<float> current_scores;
148149
if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_,
149-
bpe_encoder_.get(), &current)) {
150+
bpe_encoder_.get(), &current, &current_scores)) {
150151
SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s",
151152
hotwords.c_str());
152153
}
154+
155+
int32_t num_default_hws = hotwords_.size();
156+
int32_t num_hws = current.size();
157+
153158
current.insert(current.end(), hotwords_.begin(), hotwords_.end());
154159

155-
auto context_graph =
156-
std::make_shared<ContextGraph>(current, config_.hotwords_score);
160+
if (!current_scores.empty() && !boost_scores_.empty()) {
161+
current_scores.insert(current_scores.end(), boost_scores_.begin(),
162+
boost_scores_.end());
163+
} else if (!current_scores.empty() && boost_scores_.empty()) {
164+
current_scores.insert(current_scores.end(), num_default_hws,
165+
config_.hotwords_score);
166+
} else if (current_scores.empty() && !boost_scores_.empty()) {
167+
current_scores.insert(current_scores.end(), num_hws,
168+
config_.hotwords_score);
169+
current_scores.insert(current_scores.end(), boost_scores_.begin(),
170+
boost_scores_.end());
171+
} else {
172+
// Do nothing.
173+
}
174+
175+
auto context_graph = std::make_shared<ContextGraph>(
176+
current, config_.hotwords_score, current_scores);
157177
return std::make_unique<OfflineStream>(config_.feat_config, context_graph);
158178
}
159179

@@ -226,13 +246,13 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
226246
}
227247

228248
if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_,
229-
bpe_encoder_.get(), &hotwords_)) {
249+
bpe_encoder_.get(), &hotwords_, &boost_scores_)) {
230250
SHERPA_ONNX_LOGE(
231251
"Failed to encode some hotwords, skip them already, see logs above "
232252
"for details.");
233253
}
234-
hotwords_graph_ =
235-
std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score);
254+
hotwords_graph_ = std::make_shared<ContextGraph>(
255+
hotwords_, config_.hotwords_score, boost_scores_);
236256
}
237257

238258
#if __ANDROID_API__ >= 9
@@ -250,20 +270,21 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
250270
}
251271

252272
if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_,
253-
bpe_encoder_.get(), &hotwords_)) {
273+
bpe_encoder_.get(), &hotwords_, &boost_scores_)) {
254274
SHERPA_ONNX_LOGE(
255275
"Failed to encode some hotwords, skip them already, see logs above "
256276
"for details.");
257277
}
258-
hotwords_graph_ =
259-
std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score);
278+
hotwords_graph_ = std::make_shared<ContextGraph>(
279+
hotwords_, config_.hotwords_score, boost_scores_);
260280
}
261281
#endif
262282

263283
private:
264284
OfflineRecognizerConfig config_;
265285
SymbolTable symbol_table_;
266286
std::vector<std::vector<int32_t>> hotwords_;
287+
std::vector<float> boost_scores_;
267288
ContextGraphPtr hotwords_graph_;
268289
std::unique_ptr<ssentencepiece::Ssentencepiece> bpe_encoder_;
269290
std::unique_ptr<OfflineTransducerModel> model_;

sherpa-onnx/csrc/online-recognizer-transducer-impl.h

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -182,14 +182,35 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
182182
auto hws = std::regex_replace(hotwords, std::regex("/"), "\n");
183183
std::istringstream is(hws);
184184
std::vector<std::vector<int32_t>> current;
185+
std::vector<float> current_scores;
185186
if (!EncodeHotwords(is, config_.model_config.modeling_unit, sym_,
186-
bpe_encoder_.get(), &current)) {
187+
bpe_encoder_.get(), &current, &current_scores)) {
187188
SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s",
188189
hotwords.c_str());
189190
}
191+
192+
int32_t num_default_hws = hotwords_.size();
193+
int32_t num_hws = current.size();
194+
190195
current.insert(current.end(), hotwords_.begin(), hotwords_.end());
191-
auto context_graph =
192-
std::make_shared<ContextGraph>(current, config_.hotwords_score);
196+
197+
if (!current_scores.empty() && !boost_scores_.empty()) {
198+
current_scores.insert(current_scores.end(), boost_scores_.begin(),
199+
boost_scores_.end());
200+
} else if (!current_scores.empty() && boost_scores_.empty()) {
201+
current_scores.insert(current_scores.end(), num_default_hws,
202+
config_.hotwords_score);
203+
} else if (current_scores.empty() && !boost_scores_.empty()) {
204+
current_scores.insert(current_scores.end(), num_hws,
205+
config_.hotwords_score);
206+
current_scores.insert(current_scores.end(), boost_scores_.begin(),
207+
boost_scores_.end());
208+
} else {
209+
// Do nothing.
210+
}
211+
212+
auto context_graph = std::make_shared<ContextGraph>(
213+
current, config_.hotwords_score, current_scores);
193214
auto stream =
194215
std::make_unique<OnlineStream>(config_.feat_config, context_graph);
195216
InitOnlineStream(stream.get());
@@ -376,13 +397,13 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
376397
}
377398

378399
if (!EncodeHotwords(is, config_.model_config.modeling_unit, sym_,
379-
bpe_encoder_.get(), &hotwords_)) {
400+
bpe_encoder_.get(), &hotwords_, &boost_scores_)) {
380401
SHERPA_ONNX_LOGE(
381402
"Failed to encode some hotwords, skip them already, see logs above "
382403
"for details.");
383404
}
384-
hotwords_graph_ =
385-
std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score);
405+
hotwords_graph_ = std::make_shared<ContextGraph>(
406+
hotwords_, config_.hotwords_score, boost_scores_);
386407
}
387408

388409
#if __ANDROID_API__ >= 9
@@ -400,13 +421,13 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
400421
}
401422

402423
if (!EncodeHotwords(is, config_.model_config.modeling_unit, sym_,
403-
bpe_encoder_.get(), &hotwords_)) {
424+
bpe_encoder_.get(), &hotwords_, &boost_scores_)) {
404425
SHERPA_ONNX_LOGE(
405426
"Failed to encode some hotwords, skip them already, see logs above "
406427
"for details.");
407428
}
408-
hotwords_graph_ =
409-
std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score);
429+
hotwords_graph_ = std::make_shared<ContextGraph>(
430+
hotwords_, config_.hotwords_score, boost_scores_);
410431
}
411432
#endif
412433

@@ -428,6 +449,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
428449
private:
429450
OnlineRecognizerConfig config_;
430451
std::vector<std::vector<int32_t>> hotwords_;
452+
std::vector<float> boost_scores_;
431453
ContextGraphPtr hotwords_graph_;
432454
std::unique_ptr<ssentencepiece::Ssentencepiece> bpe_encoder_;
433455
std::unique_ptr<OnlineTransducerModel> model_;

sherpa-onnx/csrc/text2token-test.cc

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,17 +35,21 @@ TEST(TEXT2TOKEN, TEST_cjkchar) {
3535

3636
auto sym_table = SymbolTable(tokens);
3737

38-
std::string text = "世界人民大团结\n中国 V S 美国";
38+
std::string text =
39+
"世界人民大团结\n中国 V S 美国\n\n"; // Test blank lines also
3940

4041
std::istringstream iss(text);
4142

4243
std::vector<std::vector<int32_t>> ids;
44+
std::vector<float> scores;
4345

44-
auto r = EncodeHotwords(iss, "cjkchar", sym_table, nullptr, &ids);
46+
auto r = EncodeHotwords(iss, "cjkchar", sym_table, nullptr, &ids, &scores);
4547

4648
std::vector<std::vector<int32_t>> expected_ids(
4749
{{379, 380, 72, 874, 93, 1251, 489}, {262, 147, 3423, 2476, 21, 147}});
4850
EXPECT_EQ(ids, expected_ids);
51+
52+
EXPECT_EQ(scores.size(), 0);
4953
}
5054

5155
TEST(TEXT2TOKEN, TEST_bpe) {
@@ -68,17 +72,22 @@ TEST(TEXT2TOKEN, TEST_bpe) {
6872
auto sym_table = SymbolTable(tokens);
6973
auto bpe_processor = std::make_unique<ssentencepiece::Ssentencepiece>(bpe);
7074

71-
std::string text = "HELLO WORLD\nI LOVE YOU";
75+
std::string text = "HELLO WORLD\nI LOVE YOU :2.0";
7276

7377
std::istringstream iss(text);
7478

7579
std::vector<std::vector<int32_t>> ids;
80+
std::vector<float> scores;
7681

77-
auto r = EncodeHotwords(iss, "bpe", sym_table, bpe_processor.get(), &ids);
82+
auto r =
83+
EncodeHotwords(iss, "bpe", sym_table, bpe_processor.get(), &ids, &scores);
7884

7985
std::vector<std::vector<int32_t>> expected_ids(
8086
{{22, 58, 24, 425}, {19, 370, 47}});
8187
EXPECT_EQ(ids, expected_ids);
88+
89+
std::vector<float> expected_scores({0, 2.0});
90+
EXPECT_EQ(scores, expected_scores);
8291
}
8392

8493
TEST(TEXT2TOKEN, TEST_cjkchar_bpe) {
@@ -101,19 +110,23 @@ TEST(TEXT2TOKEN, TEST_cjkchar_bpe) {
101110
auto sym_table = SymbolTable(tokens);
102111
auto bpe_processor = std::make_unique<ssentencepiece::Ssentencepiece>(bpe);
103112

104-
std::string text = "世界人民 GOES TOGETHER\n中国 GOES WITH 美国";
113+
std::string text = "世界人民 GOES TOGETHER :1.5\n中国 GOES WITH 美国 :0.5";
105114

106115
std::istringstream iss(text);
107116

108117
std::vector<std::vector<int32_t>> ids;
118+
std::vector<float> scores;
109119

110-
auto r =
111-
EncodeHotwords(iss, "cjkchar+bpe", sym_table, bpe_processor.get(), &ids);
120+
auto r = EncodeHotwords(iss, "cjkchar+bpe", sym_table, bpe_processor.get(),
121+
&ids, &scores);
112122

113123
std::vector<std::vector<int32_t>> expected_ids(
114124
{{1368, 1392, 557, 680, 275, 178, 475},
115125
{685, 736, 275, 178, 179, 921, 736}});
116126
EXPECT_EQ(ids, expected_ids);
127+
128+
std::vector<float> expected_scores({1.5, 0.5});
129+
EXPECT_EQ(scores, expected_scores);
117130
}
118131

119132
TEST(TEXT2TOKEN, TEST_bbpe) {
@@ -136,17 +149,22 @@ TEST(TEXT2TOKEN, TEST_bbpe) {
136149
auto sym_table = SymbolTable(tokens);
137150
auto bpe_processor = std::make_unique<ssentencepiece::Ssentencepiece>(bpe);
138151

139-
std::string text = "频繁\n李鞑靼";
152+
std::string text = "频繁 :1.0\n李鞑靼";
140153

141154
std::istringstream iss(text);
142155

143156
std::vector<std::vector<int32_t>> ids;
157+
std::vector<float> scores;
144158

145-
auto r = EncodeHotwords(iss, "bpe", sym_table, bpe_processor.get(), &ids);
159+
auto r =
160+
EncodeHotwords(iss, "bpe", sym_table, bpe_processor.get(), &ids, &scores);
146161

147162
std::vector<std::vector<int32_t>> expected_ids(
148163
{{259, 1118, 234, 188, 132}, {259, 1585, 236, 161, 148, 236, 160, 191}});
149164
EXPECT_EQ(ids, expected_ids);
165+
166+
std::vector<float> expected_scores({1.0, 0});
167+
EXPECT_EQ(scores, expected_scores);
150168
}
151169

152170
} // namespace sherpa_onnx

sherpa-onnx/csrc/utils.cc

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,8 @@ static bool EncodeBase(const std::vector<std::string> &lines,
103103
bool EncodeHotwords(std::istream &is, const std::string &modeling_unit,
104104
const SymbolTable &symbol_table,
105105
const ssentencepiece::Ssentencepiece *bpe_encoder,
106-
std::vector<std::vector<int32_t>> *hotwords) {
106+
std::vector<std::vector<int32_t>> *hotwords,
107+
std::vector<float> *boost_scores) {
107108
std::vector<std::string> lines;
108109
std::string line;
109110
std::string word;
@@ -131,7 +132,12 @@ bool EncodeHotwords(std::istream &is, const std::string &modeling_unit,
131132
break;
132133
}
133134
}
134-
phrase = oss.str().substr(1);
135+
phrase = oss.str();
136+
if (phrase.empty()) {
137+
continue;
138+
} else {
139+
phrase = phrase.substr(1);
140+
}
135141
std::istringstream piss(phrase);
136142
oss.clear();
137143
oss.str("");
@@ -177,7 +183,8 @@ bool EncodeHotwords(std::istream &is, const std::string &modeling_unit,
177183
}
178184
lines.push_back(oss.str());
179185
}
180-
return EncodeBase(lines, symbol_table, hotwords, nullptr, nullptr, nullptr);
186+
return EncodeBase(lines, symbol_table, hotwords, nullptr, boost_scores,
187+
nullptr);
181188
}
182189

183190
bool EncodeKeywords(std::istream &is, const SymbolTable &symbol_table,

sherpa-onnx/csrc/utils.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ namespace sherpa_onnx {
2929
bool EncodeHotwords(std::istream &is, const std::string &modeling_unit,
3030
const SymbolTable &symbol_table,
3131
const ssentencepiece::Ssentencepiece *bpe_encoder_,
32-
std::vector<std::vector<int32_t>> *hotwords_id);
32+
std::vector<std::vector<int32_t>> *hotwords_id,
33+
std::vector<float> *boost_scores);
3334

3435
/* Encode the keywords in an input stream to be tokens ids.
3536
*

0 commit comments

Comments
 (0)