diff --git a/model2vec/distill/distillation.py b/model2vec/distill/distillation.py index 454173a1..be196b82 100644 --- a/model2vec/distill/distillation.py +++ b/model2vec/distill/distillation.py @@ -304,24 +304,35 @@ def _clean_vocabulary(tokenizer: Tokenizer, vocabulary: list[str], added_tokens: n_duplicates = 0 n_multiword = 0 for token in vocabulary: - if tokenizer.normalizer is not None: - token = tokenizer.normalizer.normalize_str(token) + normalizer = tokenizer.normalizer + if normalizer is not None: + token = normalizer.normalize_str(token) if not token: n_empty += 1 continue - if token in seen_tokens or token in added_tokens_set: - n_duplicates += 1 - continue pre_tokenizer = tokenizer.pre_tokenizer + # We need to check whether the pretokenized token is a single word or not. if pre_tokenizer is not None: pretokenized_tokens = pre_tokenizer.pre_tokenize_str(token) if len(pretokenized_tokens) != 1: n_multiword += 1 continue + new_token = pretokenized_tokens[-1][0] + else: + new_token = token + + # We need to check whether the pretokenized token is in the vocabulary. + # But we need to return the original token, because that will be tokenized + # again by the tokenizer during featurization. + if new_token in seen_tokens or new_token in added_tokens_set: + n_duplicates += 1 + continue - seen_tokens.add(token) + # Add the possibly pretokenized token to _seen_ + seen_tokens.add(new_token) + # Add the original string to the vocabulary. cleaned_vocabulary.append(token) if n_duplicates: