@@ -98,7 +98,7 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl {
98
98
int32_t dot_index = -1 ;
99
99
int32_t comma_index = -1 ;
100
100
101
- for (int32_t m = this_punctuations.size () - 1 ; m >= 1 ; --m) {
101
+ for (int32_t m = this_punctuations.size () - 2 ; m >= 1 ; --m) {
102
102
int32_t punct_id = this_punctuations[m];
103
103
104
104
if (punct_id == meta_data.dot_id || punct_id == meta_data.quest_id ) {
@@ -126,27 +126,20 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl {
126
126
}
127
127
} else {
128
128
last = this_start + dot_index + 1 ;
129
+ }
129
130
131
+ if (dot_index != 1 ) {
130
132
punctuations.insert (punctuations.end (), this_punctuations.begin (),
131
133
this_punctuations.begin () + (dot_index + 1 ));
132
134
}
133
135
} // for (int32_t i = 0; i != num_segments; ++i)
134
136
135
- if (punctuations.size () != token_ids.size () &&
136
- punctuations.size () + 1 == token_ids.size ()) {
137
- punctuations.push_back (meta_data.dot_id );
138
- }
139
-
140
- if (punctuations.size () != token_ids.size ()) {
141
- SHERPA_ONNX_LOGE (" %s, %d, %d. Some unexpected things happened" ,
142
- text.c_str (), static_cast <int32_t >(punctuations.size ()),
143
- static_cast <int32_t >(token_ids.size ()));
144
- return text;
145
- }
146
-
147
137
std::string ans;
148
138
149
139
for (int32_t i = 0 ; i != static_cast <int32_t >(punctuations.size ()); ++i) {
140
+ if (i > tokens.size ()) {
141
+ break ;
142
+ }
150
143
const std::string &w = tokens[i];
151
144
if (i > 0 && !(ans.back () & 0x80 ) && !(w[0 ] & 0x80 )) {
152
145
ans.push_back (' ' );
@@ -156,6 +149,9 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl {
156
149
ans.append (meta_data.id2punct [punctuations[i]]);
157
150
}
158
151
}
152
+ if (ans.back () != meta_data.dot_id && ans.back () != meta_data.quest_id ) {
153
+ ans.push_back (meta_data.dot_id );
154
+ }
159
155
160
156
return ans;
161
157
}
0 commit comments