diff --git a/.github/workflows/sanitizer.yaml b/.github/workflows/sanitizer.yaml index ef348a3809..2d4abf49c4 100644 --- a/.github/workflows/sanitizer.yaml +++ b/.github/workflows/sanitizer.yaml @@ -76,6 +76,14 @@ jobs: otool -L build/bin/sherpa-onnx otool -l build/bin/sherpa-onnx + - name: Test offline punctuation + shell: bash + run: | + export PATH=$PWD/build/bin:$PATH + export EXE=sherpa-onnx-offline-punctuation + + .github/scripts/test-offline-punctuation.sh + - name: Test offline transducer shell: bash run: | @@ -92,13 +100,7 @@ jobs: .github/scripts/test-online-ctc.sh - - name: Test offline punctuation - shell: bash - run: | - export PATH=$PWD/build/bin:$PATH - export EXE=sherpa-onnx-offline-punctuation - .github/scripts/test-offline-punctuation.sh - name: Test C API shell: bash diff --git a/sherpa-onnx/csrc/offline-punctuation-ct-transformer-impl.h b/sherpa-onnx/csrc/offline-punctuation-ct-transformer-impl.h index 4d05fb5036..eb2c46d6a7 100644 --- a/sherpa-onnx/csrc/offline-punctuation-ct-transformer-impl.h +++ b/sherpa-onnx/csrc/offline-punctuation-ct-transformer-impl.h @@ -69,8 +69,8 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl { std::vector punctuations; int32_t last = -1; for (int32_t i = 0; i != num_segments; ++i) { - int32_t this_start = i * segment_size; // inclusive - int32_t this_end = this_start + segment_size; // exclusive + int32_t this_start = i * segment_size; // included + int32_t this_end = this_start + segment_size; // not included if (this_end > static_cast(token_ids.size())) { this_end = token_ids.size(); } @@ -113,7 +113,8 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl { int32_t dot_index = -1; int32_t comma_index = -1; - for (int32_t m = this_punctuations.size() - 2; m >= 1; --m) { + for (int32_t m = static_cast(this_punctuations.size()) - 2; + m >= 1; --m) { int32_t punct_id = this_punctuations[m]; if (punct_id == meta_data.dot_id || punct_id == meta_data.quest_id) { @@ -137,13 +138,13 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl { } if (i == num_segments - 1) { - dot_index = token_ids.size() - 1; + dot_index = static_cast(this_punctuations.size()) - 1; } } else { last = this_start + dot_index + 1; } - if (dot_index != 1) { + if (dot_index != -1) { punctuations.insert(punctuations.end(), this_punctuations.begin(), this_punctuations.begin() + (dot_index + 1)); }