Skip to content

feat: add supertokenizers #236

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 40 commits into from
May 26, 2025
Merged
Show file tree
Hide file tree
Changes from 36 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
b172b6d
remove multiword warning
stephantul Apr 9, 2025
d24c387
add superbpe tokenizers
stephantul Apr 10, 2025
27e856f
Merge branch 'main' into add-superbpe
stephantul Apr 20, 2025
5f36097
merge
stephantul Apr 24, 2025
c9e7d14
fix issue with mwe
stephantul Apr 24, 2025
ecc89b8
merge
stephantul Apr 30, 2025
9d301d1
form
stephantul Apr 30, 2025
f4d6a82
Merge branch 'main' into add-superbpe
stephantul Apr 30, 2025
1666dd2
working version
stephantul May 2, 2025
8611ad5
first pass
stephantul May 4, 2025
59502a1
small fixes, many comments
stephantul May 4, 2025
e06c5d9
fix e5 bug
stephantul May 4, 2025
13a95dc
Adjust arcane formulae
stephantul May 4, 2025
e85e292
fix: logging
stephantul May 5, 2025
1c57d40
Merge branch 'main' into add-superbpe
stephantul May 6, 2025
b05c669
wip
stephantul May 7, 2025
3a7408a
wip
stephantul May 7, 2025
5275fca
wip
stephantul May 7, 2025
12d9ff2
lower complexity
stephantul May 12, 2025
c52ab40
add lock file
stephantul May 12, 2025
077a550
fix: metaspace pretokenizer
stephantul May 12, 2025
cff4035
fix: bug in vocab
stephantul May 12, 2025
a972c10
feat: spaces/commas etc.
stephantul May 12, 2025
e2789ba
turn tokenizer into package
stephantul May 13, 2025
d19ab92
add annotations
stephantul May 13, 2025
0bd32c0
feat: turn tokenizer into package
stephantul May 13, 2025
b48bd60
fix: future
stephantul May 13, 2025
abcf903
add tokenizer function
stephantul May 13, 2025
a201e0a
update lockfile
stephantul May 13, 2025
345a701
feat: improve segmentation of unigram
stephantul May 14, 2025
1350195
Merge branch 'main' into add-superbpe
stephantul May 18, 2025
cf005aa
merge
stephantul May 22, 2025
9301f00
fix: broken merge
stephantul May 22, 2025
796e18f
fix interpunct tokens
stephantul May 22, 2025
3aff31b
fix tests, make tokenizer changes better
stephantul May 22, 2025
bae0193
update lock file
stephantul May 22, 2025
336655e
fix comment, add additional check for pad token
stephantul May 23, 2025
f6a27a4
Merge branch 'main' into add-superbpe
stephantul May 26, 2025
02f5591
tests: add a lot of tests
stephantul May 26, 2025
98546da
fix: 3.9 error
stephantul May 26, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 29 additions & 122 deletions model2vec/distill/distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,21 @@
import logging
import os
import re
from typing import Literal, Union
from typing import cast
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😮


import numpy as np
from huggingface_hub import model_info
from sklearn.decomposition import PCA
from tokenizers import Tokenizer
from transformers import AutoModel, AutoTokenizer, PreTrainedModel, PreTrainedTokenizerFast

from model2vec.distill.inference import create_embeddings
from model2vec.distill.tokenizer import replace_vocabulary
from model2vec.distill.inference import PCADimType, create_embeddings, post_process_embeddings
from model2vec.distill.utils import select_optimal_device
from model2vec.model import StaticModel
from model2vec.quantization import DType, quantize_embeddings
from model2vec.tokenizer import clean_and_create_vocabulary, replace_vocabulary, turn_tokens_into_ids

logger = logging.getLogger(__name__)


PCADimType = Union[int, None, float, Literal["auto"]]


