From 20d7c8ce41593608c1a7569b202d932fd642b9f8 Mon Sep 17 00:00:00 2001 From: Yaroslav Sokolov Date: Wed, 17 May 2023 15:01:39 +0200 Subject: [PATCH] Now we don't add extra space token at the beginning of the sequence --- pyproject.toml | 2 +- tests/unit_tests/test_cli.py | 3 --- tests/unit_tests/test_manual.py | 6 +++--- tests/unit_tests/test_python_api.py | 2 +- youtokentome/cpp/bpe.cpp | 4 +++- 5 files changed, 8 insertions(+), 9 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 155098c..a4da3be 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "youtokentome" -version = "1.0.8" +version = "1.0.9" description = "FL version of YTTM" license = "MIT" authors = ["Ivan Belonogov"] diff --git a/tests/unit_tests/test_cli.py b/tests/unit_tests/test_cli.py index 7d7680c..3c312a8 100644 --- a/tests/unit_tests/test_cli.py +++ b/tests/unit_tests/test_cli.py @@ -186,7 +186,6 @@ def test_decode(): ) with open("decode_text_out.txt", "r") as fin: - fin.readline() text_out = fin.readline() assert text_in == text_out[:-1] @@ -224,8 +223,6 @@ def test_decode(): ) with open("decode_text_out.txt", "r") as fin: - # It is necessary to skip the first line, since everything in BPE starts from a new line - fin.readline() text_out = fin.readline() assert text_in == text_out[:-1] diff --git a/tests/unit_tests/test_manual.py b/tests/unit_tests/test_manual.py index 0d0bf13..d9e2462 100644 --- a/tests/unit_tests/test_manual.py +++ b/tests/unit_tests/test_manual.py @@ -39,7 +39,7 @@ def test_english(): chronocoulometry """ - test_text = "chronocline synchroscope " + test_text = "chronocline synchroscope \n" TRAIN_DATA_PATH = "train_data.txt" MODEL_PATH = "model.yttm" @@ -48,7 +48,7 @@ def test_english(): model = yttm.BPE.train(TRAIN_DATA_PATH, MODEL_PATH, 200, n_threads=1) tokenized_text = model.encode([test_text], output_type=yttm.OutputType.SUBWORD) expected_result = [ - ["\n", "chrono", "c", "l", "i", "n", "e", " ", "s", "y", "n", "ch", "r", "o", "s", "co", "p", "e", " "] + ["chrono", "c", "l", "i", "n", "e", " ", "s", "y", "n", "ch", "r", "o", "s", "co", "p", "e", " ", "\n"] ] assert tokenized_text == expected_result print(tokenized_text) @@ -73,7 +73,7 @@ def test_japanese(): model = yttm.BPE.train(TRAIN_DATA_PATH, MODEL_PATH, 100) tokenized_text = model.encode([test_text], output_type=yttm.OutputType.SUBWORD) expected_result = [ - ["\n", " ", "おばあさん ", "が", " ", "川", " ", "で", " ", "せ", "ん", " "] + [" ", "おばあさん ", "が", " ", "川", " ", "で", " ", "せ", "ん", " "] ] assert tokenized_text == expected_result print(tokenized_text) diff --git a/tests/unit_tests/test_python_api.py b/tests/unit_tests/test_python_api.py index e0c581b..495e925 100644 --- a/tests/unit_tests/test_python_api.py +++ b/tests/unit_tests/test_python_api.py @@ -30,7 +30,7 @@ def test_encode_decode(): text_in = [" ".join("".join([random.choice("abcd ") for _ in range(50)]).split())] ids = bpe.encode(text_in, yttm.OutputType.ID) # It is necessary to add first empty line, since everything in BPE starts from a new line - text_in[0] = "\n" + text_in[0] + text_in[0] = text_in[0] assert text_in == bpe.decode(ids) ids_bos_eos = bpe.encode(text_in, yttm.OutputType.ID, bos=True, eos=True) assert text_in == bpe.decode(ids_bos_eos, ignore_ids=[BOS_ID, EOS_ID]) diff --git a/youtokentome/cpp/bpe.cpp b/youtokentome/cpp/bpe.cpp index 7cd3212..de87b93 100644 --- a/youtokentome/cpp/bpe.cpp +++ b/youtokentome/cpp/bpe.cpp @@ -1539,10 +1539,12 @@ DecodeResult BaseEncoder::encode_sentence(const std::string &sentence_utf8, auto begin_of_word = std::find_if_not(it_text, text.end(), is_space); auto end_of_word = std::find_if(begin_of_word, text.end(), is_space); + if (begin_of_word != it_text) { + list.emplace_back(bpe_state.char2id.at(SPACE_TOKEN), 0); + } it_text = end_of_word; uint32_t new_token_cur = new_tokens_start; - list.emplace_back(bpe_state.char2id.at(SPACE_TOKEN), 0); for (auto it_char_in_word = begin_of_word; it_char_in_word < end_of_word;) { if (bpe_state.char2id.count(*it_char_in_word) == 0) {