Skip to content

Commit 983df28

Browse files
authored
Fix a punctuation bug (#764)
1 parent b6ad043 commit 983df28

File tree

2 files changed

+10
-14
lines changed

2 files changed

+10
-14
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
22
project(sherpa-onnx)
33

4-
set(SHERPA_ONNX_VERSION "1.9.18")
4+
set(SHERPA_ONNX_VERSION "1.9.19")
55

66
# Disable warning about
77
#

sherpa-onnx/csrc/offline-punctuation-ct-transformer-impl.h

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl {
9898
int32_t dot_index = -1;
9999
int32_t comma_index = -1;
100100

101-
for (int32_t m = this_punctuations.size() - 1; m >= 1; --m) {
101+
for (int32_t m = this_punctuations.size() - 2; m >= 1; --m) {
102102
int32_t punct_id = this_punctuations[m];
103103

104104
if (punct_id == meta_data.dot_id || punct_id == meta_data.quest_id) {
@@ -126,27 +126,20 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl {
126126
}
127127
} else {
128128
last = this_start + dot_index + 1;
129+
}
129130

131+
if (dot_index != 1) {
130132
punctuations.insert(punctuations.end(), this_punctuations.begin(),
131133
this_punctuations.begin() + (dot_index + 1));
132134
}
133135
} // for (int32_t i = 0; i != num_segments; ++i)
134136

135-
if (punctuations.size() != token_ids.size() &&
136-
punctuations.size() + 1 == token_ids.size()) {
137-
punctuations.push_back(meta_data.dot_id);
138-
}
139-
140-
if (punctuations.size() != token_ids.size()) {
141-
SHERPA_ONNX_LOGE("%s, %d, %d. Some unexpected things happened",
142-
text.c_str(), static_cast<int32_t>(punctuations.size()),
143-
static_cast<int32_t>(token_ids.size()));
144-
return text;
145-
}
146-
147137
std::string ans;
148138

149139
for (int32_t i = 0; i != static_cast<int32_t>(punctuations.size()); ++i) {
140+
if (i > tokens.size()) {
141+
break;
142+
}
150143
const std::string &w = tokens[i];
151144
if (i > 0 && !(ans.back() & 0x80) && !(w[0] & 0x80)) {
152145
ans.push_back(' ');
@@ -156,6 +149,9 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl {
156149
ans.append(meta_data.id2punct[punctuations[i]]);
157150
}
158151
}
152+
if (ans.back() != meta_data.dot_id && ans.back() != meta_data.quest_id) {
153+
ans.push_back(meta_data.dot_id);
154+
}
159155

160156
return ans;
161157
}

0 commit comments

Comments
 (0)