def distill_from_model(
model: PreTrainedModel,
tokenizer: PreTrainedTokenizerFast,
Expand Down Expand Up @@ -60,6 +55,7 @@
:param quantize_to: The data type to quantize to. Can be any of the DType enum members or their string equivalents.
:param use_subword: DEPRECATED: If this is not set to None, we show a warning. It doesn't do anything.
:return: A StaticModel
:raises: ValueError if the vocabulary is empty after preprocessing.

"""
if use_subword is not None:
Expand All @@ -74,35 +70,37 @@
vocabulary = []

device = select_optimal_device(device)
# Make a base list of tokens.
subword_vocab: dict[str, int] = tokenizer.get_vocab()
subword_tokens: list[str] = [k for k, _ in sorted(subword_vocab.items(), key=lambda x: x[1])]

n_tokens_before = len(vocabulary)
# Clean the vocabulary by removing duplicate tokens and tokens that are in the subword vocabulary.
cleaned_vocabulary = _clean_vocabulary(tokenizer.backend_tokenizer, vocabulary, subword_tokens)
n_tokens_after = len(cleaned_vocabulary)
logger.info(
f"Adding {n_tokens_after} tokens to the vocabulary. Removed {n_tokens_before - n_tokens_after} tokens during preprocessing."
# Clean the vocabulary by removing duplicate tokens and tokens that are in the internal vocabulary.
all_tokens, backend_tokenizer = clean_and_create_vocabulary(
tokenizer, vocabulary, token_remove_regex=token_remove_regex
)
n_tokens_after = len([token for token in all_tokens if not token.is_internal])
if n_tokens_before:
logger.info(
f"Adding {n_tokens_after} tokens to the vocabulary. Removed {n_tokens_before - n_tokens_after} tokens during preprocessing."
)

if not all_tokens:
raise ValueError("The vocabulary is empty after preprocessing. Please check your token_remove_pattern.")

Check warning on line 86 in model2vec/distill/distillation.py

View check run for this annotation

Codecov / codecov/patch

model2vec/distill/distillation.py#L86

Added line #L86 was not covered by tests

# Create the embeddings.
all_tokens, embeddings = create_embeddings(
model=model,
tokenizer=tokenizer,
tokens=cleaned_vocabulary,
device=device,
token_remove_regex=token_remove_regex,
)
unk_token: str | None = tokenizer.special_tokens_map.get("unk_token")
pad_token: str | None = tokenizer.special_tokens_map.get("pad_token")

unk_token = tokenizer.special_tokens_map.get("unk_token")
pad_token = tokenizer.special_tokens_map.get("pad_token")
# Add the cleaned vocabulary to the tokenizer.
backend_tokenizer = replace_vocabulary(backend_tokenizer, all_tokens, unk_token=unk_token, pad_token=pad_token)

# Post process the embeddings by applying PCA and Zipf weighting.
embeddings = _post_process_embeddings(np.asarray(embeddings), pca_dims, sif_coefficient=sif_coefficient)
# Convert tokens to IDs
token_ids = turn_tokens_into_ids(all_tokens, tokenizer, unk_token)

embeddings = create_embeddings(
tokenized=token_ids, model=model, device=device, pad_token_id=tokenizer.get_vocab()[pad_token]
)

# Post process the embeddings by applying PCA and Zipf weighting.
embeddings = post_process_embeddings(np.asarray(embeddings), pca_dims, sif_coefficient=sif_coefficient)
# Quantize the embeddings.
embeddings = quantize_embeddings(embeddings, quantize_to)

Expand Down Expand Up @@ -227,7 +225,10 @@

"""
model: PreTrainedModel = AutoModel.from_pretrained(model_name, trust_remote_code=trust_remote_code)
tokenizer: PreTrainedTokenizerFast = AutoTokenizer.from_pretrained(model_name, trust_remote_code=trust_remote_code)
tokenizer = cast(
PreTrainedTokenizerFast,
AutoTokenizer.from_pretrained(model_name, trust_remote_code=trust_remote_code, use_fast=True),
)

return distill_from_model(
model=model,
Expand All @@ -241,97 +242,3 @@
quantize_to=quantize_to,
use_subword=use_subword,
)


def _post_process_embeddings(
embeddings: np.ndarray, pca_dims: PCADimType, sif_coefficient: float | None = 1e-4
) -> np.ndarray:
"""Post process embeddings by applying PCA and SIF weighting by estimating the frequencies through Zipf's law."""
if pca_dims is not None:
if pca_dims == "auto":
pca_dims = embeddings.shape[1]
if pca_dims > embeddings.shape[1]:
logger.warning(
f"PCA dimension ({pca_dims}) is larger than the number of dimensions in the embeddings ({embeddings.shape[1]}). "
"Applying PCA, but not reducing dimensionality. Is this is not desired, please set `pca_dims` to None. "
"Applying PCA will probably improve performance, so consider just leaving it."
)
pca_dims = embeddings.shape[1]
if pca_dims >= embeddings.shape[0]:
logger.warning(
f"PCA dimension ({pca_dims}) is larger than the number of tokens in the vocabulary ({embeddings.shape[0]}). Not applying PCA."
)
elif pca_dims <= embeddings.shape[1]:
if isinstance(pca_dims, float):
logger.info(f"Applying PCA with {pca_dims} explained variance.")
else:
logger.info(f"Applying PCA with n_components {pca_dims}")

orig_dims = embeddings.shape[1]
p = PCA(n_components=pca_dims, svd_solver="full")
embeddings = p.fit_transform(embeddings)

if embeddings.shape[1] < orig_dims:
explained_variance_ratio = np.sum(p.explained_variance_ratio_)
explained_variance = np.sum(p.explained_variance_)
logger.info(f"Reduced dimensionality from {orig_dims} to {embeddings.shape[1]}.")
logger.info(f"Explained variance ratio: {explained_variance_ratio:.3f}.")
logger.info(f"Explained variance: {explained_variance:.3f}.")

if sif_coefficient is not None:
logger.info("Estimating word frequencies using Zipf's law, and then applying SIF.")
inv_rank = 1 / (np.arange(2, embeddings.shape[0] + 2))
proba = inv_rank / np.sum(inv_rank)
embeddings *= (sif_coefficient / (sif_coefficient + proba))[:, None]

return embeddings


def _clean_vocabulary(tokenizer: Tokenizer, vocabulary: list[str], added_tokens: list[str]) -> list[str]:
"""Cleans a vocabulary by removing duplicates and tokens that were already in the vocabulary."""
added_tokens_set = set(added_tokens)
seen_tokens = set()
cleaned_vocabulary = []
n_empty = 0
n_duplicates = 0
n_multiword = 0
for token in vocabulary:
normalizer = tokenizer.normalizer
if normalizer is not None:
token = normalizer.normalize_str(token)

if not token:
n_empty += 1
continue

pre_tokenizer = tokenizer.pre_tokenizer
# We need to check whether the pretokenized token is a single word or not.
if pre_tokenizer is not None:
pretokenized_tokens = pre_tokenizer.pre_tokenize_str(token)
if len(pretokenized_tokens) != 1:
n_multiword += 1
continue
new_token = pretokenized_tokens[-1][0]
else:
new_token = token

# We need to check whether the pretokenized token is in the vocabulary.
# But we need to return the original token, because that will be tokenized
# again by the tokenizer during featurization.
if new_token in seen_tokens or new_token in added_tokens_set:
n_duplicates += 1
continue

# Add the possibly pretokenized token to _seen_
seen_tokens.add(new_token)
# Add the original string to the vocabulary.
cleaned_vocabulary.append(token)

if n_duplicates:
logger.warning(f"Removed {n_duplicates} duplicate tokens.")
if n_empty:
logger.warning(f"Removed {n_empty} empty tokens.")
if n_multiword:
logger.warning(f"Removed {n_multiword} multiword tokens.")

return cleaned_vocabulary
109 changes: 59 additions & 50 deletions model2vec/distill/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,23 @@

import inspect
import logging
import re
from pathlib import Path
from typing import Protocol, Union
from typing import Literal, Protocol, Union

import numpy as np
import torch
from sklearn.decomposition import PCA
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm
from transformers import PreTrainedModel, PreTrainedTokenizerFast
from transformers import PreTrainedModel
from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions

from model2vec.distill.utils import Token, filter_vocabulary_by_regex

logger = logging.getLogger(__name__)


PathLike = Union[Path, str]
PCADimType = Union[int, None, float, Literal["auto"]]


_DEFAULT_BATCH_SIZE = 256

Expand All @@ -30,63 +30,27 @@ class ModulewithWeights(Protocol):

def create_embeddings(
model: PreTrainedModel,
tokenizer: PreTrainedTokenizerFast,
tokens: list[str],
tokenized: list[list[int]],
device: str,
token_remove_regex: re.Pattern | None,
) -> tuple[list[Token], np.ndarray]:
pad_token_id: int,
) -> np.ndarray:
"""
Create output embeddings for a bunch of tokens using a pretrained model.

It does a forward pass for all tokens passed in `tokens`.

:param model: The model to use.
This should be a transformers model.
:param tokenizer: The tokenizer to use.
:param tokens: The tokens to use.
:param tokenized: All tokenized tokens.
:param device: The torch device to use.
:param token_remove_regex: A regex pattern to remove tokens from the vocabulary.
:return: The tokens and output embeddings.
:param pad_token_id: The pad token id. Used to pad sequences.
:return: The output embeddings.
"""
model = model.to(device)

