Skip to content

Commit 6115ea0

Browse files
committed
Support different tokenizers
1 parent 722f6e2 commit 6115ea0

File tree

6 files changed

+601
-26
lines changed

6 files changed

+601
-26
lines changed

.ci/docker/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@ tabulate
88
wandb
99
fsspec
1010
tyro
11+
tokenizers >= 0.15.0

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,11 @@ out
1313
wandb
1414

1515
torchtitan/datasets/**/*.model
16+
17+
# tokenizer models
1618
assets/**/*.model
19+
assets/**/*.json
20+
assets/**/*.txt
1721
torchtitan/experiments/flux/assets/*
1822

1923
# temp files

README.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,11 @@ Once you have confirmed access, you can run the following command to download th
103103
```bash
104104
# Get your HF token from https://huggingface.co/settings/tokens
105105

106-
# Llama 3.1 tokenizer.model
107-
python scripts/download_tokenizer.py --repo_id meta-llama/Meta-Llama-3.1-8B --tokenizer_path "original" --hf_token=...
106+
# Llama 3.1 tokenizer (automatically downloads original/tokenizer.model)
107+
python scripts/download_tokenizer.py --repo_id meta-llama/Meta-Llama-3.1-8B --hf_token=...
108+
109+
# DeepSeek tokenizer (automatically downloads tokenizer.json and tokenizer_config.json)
110+
python scripts/download_tokenizer.py --repo_id deepseek-ai/DeepSeek-V3
108111
```
109112

110113
### Start a training run

scripts/download_tokenizer.py

Lines changed: 107 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,23 +9,105 @@
99
from requests.exceptions import HTTPError
1010

1111

12-
def hf_download(
13-
repo_id: str, tokenizer_path: str, local_dir: str, hf_token: Optional[str] = None
12+
def hf_download_tokenizer(
13+
repo_id: str, local_dir: str, hf_token: Optional[str] = None
1414
) -> None:
15-
from huggingface_hub import hf_hub_download
15+
"""
16+
Download relevant tokenizer files from HuggingFace Hub repository.
1617
17-
tokenizer_path = (
18-
f"{tokenizer_path}/tokenizer.model" if tokenizer_path else "tokenizer.model"
19-
)
18+
This function recursively searches through the HuggingFace Hub repository
19+
and downloads all tokenizer-related files to enable tokenizer
20+
loading with the load_tokenizer() function.
21+
22+
Files downloaded:
23+
- tokenizer.json - Modern HuggingFace tokenizers (complete definition)
24+
- tokenizer_config.json - Tokenizer configuration and metadata
25+
- tokenizer.model - SentencePiece model files (Llama, T5, etc.)
26+
- vocab.txt - Plain text vocabulary files
27+
- vocab.json - JSON vocabulary files
28+
- merges.txt - BPE merge rules (GPT-2, RoBERTa style)
29+
- special_tokens_map.json - Special token mappings
30+
31+
Args:
32+
repo_id (str): HuggingFace repository ID (e.g., "meta-llama/Meta-Llama-3.1-8B")
33+
local_dir (str): Local directory to save tokenizer files. A subdirectory
34+
named after the model will be created automatically.
35+
hf_token (Optional[str]): HuggingFace API token for accessing private repositories.
36+
Required for gated models like Llama.
37+
"""
38+
import os
39+
40+
from huggingface_hub import hf_hub_download, list_repo_files
41+
42+
# Extract model name from repo_id (part after "/")
43+
model_name = repo_id.split("/")[-1]
44+
model_dir = os.path.join(local_dir, model_name)
45+
46+
# Tokenizer file patterns to match (case-insensitive)
47+
tokenizer_patterns = [
48+
"tokenizer.json",
49+
"tokenizer_config.json",
50+
"tokenizer.model",
51+
"vocab.txt",
52+
"vocab.json",
53+
"merges.txt",
54+
"special_tokens_map.json",
55+
]
56+
57+
def is_tokenizer_file(filename: str) -> bool:
58+
"""Check if a file is a tokenizer-related file."""
59+
filename_lower = filename.lower()
60+
basename = os.path.basename(filename_lower)
61+
62+
# Check exact matches
63+
if basename in [pattern.lower() for pattern in tokenizer_patterns]:
64+
return True
65+
66+
return False
2067

2168
try:
22-
hf_hub_download(
23-
repo_id=repo_id,
24-
filename=tokenizer_path,
25-
local_dir=local_dir,
26-
local_dir_use_symlinks=False,
27-
token=hf_token,
28-
)
69+
# Get list of available files in the repo
70+
print(f"Scanning repository {repo_id} for tokenizer files...")
71+
available_files = list_repo_files(repo_id=repo_id, token=hf_token)
72+
73+
# Filter for tokenizer files
74+
tokenizer_files_found = [f for f in available_files if is_tokenizer_file(f)]
75+
76+
if not tokenizer_files_found:
77+
print(f"Warning: No tokenizer files found in {repo_id}")
78+
print(f"Available files: {available_files[:10]}...")
79+
return
80+
81+
print(f"Found {len(tokenizer_files_found)} tokenizer files:")
82+
for f in tokenizer_files_found:
83+
print(f" - {f}")
84+
85+
downloaded_files = []
86+
for filename in tokenizer_files_found:
87+
try:
88+
hf_hub_download(
89+
repo_id=repo_id,
90+
filename=filename,
91+
local_dir=model_dir,
92+
token=hf_token,
93+
)
94+
file_path = os.path.join(model_dir, filename)
95+
print(f"Successfully downloaded {filename} to {file_path}")
96+
downloaded_files.append(filename)
97+
except HTTPError as e:
98+
if e.response.status_code == 404:
99+
print(f"File {filename} not found, skipping...")
100+
continue
101+
else:
102+
raise e
103+
104+
if downloaded_files:
105+
print(
106+
f"\nSuccessfully downloaded {len(downloaded_files)} tokenizer files to: {model_dir}"
107+
)
108+
else:
109+
print(f"Warning: No tokenizer files could be downloaded from {repo_id}")
110+
29111
except HTTPError as e:
30112
if e.response.status_code == 401:
31113
print(
@@ -38,28 +120,29 @@ def hf_download(
38120
if __name__ == "__main__":
39121
import argparse
40122

41-
parser = argparse.ArgumentParser(description="Download tokenizer from HuggingFace.")
123+
parser = argparse.ArgumentParser(
124+
description="Download tokenizer files from HuggingFace Hub. "
125+
"Automatically detects and downloads common tokenizer files (tokenizer.json, "
126+
"tokenizer_config.json, tokenizer.model, ...) that work with Tokenizer."
127+
)
42128
parser.add_argument(
43129
"--repo_id",
44130
type=str,
45-
default="meta-llama/Meta-Llama-3.1-8B",
46-
help="Repository ID to download from. default to Llama-3.1-8B",
131+
required=True,
132+
help="Repository ID to download from (e.g., 'meta-llama/Meta-Llama-3.1-8B', 'deepseek-ai/DeepSeek-V3')",
47133
)
48134
parser.add_argument(
49-
"--tokenizer_path",
135+
"--hf_token",
50136
type=str,
51-
default="original",
52-
help="the tokenizer.model path relative to repo_id",
53-
)
54-
parser.add_argument(
55-
"--hf_token", type=str, default=None, help="HuggingFace API token"
137+
default=None,
138+
help="HuggingFace API token (required for private repos)",
56139
)
57140
parser.add_argument(
58141
"--local_dir",
59142
type=str,
60143
default="assets/tokenizer/",
61-
help="local directory to save the tokenizer.model",
144+
help="Local directory to save tokenizer files (default: assets/tokenizer/)",
62145
)
63146

64147
args = parser.parse_args()
65-
hf_download(args.repo_id, args.tokenizer_path, args.local_dir, args.hf_token)
148+
hf_download_tokenizer(args.repo_id, args.local_dir, args.hf_token)

tests/unit_tests/test_tokenizer.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import os
8+
import shutil
9+
import tempfile
10+
import unittest
11+
12+
from scripts.download_tokenizer import hf_download_tokenizer
13+
14+
from tokenizers import Tokenizer
15+
16+
from torchtitan.components.tokenizer import load_tokenizer
17+
18+
19+
class TestTokenizerIntegration(unittest.TestCase):
20+
"""Test integration between download_tokenizer and load_tokenizer functions."""
21+
22+
def setUp(self):
23+
"""Create a temporary directory for test files."""
24+
self.temp_dir = tempfile.mkdtemp()
25+
26+
def tearDown(self):
27+
"""Clean up temporary directory."""
28+
shutil.rmtree(self.temp_dir)
29+
30+
def test_download_and_load_tokenizer_integration(self):
31+
"""
32+
Test downloading tokenizer files and loading them, comparing with official APIs.
33+
34+
This test:
35+
1. Downloads tokenizer files using hf_download_tokenizer
36+
2. Loads tokenizer using our load_tokenizer function
37+
3. Compares behavior with official Tokenizer library
38+
4. Compares with transformers AutoTokenizer (if available)
39+
"""
40+
# Use a smaller, accessible model for testing
41+
test_repo_id = "deepseek-ai/DeepSeek-V3"
42+
43+
# Step 1: Download tokenizer files
44+
hf_download_tokenizer(
45+
repo_id=test_repo_id,
46+
local_dir=self.temp_dir,
47+
hf_token=None, # Public model, no token needed
48+
)
49+
50+
# Step 2: Load tokenizer using our function
51+
model_name = test_repo_id.split("/")[-1]
52+
tokenizer_path = os.path.join(self.temp_dir, model_name)
53+
our_tokenizer = load_tokenizer(tokenizer_path)
54+
55+
# Step 3: Load tokenizer using official Tokenizer library
56+
official_tokenizer = Tokenizer.from_pretrained(test_repo_id)
57+
58+
# Step 4: Load tokenizer using transformers AutoTokenizer (if available)
59+
transformers_tokenizer = None
60+
try:
61+
from transformers import AutoTokenizer
62+
63+
transformers_tokenizer = AutoTokenizer.from_pretrained(test_repo_id)
64+
except Exception:
65+
pass # Skip transformers comparison if not available
66+
67+
# Step 5: Compare underlying tokenizer attributes
68+
# Test that our_tokenizer.tokenizer has the same attributes as official_tokenizer
69+
70+
# Get the underlying tokenizer from our wrapper
71+
our_underlying_tokenizer = our_tokenizer.tokenizer
72+
73+
# Compare key attributes that should be identical
74+
# Vocabulary size
75+
self.assertEqual(
76+
our_underlying_tokenizer.get_vocab_size(),
77+
official_tokenizer.get_vocab_size(),
78+
"Vocabulary sizes should match",
79+
)
80+
81+
# Compare vocabularies (this might be large, so we'll sample some tokens)
82+
our_vocab = our_underlying_tokenizer.get_vocab()
83+
official_vocab = official_tokenizer.get_vocab()
84+
85+
# Test a few common tokens to ensure vocabularies match
86+
common_test_tokens = ["hello", "world", "the", "and", "is", "a"]
87+
for token in common_test_tokens:
88+
if token in our_vocab and token in official_vocab:
89+
self.assertEqual(
90+
our_vocab[token],
91+
official_vocab[token],
92+
f"Token '{token}' should have the same ID in both tokenizers",
93+
)
94+
95+
# Compare special tokens if they exist
96+
# Get added tokens from both tokenizers
97+
our_added_tokens = our_underlying_tokenizer.get_added_tokens_decoder()
98+
official_added_tokens = official_tokenizer.get_added_tokens_decoder()
99+
100+
# Compare the number of added tokens
101+
self.assertEqual(
102+
len(our_added_tokens),
103+
len(official_added_tokens),
104+
"Number of added special tokens should match",
105+
)
106+
107+
# Compare each added token
108+
for token_id, our_token in our_added_tokens.items():
109+
if token_id in official_added_tokens:
110+
official_token = official_added_tokens[token_id]
111+
self.assertEqual(
112+
our_token.content,
113+
official_token.content,
114+
f"Special token content should match for ID {token_id}",
115+
)
116+
# Compare token properties if they exist
117+
if hasattr(our_token, "special") and hasattr(official_token, "special"):
118+
self.assertEqual(
119+
our_token.special,
120+
official_token.special,
121+
f"Special token 'special' property should match for token '{our_token.content}'",
122+
)
123+
124+
# Step 6: Compare with transformers tokenizer if available
125+
if transformers_tokenizer:
126+
# Test text encoding/decoding with transformers tokenizer
127+
text = "Hello world! This is a test."
128+
129+
# Get tokens from our tokenizer (using the wrapper's encode method)
130+
our_tokens = our_tokenizer.encode(text)
131+
our_decoded_text = our_tokenizer.decode(our_tokens)
132+
133+
# Verify our tokenizer produces expected output
134+
self.assertIsInstance(our_tokens, list)
135+
self.assertEqual(our_decoded_text, text)
136+
137+
# Get tokens from transformers tokenizer
138+
transformers_tokens = transformers_tokenizer.encode(text)
139+
transformers_decoded = transformers_tokenizer.decode(transformers_tokens)
140+
141+
# Compare our tokens with transformers tokens
142+
self.assertEqual(
143+
our_tokens,
144+
transformers_tokens,
145+
f"Tokens should match between our tokenizer and transformers tokenizer for input: '{text}'",
146+
)
147+
148+
149+
if __name__ == "__main__":
150+
unittest.main()

0 commit comments

Comments
 (0)