Skip to content

Partial sync of codebase #381

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 35 additions & 1 deletion tiktoken/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@

import functools
from concurrent.futures import ThreadPoolExecutor
from typing import AbstractSet, Collection, Literal, NoReturn, Sequence
from typing import TYPE_CHECKING, AbstractSet, Collection, Literal, NoReturn, Sequence

import regex

from tiktoken import _tiktoken

if TYPE_CHECKING:
import numpy as np
import numpy.typing as npt


class Encoding:
def __init__(
Expand Down Expand Up @@ -128,6 +132,32 @@ def encode(
text = text.encode("utf-16", "surrogatepass").decode("utf-16", "replace")
return self._core_bpe.encode(text, allowed_special)

def encode_to_numpy(
self,
text: str,
*,
allowed_special: Literal["all"] | AbstractSet[str] = set(), # noqa: B006
disallowed_special: Literal["all"] | Collection[str] = "all",
) -> npt.NDArray[np.uint32]:
"""Encodes a string into tokens, returning a numpy array.
Avoids the overhead of copying the token buffer into a Python list.
"""
if allowed_special == "all":
allowed_special = self.special_tokens_set
if disallowed_special == "all":
disallowed_special = self.special_tokens_set - allowed_special
if disallowed_special:
if not isinstance(disallowed_special, frozenset):
disallowed_special = frozenset(disallowed_special)
if match := _special_token_regex(disallowed_special).search(text):
raise_disallowed_special_token(match.group())

import numpy as np

buffer = self._core_bpe.encode_to_tiktoken_buffer(text, self.special_tokens_set)
return np.frombuffer(buffer, dtype=np.uint32)

def encode_ordinary_batch(self, text: list[str], *, num_threads: int = 8) -> list[list[int]]:
"""Encodes a list of strings into tokens, in parallel, ignoring special tokens.
Expand Down Expand Up @@ -332,6 +362,10 @@ def eot_token(self) -> int:
def special_tokens_set(self) -> set[str]:
return set(self._special_tokens.keys())

def is_special_token(self, token: int) -> bool:
assert isinstance(token, int)
return token in self._special_token_values

@property
def n_vocab(self) -> int:
"""For backwards compatibility. Prefer to use `enc.max_token_value + 1`."""
Expand Down
2 changes: 1 addition & 1 deletion tiktoken/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,5 +154,5 @@ def load_tiktoken_bpe(tiktoken_bpe_file: str, expected_hash: str | None = None)
token, rank = line.split()
ret[base64.b64decode(token)] = int(rank)
except Exception as e:
raise ValueError(f"Error parsing line {line} in {tiktoken_bpe_file}") from e
raise ValueError(f"Error parsing line {line!r} in {tiktoken_bpe_file}") from e
return ret
5 changes: 5 additions & 0 deletions tiktoken/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,25 @@
# TODO: these will likely be replaced by an API endpoint
MODEL_PREFIX_TO_ENCODING: dict[str, str] = {
"o1-": "o200k_base",
"o3-": "o200k_base",
# chat
"chatgpt-4o-": "o200k_base",
"gpt-4o-": "o200k_base", # e.g., gpt-4o-2024-05-13
"gpt-4-": "cl100k_base", # e.g., gpt-4-0314, etc., plus gpt-4-32k
"gpt-3.5-turbo-": "cl100k_base", # e.g, gpt-3.5-turbo-0301, -0401, etc.
"gpt-35-turbo-": "cl100k_base", # Azure deployment name
# fine-tuned
"ft:gpt-4o": "o200k_base",
"ft:gpt-4": "cl100k_base",
"ft:gpt-3.5-turbo": "cl100k_base",
"ft:davinci-002": "cl100k_base",
"ft:babbage-002": "cl100k_base",
}

MODEL_TO_ENCODING: dict[str, str] = {
# reasoning
"o1": "o200k_base",
"o3": "o200k_base",
# chat
"gpt-4o": "o200k_base",
"gpt-4": "cl100k_base",
Expand Down
Loading