out_weights: np.ndarray
intermediate_weights: list[np.ndarray] = []

out_tokens: list[Token] = []
tokenized: list[torch.Tensor] = []
pad_token = tokenizer.special_tokens_map.get("pad_token")
# We need to use the pad token id for padding below.
pad_token_id = tokenizer.convert_tokens_to_ids(pad_token)
unk_token = tokenizer.special_tokens_map.get("unk_token")

# Empty set if no pad or unk token is set.
tokens_to_keep = {pad_token, unk_token} - {None}

if token_remove_regex is not None:
# Sort the vocabulary by id, important for zipf.
sorted_vocab = sorted(tokenizer.get_vocab().items(), key=lambda x: x[1])
id_list = filter_vocabulary_by_regex(token_remove_regex, sorted_vocab)
else:
# If the token remove regex is None, just use all tokens.
id_list = list(range(len(tokenizer.get_vocab())))

added_tokens_ids = [id for token, id in tokenizer.added_tokens_encoder.items() if token not in tokens_to_keep]
ids = torch.Tensor(sorted(set(id_list) - set(added_tokens_ids))).long()

if ids is not None:
dummy_encoding = tokenizer.encode("A")
bos_token_id, eos_token_id = dummy_encoding[0], dummy_encoding[-1]

bos = torch.full([len(ids)], fill_value=bos_token_id)
eos = torch.full([len(ids)], fill_value=eos_token_id)

