diff --git a/model2vec/distill/distillation.py b/model2vec/distill/distillation.py index 704bec6..bc33772 100644 --- a/model2vec/distill/distillation.py +++ b/model2vec/distill/distillation.py @@ -13,7 +13,7 @@ from model2vec.distill.inference import create_embeddings from model2vec.distill.tokenizer import replace_vocabulary -from model2vec.distill.utils import select_optimal_device +from model2vec.distill.utils import Token, select_optimal_device from model2vec.model import StaticModel from model2vec.quantization import DType, quantize_embeddings diff --git a/model2vec/distill/inference.py b/model2vec/distill/inference.py index 78d99c9..e94695a 100644 --- a/model2vec/distill/inference.py +++ b/model2vec/distill/inference.py @@ -14,7 +14,7 @@ from transformers import PreTrainedModel, PreTrainedTokenizerFast from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions -from model2vec.distill.utils import filter_vocabulary_by_regex +from model2vec.distill.utils import Token, filter_vocabulary_by_regex logger = logging.getLogger(__name__) @@ -35,7 +35,7 @@ def create_embeddings( tokens: list[str], device: str, token_remove_regex: re.Pattern | None, -) -> tuple[list[str], np.ndarray]: +) -> tuple[list[Token], np.ndarray]: """ Create output embeddings for a bunch of tokens using a pretrained model. @@ -55,7 +55,7 @@ def create_embeddings( out_weights: np.ndarray intermediate_weights: list[np.ndarray] = [] - out_tokens = [] + out_tokens: list[Token] = [] tokenized: list[torch.Tensor] = [] pad_token = tokenizer.special_tokens_map.get("pad_token") pad_token_id = tokenizer.convert_tokens_to_ids(pad_token) @@ -89,7 +89,8 @@ def create_embeddings( eos = torch.full([len(ids)], fill_value=eos_token_id) tokenized.extend(torch.stack([bos, ids, eos], dim=1)) - out_tokens.extend(tokenizer.convert_ids_to_tokens(ids)) + subword_tokens = [Token(x, True) for x in tokenizer.convert_ids_to_tokens(ids.tolist())] + out_tokens.extend(subword_tokens) tokenized.extend([tokenizer.encode_plus(token, return_tensors="pt")["input_ids"][0] for token in tokens]) @@ -119,7 +120,7 @@ def create_embeddings( # Sort the output back to the original order intermediate_weights = [intermediate_weights[i] for i in np.argsort(sort_order)] - out_tokens.extend(tokens) + out_tokens.extend([Token(x, False) for x in tokens]) out_weights = np.stack(intermediate_weights) return out_tokens, out_weights diff --git a/model2vec/distill/tokenizer.py b/model2vec/distill/tokenizer.py index 6f21e85..69a82e0 100644 --- a/model2vec/distill/tokenizer.py +++ b/model2vec/distill/tokenizer.py @@ -6,6 +6,8 @@ from tokenizers import Tokenizer +from model2vec.distill.utils import Token + logger = logging.getLogger(__name__) @@ -17,7 +19,7 @@ } -def _pre_tokenize_vocabulary(tokenizer: Tokenizer, tokens: list[str]) -> list[str]: +def _pre_tokenize_vocabulary(tokenizer: Tokenizer, tokens: list[Token]) -> list[str]: """ Apply pre-tokenization to vocabulary tokens if a pre-tokenizer is present. @@ -33,14 +35,14 @@ def _pre_tokenize_vocabulary(tokenizer: Tokenizer, tokens: list[str]) -> list[st if tokenizer.pre_tokenizer is not None: for token in tokens: - if token in current_tokenizer_vocab: - pre_tokenized_tokens.append(token) + if token.is_subword: + pre_tokenized_tokens.append(token.form) else: # We know 100% sure that all pretokenized tokens will have length 1. - pretokenized_tokens, _ = zip(*tokenizer.pre_tokenizer.pre_tokenize_str(f" {token}")) + pretokenized_tokens, _ = zip(*tokenizer.pre_tokenizer.pre_tokenize_str(f" {token.form}")) pre_tokenized_tokens.append(pretokenized_tokens[-1]) else: - pre_tokenized_tokens = tokens + pre_tokenized_tokens = [token.form for token in tokens] return pre_tokenized_tokens @@ -106,7 +108,7 @@ def _make_new_merges_from_vocab( def replace_vocabulary( - tokenizer: Tokenizer, new_vocabulary: list[str], unk_token: str | None, pad_token: str | None + tokenizer: Tokenizer, new_vocabulary: list[Token], unk_token: str | None, pad_token: str | None ) -> Tokenizer: """Replace the vocabulary of a tokenizer with a new one.""" tokenizer_json: dict[str, Any] = json.loads(tokenizer.to_str()) @@ -139,8 +141,8 @@ def replace_vocabulary( vocab = tokenizer_json["model"]["vocab"] unk_token = vocab[unk_id][0] if unk_id is not None else None current_probas = dict(tokenizer_json["model"]["vocab"]) - lowest_proba = min(current_probas.values()) - new_probas = {word: current_probas.get(word, lowest_proba) for word in pre_tokenized_tokens} + avg_proba = sum(current_probas.values()) / len(current_probas) + new_probas = {word: current_probas.get(word, avg_proba) for word in pre_tokenized_tokens} tokenizer_json["model"]["vocab"] = sorted(new_probas.items(), key=lambda x: x[1], reverse=True) tokens, _ = zip(*tokenizer_json["model"]["vocab"]) diff --git a/model2vec/distill/utils.py b/model2vec/distill/utils.py index 5073903..4dd3c4a 100644 --- a/model2vec/distill/utils.py +++ b/model2vec/distill/utils.py @@ -1,6 +1,7 @@ from __future__ import annotations import re +from dataclasses import dataclass from logging import getLogger import torch @@ -8,6 +9,14 @@ logger = getLogger(__name__) +@dataclass +class Token: + """A class to represent a token.""" + + form: str + is_subword: bool + + def select_optimal_device(device: str | None) -> str: """ Guess what your optimal device should be based on backend availability.