Skip to content

Commit 4813c49

Browse files
committed
Support different tokenizers
1 parent 722f6e2 commit 4813c49

File tree

6 files changed

+872
-26
lines changed

6 files changed

+872
-26
lines changed

.ci/docker/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@ tabulate
88
wandb
99
fsspec
1010
tyro
11+
tokenizers >= 0.15.0

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,11 @@ out
1313
wandb
1414

1515
torchtitan/datasets/**/*.model
16+
17+
# tokenizer models
1618
assets/**/*.model
19+
assets/**/*.json
20+
assets/**/*.txt
1721
torchtitan/experiments/flux/assets/*
1822

1923
# temp files

README.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,11 @@ Once you have confirmed access, you can run the following command to download th
103103
```bash
104104
# Get your HF token from https://huggingface.co/settings/tokens
105105

106-
# Llama 3.1 tokenizer.model
107-
python scripts/download_tokenizer.py --repo_id meta-llama/Meta-Llama-3.1-8B --tokenizer_path "original" --hf_token=...
106+
# Llama 3.1 tokenizer (automatically downloads original/tokenizer.model)
107+
python scripts/download_tokenizer.py --repo_id meta-llama/Meta-Llama-3.1-8B --hf_token=...
108+
109+
# DeepSeek tokenizer (automatically downloads tokenizer.json and tokenizer_config.json)
110+
python scripts/download_tokenizer.py --repo_id deepseek-ai/DeepSeek-V3
108111
```
109112

110113
### Start a training run

scripts/download_tokenizer.py

Lines changed: 129 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,23 +9,118 @@
99
from requests.exceptions import HTTPError
1010

1111

12-
def hf_download(
13-
repo_id: str, tokenizer_path: str, local_dir: str, hf_token: Optional[str] = None
12+
def download_hf_tokenizer_files(
13+
repo_id: str,
14+
local_dir: str,
15+
hf_token: Optional[str] = None,
16+
additional_patterns: Optional[list] = None,
1417
) -> None:
15-
from huggingface_hub import hf_hub_download
18+
"""
19+
Download relevant tokenizer files from HuggingFace Hub repository.
1620
17-
tokenizer_path = (
18-
f"{tokenizer_path}/tokenizer.model" if tokenizer_path else "tokenizer.model"
19-
)
21+
This function recursively searches through the HuggingFace Hub repository
22+
and downloads all tokenizer-related files to enable tokenizer
23+
loading with the build_hf_tokenizer() function.
2024
21-
try:
22-
hf_hub_download(
23-
repo_id=repo_id,
24-
filename=tokenizer_path,
25-
local_dir=local_dir,
26-
local_dir_use_symlinks=False,
27-
token=hf_token,
25+
Files downloaded:
26+
- tokenizer.json - Modern HuggingFace tokenizers (complete definition)
27+
- tokenizer_config.json - Tokenizer configuration and metadata
28+
- tokenizer.model - SentencePiece model files (Llama, T5, etc.)
29+
- vocab.txt - Plain text vocabulary files
30+
- vocab.json - JSON vocabulary files
31+
- merges.txt - BPE merge rules (GPT-2, RoBERTa style)
32+
- special_tokens_map.json - Special token mappings
33+
34+
Args:
35+
repo_id (str): HuggingFace repository ID (e.g., "meta-llama/Meta-Llama-3.1-8B")
36+
local_dir (str): Local directory to save tokenizer files. A subdirectory
37+
named after the model will be created automatically.
38+
hf_token (Optional[str]): HuggingFace API token for accessing private repositories.
39+
Required for gated models like Llama.
40+
additional_patterns (Optional[list]): Additional file patterns to search for and download
41+
from the HuggingFace Hub repository.
42+
"""
43+
import os
44+
45+
from huggingface_hub import hf_hub_download, list_repo_files
46+
47+
# Extract model name from repo_id (part after "/")
48+
if "/" not in repo_id:
49+
raise ValueError(
50+
f"Invalid repo_id format: '{repo_id}'. Expected format: 'organization/model-name'"
2851
)
52+
model_name = repo_id.split("/")[-1].strip()
53+
model_dir = os.path.join(local_dir, model_name)
54+
55+
# Tokenizer file patterns to match (case-insensitive)
56+
tokenizer_patterns = [
57+
"tokenizer.json",
58+
"tokenizer_config.json",
59+
"tokenizer.model",
60+
"vocab.txt",
61+
"vocab.json",
62+
"merges.txt",
63+
"special_tokens_map.json",
64+
]
65+
66+
# Add additional files if provided
67+
if additional_patterns:
68+
tokenizer_patterns.extend(additional_patterns)
69+
70+
def is_tokenizer_file(filename: str) -> bool:
71+
"""Check if a file is a tokenizer-related file."""
72+
filename_lower = filename.lower()
73+
basename = os.path.basename(filename_lower)
74+
75+
# Check exact matches
76+
if basename in [pattern.lower() for pattern in tokenizer_patterns]:
77+
return True
78+
79+
return False
80+
81+
try:
82+
# Get list of available files in the repo
83+
print(f"Scanning repository {repo_id} for tokenizer files...")
84+
available_files = list_repo_files(repo_id=repo_id, token=hf_token)
85+
86+
# Filter for tokenizer files
87+
tokenizer_files_found = [f for f in available_files if is_tokenizer_file(f)]
88+
89+
if not tokenizer_files_found:
90+
print(f"Warning: No tokenizer files found in {repo_id}")
91+
print(f"Available files: {available_files[:10]}...")
92+
return
93+
94+
print(f"Found {len(tokenizer_files_found)} tokenizer files:")
95+
for f in tokenizer_files_found:
96+
print(f" - {f}")
97+
98+
downloaded_files = []
99+
for filename in tokenizer_files_found:
100+
try:
101+
hf_hub_download(
102+
repo_id=repo_id,
103+
filename=filename,
104+
local_dir=model_dir,
105+
token=hf_token,
106+
)
107+
file_path = os.path.join(model_dir, filename)
108+
print(f"Successfully downloaded {filename} to {file_path}")
109+
downloaded_files.append(filename)
110+
except HTTPError as e:
111+
if e.response.status_code == 404:
112+
print(f"File {filename} not found, skipping...")
113+
continue
114+
else:
115+
raise e
116+
117+
if downloaded_files:
118+
print(
119+
f"\nSuccessfully downloaded {len(downloaded_files)} tokenizer files to: {model_dir}"
120+
)
121+
else:
122+
print(f"Warning: No tokenizer files could be downloaded from {repo_id}")
123+
29124
except HTTPError as e:
30125
if e.response.status_code == 401:
31126
print(
@@ -38,28 +133,38 @@ def hf_download(
38133
if __name__ == "__main__":
39134
import argparse
40135

41-
parser = argparse.ArgumentParser(description="Download tokenizer from HuggingFace.")
136+
parser = argparse.ArgumentParser(
137+
description="Download tokenizer files from HuggingFace Hub. "
138+
"Automatically detects and downloads common tokenizer files (tokenizer.json, "
139+
"tokenizer_config.json, tokenizer.model, ...) that work with Tokenizer."
140+
)
42141
parser.add_argument(
43142
"--repo_id",
44143
type=str,
45-
default="meta-llama/Meta-Llama-3.1-8B",
46-
help="Repository ID to download from. default to Llama-3.1-8B",
144+
required=True,
145+
help="Repository ID to download from (e.g., 'meta-llama/Meta-Llama-3.1-8B', 'deepseek-ai/DeepSeek-V3')",
47146
)
48147
parser.add_argument(
49-
"--tokenizer_path",
148+
"--hf_token",
50149
type=str,
51-
default="original",
52-
help="the tokenizer.model path relative to repo_id",
53-
)
54-
parser.add_argument(
55-
"--hf_token", type=str, default=None, help="HuggingFace API token"
150+
default=None,
151+
help="HuggingFace API token (required for private repos)",
56152
)
57153
parser.add_argument(
58154
"--local_dir",
59155
type=str,
60156
default="assets/tokenizer/",
61-
help="local directory to save the tokenizer.model",
157+
help="Local directory to save tokenizer files (default: assets/tokenizer/)",
158+
)
159+
parser.add_argument(
160+
"--additional_patterns",
161+
type=str,
162+
nargs="*",
163+
default=None,
164+
help="Additional file patterns to search for and download from the HuggingFace Hub repository",
62165
)
63166

64167
args = parser.parse_args()
65-
hf_download(args.repo_id, args.tokenizer_path, args.local_dir, args.hf_token)
168+
download_hf_tokenizer_files(
169+
args.repo_id, args.local_dir, args.hf_token, args.additional_patterns
170+
)

0 commit comments

Comments
 (0)