tokenized.extend(torch.stack([bos, ids, eos], dim=1))
subword_tokens = [Token(x, True) for x in tokenizer.convert_ids_to_tokens(ids.tolist())]
out_tokens.extend(subword_tokens)

tokenized.extend([tokenizer.encode_plus(token, return_tensors="pt")["input_ids"][0] for token in tokens])

# Add token_type_ids only if the model supports it
add_token_type_ids = "token_type_ids" in inspect.getfullargspec(model.forward).args

Expand All @@ -98,7 +62,7 @@ def create_embeddings(
pbar = tqdm(total=len(sorted_tokenized), desc="Encoding tokens", unit=" tokens")

for batch_idx in range(0, len(sorted_tokenized), _DEFAULT_BATCH_SIZE):
batch = sorted_tokenized[batch_idx : batch_idx + _DEFAULT_BATCH_SIZE]
batch = [torch.Tensor(x).long() for x in sorted_tokenized[batch_idx : batch_idx + _DEFAULT_BATCH_SIZE]]

encoded = {}
encoded["input_ids"] = pad_sequence(batch, batch_first=True, padding_value=pad_token_id)
Expand All @@ -113,10 +77,11 @@ def create_embeddings(

# Sort the output back to the original order
intermediate_weights = [intermediate_weights[i] for i in np.argsort(sort_order)]
out_tokens.extend([Token(x, False) for x in tokens])
out_weights = np.stack(intermediate_weights)

return out_tokens, out_weights
out_weights = np.nan_to_num(out_weights)

return out_weights


@torch.no_grad()
Expand Down Expand Up @@ -147,3 +112,47 @@ def _encode_mean_using_model(model: PreTrainedModel, encodings: dict[str, torch.
result = torch.bmm(mask[:, None, :].float(), out).squeeze(1)

return result


def post_process_embeddings(
embeddings: np.ndarray, pca_dims: PCADimType, sif_coefficient: float | None = 1e-4
) -> np.ndarray:
"""Post process embeddings by applying PCA and SIF weighting by estimating the frequencies through Zipf's law."""
if pca_dims is not None:
if pca_dims == "auto":
pca_dims = embeddings.shape[1]
if pca_dims > embeddings.shape[1]:
logger.warning(
f"PCA dimension ({pca_dims}) is larger than the number of dimensions in the embeddings ({embeddings.shape[1]}). "
"Applying PCA, but not reducing dimensionality. Is this is not desired, please set `pca_dims` to None. "
"Applying PCA will probably improve performance, so consider just leaving it."
)
pca_dims = embeddings.shape[1]
if pca_dims >= embeddings.shape[0]:
logger.warning(
f"PCA dimension ({pca_dims}) is larger than the number of tokens in the vocabulary ({embeddings.shape[0]}). Not applying PCA."
)
elif pca_dims <= embeddings.shape[1]:
if isinstance(pca_dims, float):
logger.info(f"Applying PCA with {pca_dims} explained variance.")
else:
logger.info(f"Applying PCA with n_components {pca_dims}")

orig_dims = embeddings.shape[1]
p = PCA(n_components=pca_dims, svd_solver="full")
embeddings = p.fit_transform(embeddings)

if embeddings.shape[1] < orig_dims:
explained_variance_ratio = np.sum(p.explained_variance_ratio_)
explained_variance = np.sum(p.explained_variance_)
logger.info(f"Reduced dimensionality from {orig_dims} to {embeddings.shape[1]}.")
logger.info(f"Explained variance ratio: {explained_variance_ratio:.3f}.")
logger.info(f"Explained variance: {explained_variance:.3f}.")

if sif_coefficient is not None:
logger.info("Estimating word frequencies using Zipf's law, and then applying SIF.")
inv_rank = 1 / (np.arange(2, embeddings.shape[0] + 2))
proba = inv_rank / np.sum(inv_rank)
embeddings *= (sif_coefficient / (sif_coefficient + proba))[:, None]

return embeddings
Loading