Skip to content

Commit 3ee6c35

Browse files
authored
Add support for checking hash of downloaded files before use. (#230)
We are using tiktoken in various production scenarios and sometimes have the problem that the download of `.tiktoken` files (e.g., `cl100k_base.tiktoken`) will get interrupted or fail, causing the cached file to be corrupted in some way. In those cases, the results returned from the encoder will be incorrect and could be damaging to our production instances. More often, when this happens, `Encoder.encode()` will throw an exception such as ``` pyo3_runtime.PanicException: no entry found for key ``` which turns out to be quite hard to track down. In an effort to make tiktoken more robust for production use, this PR adds the `sha256` hash of each of the downloaded files to `openai_public.py` and augments `read_file` to check for the hash, if provided, when the file is accessed from the cache or downloaded directly. This causes errors to be flagged at file load time, rather than when the files are used, and provides a more meaningful error message indicating what might have gone wrong. This also protects users of tiktoken from scenarios where a network issue or MITM attack could have corrupted these files in transit.
1 parent 9e79899 commit 3ee6c35

File tree

2 files changed

+34
-11
lines changed

2 files changed

+34
-11
lines changed

tiktoken/load.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import os
77
import tempfile
88
import uuid
9+
from typing import Optional
910

1011
import requests
1112

@@ -26,7 +27,12 @@ def read_file(blobpath: str) -> bytes:
2627
return resp.content
2728

2829

29-
def read_file_cached(blobpath: str) -> bytes:
30+
def check_hash(data: bytes, hash: str) -> bool:
31+
data_hash = hashlib.sha256(data).hexdigest()
32+
return data_hash == hash
33+
34+
35+
def read_file_cached(blobpath: str, expected_hash: Optional[str]=None) -> bytes:
3036
user_specified_cache = True
3137
if "TIKTOKEN_CACHE_DIR" in os.environ:
3238
cache_dir = os.environ["TIKTOKEN_CACHE_DIR"]
@@ -45,9 +51,20 @@ def read_file_cached(blobpath: str) -> bytes:
4551
cache_path = os.path.join(cache_dir, cache_key)
4652
if os.path.exists(cache_path):
4753
with open(cache_path, "rb") as f:
48-
return f.read()
54+
data = f.read()
55+
if expected_hash and not check_hash(data, expected_hash):
56+
raise ValueError(
57+
f"Hash mismatch for cached data from {blobpath} (expected {expected_hash}). "
58+
f"Please delete the cache file at {cache_path} and try again."
59+
)
60+
return data
4961

5062
contents = read_file(blobpath)
63+
if expected_hash and not check_hash(contents, expected_hash):
64+
raise ValueError(
65+
f"Hash mismatch for data downloaded from {blobpath} (expected {expected_hash}). "
66+
f"This may indicate a corrupted download. Please try again."
67+
)
5168

5269
try:
5370
os.makedirs(cache_dir, exist_ok=True)
@@ -64,7 +81,7 @@ def read_file_cached(blobpath: str) -> bytes:
6481

6582

6683
def data_gym_to_mergeable_bpe_ranks(
67-
vocab_bpe_file: str, encoder_json_file: str
84+
vocab_bpe_file: str, encoder_json_file: str, vocab_bpe_hash: Optional[str]=None, encoder_json_hash: Optional[str]=None
6885
) -> dict[bytes, int]:
6986
# NB: do not add caching to this function
7087
rank_to_intbyte = [b for b in range(2**8) if chr(b).isprintable() and chr(b) != " "]
@@ -79,7 +96,7 @@ def data_gym_to_mergeable_bpe_ranks(
7996
assert len(rank_to_intbyte) == 2**8
8097

8198
# vocab_bpe contains the merges along with associated ranks
82-
vocab_bpe_contents = read_file_cached(vocab_bpe_file).decode()
99+
vocab_bpe_contents = read_file_cached(vocab_bpe_file, vocab_bpe_hash).decode()
83100
bpe_merges = [tuple(merge_str.split()) for merge_str in vocab_bpe_contents.split("\n")[1:-1]]
84101

85102
def decode_data_gym(value: str) -> bytes:
@@ -96,7 +113,7 @@ def decode_data_gym(value: str) -> bytes:
96113
# check that the encoder file matches the merges file
97114
# this sanity check is important since tiktoken assumes that ranks are ordered the same
98115
# as merge priority
99-
encoder_json = json.loads(read_file_cached(encoder_json_file))
116+
encoder_json = json.loads(read_file_cached(encoder_json_file, encoder_json_hash))
100117
encoder_json_loaded = {decode_data_gym(k): v for k, v in encoder_json.items()}
101118
# drop these two special tokens if present, since they're not mergeable bpe tokens
102119
encoder_json_loaded.pop(b"<|endoftext|>", None)
@@ -118,9 +135,9 @@ def dump_tiktoken_bpe(bpe_ranks: dict[bytes, int], tiktoken_bpe_file: str) -> No
118135
f.write(base64.b64encode(token) + b" " + str(rank).encode() + b"\n")
119136

120137

121-
def load_tiktoken_bpe(tiktoken_bpe_file: str) -> dict[bytes, int]:
138+
def load_tiktoken_bpe(tiktoken_bpe_file: str, expected_hash: Optional[str]=None) -> dict[bytes, int]:
122139
# NB: do not add caching to this function
123-
contents = read_file_cached(tiktoken_bpe_file)
140+
contents = read_file_cached(tiktoken_bpe_file, expected_hash)
124141
return {
125142
base64.b64decode(token): int(rank)
126143
for token, rank in (line.split() for line in contents.splitlines() if line)

tiktoken_ext/openai_public.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ def gpt2():
1111
mergeable_ranks = data_gym_to_mergeable_bpe_ranks(
1212
vocab_bpe_file="https://openaipublic.blob.core.windows.net/gpt-2/encodings/main/vocab.bpe",
1313
encoder_json_file="https://openaipublic.blob.core.windows.net/gpt-2/encodings/main/encoder.json",
14+
vocab_bpe_hash="1ce1664773c50f3e0cc8842619a93edc4624525b728b188a9e0be33b7726adc5",
15+
encoder_json_hash="196139668be63f3b5d6574427317ae82f612a97c5d1cdaf36ed2256dbf636783",
1416
)
1517
return {
1618
"name": "gpt2",
@@ -23,7 +25,8 @@ def gpt2():
2325

2426
def r50k_base():
2527
mergeable_ranks = load_tiktoken_bpe(
26-
"https://openaipublic.blob.core.windows.net/encodings/r50k_base.tiktoken"
28+
"https://openaipublic.blob.core.windows.net/encodings/r50k_base.tiktoken",
29+
expected_hash="306cd27f03c1a714eca7108e03d66b7dc042abe8c258b44c199a7ed9838dd930",
2730
)
2831
return {
2932
"name": "r50k_base",
@@ -36,7 +39,8 @@ def r50k_base():
3639

3740
def p50k_base():
3841
mergeable_ranks = load_tiktoken_bpe(
39-
"https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken"
42+
"https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken",
43+
expected_hash="94b5ca7dff4d00767bc256fdd1b27e5b17361d7b8a5f968547f9f23eb70d2069",
4044
)
4145
return {
4246
"name": "p50k_base",
@@ -49,7 +53,8 @@ def p50k_base():
4953

5054
def p50k_edit():
5155
mergeable_ranks = load_tiktoken_bpe(
52-
"https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken"
56+
"https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken",
57+
expected_hash="94b5ca7dff4d00767bc256fdd1b27e5b17361d7b8a5f968547f9f23eb70d2069",
5358
)
5459
special_tokens = {ENDOFTEXT: 50256, FIM_PREFIX: 50281, FIM_MIDDLE: 50282, FIM_SUFFIX: 50283}
5560
return {
@@ -62,7 +67,8 @@ def p50k_edit():
6267

6368
def cl100k_base():
6469
mergeable_ranks = load_tiktoken_bpe(
65-
"https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken"
70+
"https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken",
71+
expected_hash="223921b76ee99bde995b7ff738513eef100fb51d18c93597a113bcffe865b2a7",
6672
)
6773
special_tokens = {
6874
ENDOFTEXT: 100257,

0 commit comments

Comments
 (0)