From 42724d37a35fea4d116ea144b87204f224439e34 Mon Sep 17 00:00:00 2001 From: Howard Huang Date: Wed, 18 Jun 2025 09:07:58 -0700 Subject: [PATCH] Support different tokenizers --- .ci/docker/requirements.txt | 1 + .gitignore | 4 + README.md | 4 +- scripts/download_tokenizer.py | 156 +++++++++-- tests/unit_tests/test_tokenizer.py | 410 +++++++++++++++++++++++++++++ torchtitan/components/tokenizer.py | 398 ++++++++++++++++++++++++++++ 6 files changed, 945 insertions(+), 28 deletions(-) create mode 100644 tests/unit_tests/test_tokenizer.py diff --git a/.ci/docker/requirements.txt b/.ci/docker/requirements.txt index eec7e7ab5..11eae863f 100644 --- a/.ci/docker/requirements.txt +++ b/.ci/docker/requirements.txt @@ -8,3 +8,4 @@ tabulate wandb fsspec tyro +tokenizers >= 0.15.0 diff --git a/.gitignore b/.gitignore index e39990d72..65697a14b 100644 --- a/.gitignore +++ b/.gitignore @@ -13,7 +13,11 @@ out wandb torchtitan/datasets/**/*.model + +# tokenizer models assets/**/*.model +assets/**/*.json +assets/**/*.txt torchtitan/experiments/flux/assets/* # temp files diff --git a/README.md b/README.md index c826da4b8..a194bcb7d 100644 --- a/README.md +++ b/README.md @@ -103,8 +103,8 @@ Once you have confirmed access, you can run the following command to download th ```bash # Get your HF token from https://huggingface.co/settings/tokens -# Llama 3.1 tokenizer.model -python scripts/download_tokenizer.py --repo_id meta-llama/Meta-Llama-3.1-8B --tokenizer_path "original" --hf_token=... +# Llama 3.1 tokenizer +python scripts/download_tokenizer.py --repo_id meta-llama/Meta-Llama-3.1-8B --hf_token=... ``` ### Start a training run diff --git a/scripts/download_tokenizer.py b/scripts/download_tokenizer.py index 0a9d1af70..a28dc4992 100644 --- a/scripts/download_tokenizer.py +++ b/scripts/download_tokenizer.py @@ -9,57 +9,161 @@ from requests.exceptions import HTTPError -def hf_download( - repo_id: str, tokenizer_path: str, local_dir: str, hf_token: Optional[str] = None +def download_hf_tokenizer_files( + repo_id: str, + local_dir: str, + hf_token: Optional[str] = None, + additional_patterns: Optional[list] = None, ) -> None: - from huggingface_hub import hf_hub_download + """ + Download relevant tokenizer files from HuggingFace Hub repository. - tokenizer_path = ( - f"{tokenizer_path}/tokenizer.model" if tokenizer_path else "tokenizer.model" - ) + This function recursively searches through the HuggingFace Hub repository + and downloads all tokenizer-related files to enable tokenizer + loading with the build_hf_tokenizer() function. - try: - hf_hub_download( - repo_id=repo_id, - filename=tokenizer_path, - local_dir=local_dir, - local_dir_use_symlinks=False, - token=hf_token, + Files downloaded: + - tokenizer.json - Modern HuggingFace tokenizers (complete definition) + - tokenizer_config.json - Tokenizer configuration and metadata + - tokenizer.model - SentencePiece model files (Llama, T5, etc.) + - vocab.txt - Plain text vocabulary files + - vocab.json - JSON vocabulary files + - merges.txt - BPE merge rules (GPT-2, RoBERTa style) + - special_tokens_map.json - Special token mappings + + Args: + repo_id (str): HuggingFace repository ID (e.g., "meta-llama/Meta-Llama-3.1-8B") + local_dir (str): Local directory to save tokenizer files. A subdirectory + named after the model will be created automatically. + hf_token (Optional[str]): HuggingFace API token for accessing private repositories. + Required for gated models like Llama. + additional_patterns (Optional[list]): Additional file patterns to search for and download + from the HuggingFace Hub repository. + """ + import os + + from huggingface_hub import hf_hub_download, list_repo_files + + # Extract model name from repo_id (part after "/") + if "/" not in repo_id: + raise ValueError( + f"Invalid repo_id format: '{repo_id}'. Expected format: 'organization/model-name'" ) + model_name = repo_id.split("/")[-1].strip() + model_dir = os.path.join(local_dir, model_name) + + # Tokenizer file patterns to match (case-insensitive) + tokenizer_patterns = [ + "tokenizer.json", + "tokenizer_config.json", + "tokenizer.model", + "vocab.txt", + "vocab.json", + "merges.txt", + "special_tokens_map.json", + ] + + # Add additional files if provided + if additional_patterns: + tokenizer_patterns.extend(additional_patterns) + + def is_tokenizer_file(filename: str) -> bool: + """Check if a file is a tokenizer-related file.""" + filename_lower = filename.lower() + basename = os.path.basename(filename_lower) + + # Check exact matches + if basename in [pattern.lower() for pattern in tokenizer_patterns]: + return True + + return False + + try: + # Get list of available files in the repo + print(f"Scanning repository {repo_id} for tokenizer files...") + available_files = list_repo_files(repo_id=repo_id, token=hf_token) + + # Filter for tokenizer files + tokenizer_files_found = [f for f in available_files if is_tokenizer_file(f)] + + if not tokenizer_files_found: + print(f"Warning: No tokenizer files found in {repo_id}") + print(f"Available files: {available_files[:10]}...") + return + + print(f"Found {len(tokenizer_files_found)} tokenizer files:") + for f in tokenizer_files_found: + print(f" - {f}") + + downloaded_files = [] + for filename in tokenizer_files_found: + try: + hf_hub_download( + repo_id=repo_id, + filename=filename, + local_dir=model_dir, + token=hf_token, + ) + file_path = os.path.join(model_dir, filename) + print(f"Successfully downloaded {filename} to {file_path}") + downloaded_files.append(filename) + except HTTPError as e: + if e.response.status_code == 404: + print(f"File {filename} not found, skipping...") + continue + else: + raise e + + if downloaded_files: + print( + f"\nSuccessfully downloaded {len(downloaded_files)} tokenizer files to: {model_dir}" + ) + else: + print(f"Warning: No tokenizer files could be downloaded from {repo_id}") + except HTTPError as e: if e.response.status_code == 401: print( "You need to pass a valid `--hf_token=...` to download private checkpoints." ) - else: - raise e + raise e if __name__ == "__main__": import argparse - parser = argparse.ArgumentParser(description="Download tokenizer from HuggingFace.") + parser = argparse.ArgumentParser( + description="Download tokenizer files from HuggingFace Hub. " + "Automatically detects and downloads common tokenizer files (tokenizer.json, " + "tokenizer_config.json, tokenizer.model, ...) that work with Tokenizer." + ) parser.add_argument( "--repo_id", type=str, - default="meta-llama/Meta-Llama-3.1-8B", - help="Repository ID to download from. default to Llama-3.1-8B", + required=True, + help="Repository ID to download from (e.g., 'meta-llama/Meta-Llama-3.1-8B', 'deepseek-ai/DeepSeek-V3')", ) parser.add_argument( - "--tokenizer_path", + "--hf_token", type=str, - default="original", - help="the tokenizer.model path relative to repo_id", - ) - parser.add_argument( - "--hf_token", type=str, default=None, help="HuggingFace API token" + default=None, + help="HuggingFace API token (required for private repos)", ) parser.add_argument( "--local_dir", type=str, default="assets/tokenizer/", - help="local directory to save the tokenizer.model", + help="Local directory to save tokenizer files (default: assets/tokenizer/)", + ) + parser.add_argument( + "--additional_patterns", + type=str, + nargs="*", + default=None, + help="Additional file patterns to search for and download from the HuggingFace Hub repository", ) args = parser.parse_args() - hf_download(args.repo_id, args.tokenizer_path, args.local_dir, args.hf_token) + download_hf_tokenizer_files( + args.repo_id, args.local_dir, args.hf_token, args.additional_patterns + ) diff --git a/tests/unit_tests/test_tokenizer.py b/tests/unit_tests/test_tokenizer.py new file mode 100644 index 000000000..8efd48167 --- /dev/null +++ b/tests/unit_tests/test_tokenizer.py @@ -0,0 +1,410 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +import shutil +import tempfile +import unittest + +from requests.exceptions import HTTPError + +from scripts.download_tokenizer import download_hf_tokenizer_files + +from tokenizers import Tokenizer +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, +) + +from torchtitan.components.tokenizer import build_hf_tokenizer + + +class TestTokenizerIntegration(unittest.TestCase): + """Test integration between download_tokenizer and load_tokenizer functions.""" + + def setUp(self): + """Create a temporary directory for test files.""" + self.temp_dir = tempfile.mkdtemp() + + def tearDown(self): + """Clean up temporary directory.""" + shutil.rmtree(self.temp_dir) + + def _compare_tokenizers(self, our_tokenizer, reference_tokenizer, test_repo_id): + """ + Comprehensive comparison between our tokenizer and a reference tokenizer. + Supports both tokenizers library and transformers library tokenizers. + + Args: + our_tokenizer: Our HuggingFaceTokenizer instance or underlying tokenizer + reference_tokenizer: Reference tokenizer (tokenizers.Tokenizer or transformers tokenizer) + test_repo_id: Repository ID for context in error messages + """ + # Detect tokenizer type and create adapter functions + is_transformers = hasattr(reference_tokenizer, "vocab_size") and not hasattr( + reference_tokenizer, "get_vocab_size" + ) + + if is_transformers: + # Transformers tokenizer API + def get_vocab_size(tokenizer): + return len(tokenizer.get_vocab()) + + def get_vocab(tokenizer): + return tokenizer.get_vocab() + + def encode_text(tokenizer, text): + return tokenizer.encode(text) + + def decode_tokens(tokenizer, tokens): + return tokenizer.decode(tokens) + + def get_added_tokens_func(tokenizer): + # Transformers doesn't have get_added_tokens_decoder, so we'll skip this comparison + return {} + + tokenizer_type = "transformers" + else: + # Tokenizers library API + def get_vocab_size(tokenizer): + return len(tokenizer.get_vocab()) + + def get_vocab(tokenizer): + return tokenizer.get_vocab() + + def encode_text(tokenizer, text): + return tokenizer.encode(text).ids + + def decode_tokens(tokenizer, tokens): + return tokenizer.decode(tokens) + + def get_added_tokens_func(tokenizer): + return tokenizer.get_added_tokens_decoder() + + tokenizer_type = "tokenizers" + + # 1. Compare vocabulary sizes + self.assertEqual( + our_tokenizer.get_vocab_size(), + get_vocab_size(reference_tokenizer), + f"Vocabulary sizes should match for {test_repo_id} ({tokenizer_type})", + ) + + # 2. Compare vocabularies with more comprehensive sampling + our_vocab = our_tokenizer.get_vocab() + reference_vocab = get_vocab(reference_tokenizer) + + # Test common tokens + common_test_tokens = [ + "hello", + "world", + "the", + "and", + "is", + "a", + "to", + "of", + "in", + "for", + ] + for token in common_test_tokens: + if token in our_vocab and token in reference_vocab: + self.assertEqual( + our_vocab[token], + reference_vocab[token], + f"Token '{token}' should have the same ID in both tokenizers for {test_repo_id} ({tokenizer_type})", + ) + + # Test a random sample of tokens (more comprehensive than just common words) + import random + + vocab_keys = list(our_vocab.keys()) + if len(vocab_keys) > 50: + # Sample 50 random tokens for comparison + sample_tokens = random.sample(vocab_keys, 50) + for token in sample_tokens: + if token in reference_vocab: + self.assertEqual( + our_vocab[token], + reference_vocab[token], + f"Random sampled token '{token}' should have the same ID in \ +both tokenizers for {test_repo_id} ({tokenizer_type})", + ) + + # 3. Compare special tokens (only for tokenizers library, not transformers) + if not is_transformers: + our_added_tokens = our_tokenizer.get_added_tokens_decoder() + reference_added_tokens = get_added_tokens_func(reference_tokenizer) + + self.assertEqual( + len(our_added_tokens), + len(reference_added_tokens), + f"Number of added special tokens should match for {test_repo_id} ({tokenizer_type})", + ) + + # Compare each added token + for token_id, our_token in our_added_tokens.items(): + if token_id in reference_added_tokens: + reference_token = reference_added_tokens[token_id] + self.assertEqual( + our_token.content, + reference_token.content, + f"Special token content should match for ID {token_id} in {test_repo_id} ({tokenizer_type})", + ) + # Compare token properties if they exist + if hasattr(our_token, "special") and hasattr( + reference_token, "special" + ): + self.assertEqual( + our_token.special, + reference_token.special, + f"Special token 'special' property should match \ +for token '{our_token.content}' in {test_repo_id} ({tokenizer_type})", + ) + + # 4. Functional testing - encode/decode comparison + test_texts = [ + "Hello world!", + "This is a test.", + "The quick brown fox jumps over the lazy dog.", + "Special characters: @#$%^&*()", + "Numbers: 123456789", + "Mixed: Hello123 World!@#", + "", # Empty string + " ", # Single space + " ", # Multiple spaces + ] + + for text in test_texts: + # Compare encoding - handle different tokenizer types + if hasattr(our_tokenizer, "tokenizer"): + # Our wrapper tokenizer - returns list directly + our_tokens = our_tokenizer.encode(text) + else: + # Underlying HF tokenizer - returns object with .ids + our_encoded = our_tokenizer.encode(text) + our_tokens = ( + our_encoded.ids if hasattr(our_encoded, "ids") else our_encoded + ) + + reference_tokens = encode_text(reference_tokenizer, text) + + self.assertEqual( + our_tokens, + reference_tokens, + f"Encoded tokens should match for text '{text}' in {test_repo_id} ({tokenizer_type})", + ) + + # Compare decoding: + # for transformers-Tokenizers, skip_special_tokens=False by default + # for tokenizers library, skip_special_tokens=True by default + skip_special_tokens = not is_transformers + our_decoded = our_tokenizer.decode( + our_tokens, skip_special_tokens=skip_special_tokens + ) + reference_decoded = decode_tokens(reference_tokenizer, reference_tokens) + + self.assertEqual( + our_decoded, + reference_decoded, + f"Decoded text should match for '{text}' in {test_repo_id} ({tokenizer_type})", + ) + + # 5. Edge case testing + edge_cases = [ + "🚀🌟✨", # Emojis + "café naïve résumé", # Accented characters + "こんにちは世界", # Non-Latin scripts (Japanese) + "Здравствуй мир", # Cyrillic + "\n\t\r", # Whitespace characters + "a" * 1000, # Very long repeated character + ] + + for text in edge_cases: + # Handle different tokenizer types for edge cases too + if hasattr(our_tokenizer, "tokenizer"): + our_tokens = our_tokenizer.encode(text) + else: + our_encoded = our_tokenizer.encode(text) + our_tokens = ( + our_encoded.ids if hasattr(our_encoded, "ids") else our_encoded + ) + + reference_tokens = encode_text(reference_tokenizer, text) + + self.assertEqual( + our_tokens, + reference_tokens, + f"Edge case tokens should match for text '{text[:50]}...' in {test_repo_id} ({tokenizer_type})", + ) + + @parametrize( + "test_repo_id", + [ + "meta-llama/Meta-Llama-3.1-8B", + "deepseek-ai/DeepSeek-V3", + # "black-forest-labs/FLUX.1-dev", TODO: load the actual tokenizer + "Qwen/Qwen2-7B", + ], + ) + def test_download_and_build_tokenizer(self, test_repo_id): + """ + Test downloading tokenizer files and loading them, comparing with official APIs. + + This test: + 1. Downloads tokenizer files using download_hf_tokenizer_files + 2. Loads tokenizer using our load_tokenizer function + 3. Compares behavior with official Tokenizer library + 4. Compares with transformers AutoTokenizer (if available) + """ + # Step 1: Download tokenizer files + try: + download_hf_tokenizer_files( + repo_id=test_repo_id, + local_dir=self.temp_dir, + ) + except HTTPError as e: + if test_repo_id == "meta-llama/Meta-Llama-3.1-8B": + self.skipTest( + f"Could not download tokenizer files for Meta-Llama-3.1-8B: {e}" + ) + else: + raise e + + # Step 2: Load tokenizer using our function + model_name = test_repo_id.split("/")[-1] + tokenizer_dir = "tokenizer" if model_name == "FLUX.1-dev" else "." + tokenizer_path = os.path.join(self.temp_dir, model_name, tokenizer_dir) + our_tokenizer = build_hf_tokenizer(tokenizer_path) + + # Step 3: Load tokenizer using official Tokenizer library (if available) + official_tokenizer = None + try: + official_tokenizer = Tokenizer.from_pretrained(test_repo_id) + except Exception as e: + print(f"Warning: Could not load official tokenizer for {test_repo_id}: {e}") + + # Step 4: Load tokenizer using transformers AutoTokenizer (if available) + transformers_tokenizer = None + try: + from transformers import AutoTokenizer + + transformers_tokenizer = AutoTokenizer.from_pretrained(test_repo_id) + except Exception as e: + print(f"Warning: Could not load AutoTokenizer for {test_repo_id}: {e}") + + # Step 5: Compare underlying tokenizer attributes (only if official tokenizer is available) + if official_tokenizer: + self._compare_tokenizers( + our_tokenizer.tokenizer, official_tokenizer, test_repo_id + ) + + # Step 6: Compare with transformers tokenizer if available + if transformers_tokenizer: + self._compare_tokenizers( + our_tokenizer, transformers_tokenizer, test_repo_id + ) + + def test_backward_comptability(self): + from torchtitan.datasets.tokenizer.tiktoken import TikTokenizer + + # The existing tokenizer lives under assets/original/tokenizer.model + # This test ensures that the new tokenizer can load the old tokenizer + # and produce the same results + + # Get the base project directory (two levels up from test file) + base_project_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) + old_tokenizer_path = os.path.join( + base_project_dir, "assets", "tokenizer", "original", "tokenizer.model" + ) + + # Skip test if the old tokenizer path cannot be found + if not os.path.exists(old_tokenizer_path): + self.skipTest(f"Old tokenizer file not found at {old_tokenizer_path}") + + print(old_tokenizer_path) + old_tokenizer = TikTokenizer(old_tokenizer_path) + + # Download and load a new tokenizer for comparison (using Meta-Llama-3.1-8B) + test_repo_id = "meta-llama/Meta-Llama-3.1-8B" + try: + download_hf_tokenizer_files( + repo_id=test_repo_id, + local_dir=self.temp_dir, + ) + + # Load the new tokenizer + model_name = test_repo_id.split("/")[-1] + new_tokenizer_path = os.path.join(self.temp_dir, model_name) + new_tokenizer = build_hf_tokenizer(new_tokenizer_path) + + # Compare encoding and decoding functionality only (TikTokenizer doesn't support vocab operations) + test_texts = [ + "Hello world!", + "This is a test.", + "The quick brown fox jumps over the lazy dog.", + "Special characters: @#$%^&*()", + "Numbers: 123456789", + "Mixed: Hello123 World!@#", + "", # Empty string + " ", # Single space + " ", # Multiple spaces + ] + + for text in test_texts: + # Encode with both tokenizers + # TikTokenizer requires bos and eos parameters + old_tokens = old_tokenizer.encode(text, bos=True, eos=False) + # HuggingFaceTokenizer has optional add_bos and add_eos parameters + new_tokens = new_tokenizer.encode(text) + + self.assertEqual( + old_tokens, + new_tokens, + f"Encoded tokens should match for text '{text}' in backward compatibility test", + ) + + # Test decoding + old_decoded = old_tokenizer.decode(old_tokens) + new_decoded = new_tokenizer.decode( + new_tokens, skip_special_tokens=False + ) + + self.assertEqual( + old_decoded, + new_decoded, + f"Decoded text should match for '{text}' in backward compatibility test", + ) + + # Test edge cases + edge_cases = [ + "🚀🌟✨", # Emojis + "café naïve résumé", # Accented characters + "こんにちは世界", # Non-Latin scripts (Japanese) + "Здравствуй мир", # Cyrillic + "\n\t\r", # Whitespace characters + "a" + * 100, # Long repeated character (reduced from 1000 to avoid tiktoken limits) + ] + + for text in edge_cases: + old_tokens = old_tokenizer.encode(text, bos=True, eos=False) + new_tokens = new_tokenizer.encode(text) + + self.assertEqual( + old_tokens, + new_tokens, + f"Edge case tokens should match for text '{text[:50]}...' in backward compatibility test", + ) + + except HTTPError as e: + self.skipTest(f"Could not download new tokenizer for comparison: {e}") + + +instantiate_parametrized_tests(TestTokenizerIntegration) + +if __name__ == "__main__": + unittest.main() diff --git a/torchtitan/components/tokenizer.py b/torchtitan/components/tokenizer.py index 47e8179d2..def7594ae 100644 --- a/torchtitan/components/tokenizer.py +++ b/torchtitan/components/tokenizer.py @@ -5,7 +5,13 @@ # LICENSE file in the root directory of this source tree. +import json +import os from abc import ABC, abstractmethod +from typing import Any, Optional + +from tokenizers import AddedToken, Tokenizer as HfTokenizer +from typing_extensions import override class Tokenizer(ABC): @@ -25,3 +31,395 @@ def decode(self, *args, **kwargs) -> str: @property def n_words(self) -> int: return self._n_words + + +class HuggingFaceTokenizer(Tokenizer): + """ + A tokenizer wrapper that handles BOS/EOS token inference and encoding. + + This class loads tokenizer files and automatically infers BOS/EOS tokens from + a configuration file (tokenizer_config.json). It provides an encode method that adds + BOS/EOS tokens based on whether the underlying tokenizer adds them automatically. + + Args: + tokenizer_path (str): Path to directory containing tokenizer files + """ + + def __init__( + self, + tokenizer_path: str, + ): + self.tokenizer_path = tokenizer_path + + # Initialize BOS/EOS token attributes (frequently used) + self.bos_id = None + self.eos_id = None + self.bos_token = None + self.eos_token = None + + # Load the underlying tokenizer + self.tokenizer = self._load_tokenizer_from_path(tokenizer_path) + + # Load configuration files + self.config = self._load_config( + os.path.join(tokenizer_path, "tokenizer_config.json") + ) + + # Infer special tokens and adding BOS/EOS behavior + self._infer_special_tokens() + self._infer_should_add_bos_eos() + + def _load_config(self, config_path: str) -> Optional[dict]: + """Load configuration from JSON file if it exists.""" + if os.path.exists(config_path): + with open(config_path, "r") as f: + return json.load(f) + return None + + def _load_tokenizer_from_path(self, tokenizer_path: str) -> HfTokenizer: + """Load tokenizer from various file formats.""" + if not os.path.exists(tokenizer_path): + raise FileNotFoundError(f"Tokenizer path '{tokenizer_path}' does not exist") + + # Define paths for different tokenizer file types + tokenizer_json_path = os.path.join(tokenizer_path, "tokenizer.json") + vocab_txt_path = os.path.join(tokenizer_path, "vocab.txt") + vocab_json_path = os.path.join(tokenizer_path, "vocab.json") + merges_txt_path = os.path.join(tokenizer_path, "merges.txt") + + try: + # Strategy 1: Load from tokenizer.json (preferred for modern tokenizers) + if os.path.exists(tokenizer_json_path): + print("Loading tokenizer from tokenizer.json") + return HfTokenizer.from_file(tokenizer_json_path) + # Strategy 2: Load from vocab files (with or without merges.txt) + elif os.path.exists(vocab_json_path) or os.path.exists(vocab_txt_path): + # Load vocabulary + if os.path.exists(vocab_json_path): + print("Loading vocabulary from vocab.json") + with open(vocab_json_path, "r") as f: + vocab = json.load(f) + vocab_source = "vocab.json" + else: + print("Loading vocabulary from vocab.txt") + vocab = {} + with open(vocab_txt_path, "r") as f: + for i, line in enumerate(f): + token = line.strip() + if token: + vocab[token] = i + vocab_source = "vocab.txt" + + # Strategy 2a: Use BPE if merges.txt exists + if os.path.exists(merges_txt_path): + print(f"Loading BPE tokenizer from {vocab_source} + merges.txt") + from tokenizers import decoders, pre_tokenizers, processors + from tokenizers.models import BPE + + # Load merges from file and convert to tuples + merges = [] + with open(merges_txt_path, "r") as f: + for line in f: + line = line.strip() + if line and not line.startswith( + "#" + ): # Skip comments and empty lines + parts = line.split() + if len(parts) >= 2: + merges.append((parts[0], parts[1])) + + # Create BPE model + bpe_model = BPE(vocab=vocab, merges=merges) + tokenizer = HfTokenizer(bpe_model) + + # Configure GPT-2 style components for proper space handling + tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel( + add_prefix_space=False + ) + tokenizer.decoder = decoders.ByteLevel() + tokenizer.post_processor = processors.ByteLevel(trim_offsets=True) + + return tokenizer + + # Strategy 2b: Use WordLevel if no merges.txt + else: + print(f"Loading WordLevel tokenizer from {vocab_source}") + from tokenizers.models import WordLevel + + word_level_model = WordLevel(vocab=vocab, unk_token="[UNK]") + return HfTokenizer(word_level_model) + + else: + # List available files for debugging + available_files = [ + f + for f in os.listdir(tokenizer_path) + if os.path.isfile(os.path.join(tokenizer_path, f)) + ] + raise FileNotFoundError( + f"No supported tokenizer files found in '{tokenizer_path}'. " + f"Available files: {available_files}. " + "Looking for: tokenizer.json, tokenizer.model, vocab.txt+merges.txt, or vocab.json+merges.txt" + ) + + except Exception as e: + if isinstance(e, FileNotFoundError): + raise e + raise Exception( + f"Failed to load tokenizer from '{tokenizer_path}': {e}" + ) from e + + def _get_token_from_config(self, config: dict[str, Any], key: str) -> Optional[str]: + """ + Parse special tokens from config that can be either strings or dicts. + HF tokens are stored as either {'bos_token': ''} or {'bos_token': {'content': '', ...}}. + """ + token = config.get(key) + if isinstance(token, dict): + if "content" not in token: + raise ValueError(f"Could not parse {key} from config") + token = token["content"] + elif token is not None and not isinstance(token, str): + raise ValueError( + f"Could not parse {key} from config - expected string or dict" + ) + return token + + def _process_special_token( + self, token_str: str, token_config: dict, token_id: Optional[int] = None + ) -> AddedToken: + """ + Process a special token and update BOS/EOS attributes if applicable. + + Args: + token_str: The token string content + token_config: Token configuration dictionary + token_id: Optional explicit token ID (for added_tokens_decoder) + + Returns: + AddedToken object to be added to the tokenizer + """ + # Get reference BOS/EOS tokens from config for comparison + config_bos_token = ( + self._get_token_from_config(self.config, "bos_token") + if self.config + else None + ) + config_eos_token = ( + self._get_token_from_config(self.config, "eos_token") + if self.config + else None + ) + + # Store BOS/EOS tokens as class attributes if they match + if token_str == config_bos_token: + self.bos_token = token_str + self.bos_id = ( + token_id + if token_id is not None + else self.tokenizer.token_to_id(token_str) + ) + elif token_str == config_eos_token: + self.eos_token = token_str + self.eos_id = ( + token_id + if token_id is not None + else self.tokenizer.token_to_id(token_str) + ) + + # Create AddedToken object based on config format + if isinstance(token_config, dict): + if token_config.get("__type") == "AddedToken" or "content" in token_config: + # Handle both AddedToken format and added_tokens_decoder format + return AddedToken( + content=token_str, + single_word=token_config.get("single_word", False), + lstrip=token_config.get("lstrip", False), + rstrip=token_config.get("rstrip", False), + normalized=token_config.get("normalized", True), + special=token_config.get("special", True), + ) + + # Fallback to simple special token + return AddedToken(content=token_str, special=True) + + def _infer_special_tokens(self): + """ + Read special tokens from config and add them to the underlying tokenizer. + Store BOS/EOS tokens as class attributes since they are frequently used. + + This method handles multiple token configuration formats: + 1. Standard top-level keys (bos_token, eos_token, etc.) + 2. added_tokens_decoder dictionary (used by models like Llama 3.1) + """ + standard_keys = [ + "bos_token", + "eos_token", + "pad_token", + "unk_token", + "sep_token", + "cls_token", + "mask_token", + ] + + # List to collect AddedToken objects for updating the underlying tokenizer + added_tokens_to_add = [] + + if not self.config: + return + + # Process standard top-level token keys + for key in standard_keys: + token_config = self.config.get(key) + if token_config is not None: + token_str = self._get_token_from_config(self.config, key) + if token_str is not None: + added_token = self._process_special_token(token_str, token_config) + added_tokens_to_add.append(added_token) + + # Process added_tokens_decoder (comprehensive special token definitions) + added_tokens_decoder = self.config.get("added_tokens_decoder", {}) + for token_id_str, token_config in added_tokens_decoder.items(): + if isinstance(token_config, dict) and "content" in token_config: + token_str = token_config["content"] + token_id = int(token_id_str) + added_token = self._process_special_token( + token_str, token_config, token_id + ) + added_tokens_to_add.append(added_token) + + # Update the underlying tokenizer with special tokens + if added_tokens_to_add: + self.tokenizer.add_special_tokens(added_tokens_to_add) + + # Update BOS/EOS token IDs after adding to tokenizer (in case they changed) + if self.bos_token: + self.bos_id = self.tokenizer.token_to_id(self.bos_token) + if self.eos_token: + self.eos_id = self.tokenizer.token_to_id(self.eos_token) + + def _infer_should_add_bos_eos(self): + """ + Determine if we should add BOS/EOS tokens based on config settings. + If config explicitly specifies add_bos_token/add_eos_token, follow that. + Otherwise, determine if the underlying tokenizer automatically adds them. + """ + self.default_add_bos = False + self.default_add_eos = False + self.hf_adds_bos = False + self.hf_adds_eos = False + + # First, determine if underlying tokenizer auto-adds BOS/EOS tokens empirically + encoded_empty_str = self.tokenizer.encode("").ids + if self.bos_id is not None and self.bos_id in encoded_empty_str: + self.hf_adds_bos = True + if self.eos_id is not None and self.eos_id in encoded_empty_str: + self.hf_adds_eos = True + + # Check tokenizer_config.json for explicit settings - these override empirical detection + if self.config: + config_add_bos = self.config.get("add_bos_token") + config_add_eos = self.config.get("add_eos_token") + if config_add_bos is not None: + self.default_add_bos = bool(config_add_bos) + if config_add_eos is not None: + self.default_add_eos = bool(config_add_eos) + + def encode(self, *args, **kwargs) -> list[int]: + """ + Encode text into token IDs with BOS/EOS handling. + + Args: + text (str): The text to encode + add_bos (bool): Whether to add BOS token (if not already added by tokenizer) + add_eos (bool): Whether to add EOS token (if not already added by tokenizer) + + Returns: + list[int]: List of token IDs + """ + # Extract arguments + if len(args) >= 1: + text = args[0] + else: + text = kwargs.get("text", "") + + add_bos = kwargs.get("add_bos", self.default_add_bos) + add_eos = kwargs.get("add_eos", self.default_add_eos) + + # Get base token IDs from the underlying tokenizer + token_ids = self.tokenizer.encode(text).ids + + # Add BOS token if requested and not already added by tokenizer + if not self.hf_adds_bos and add_bos: + if self.bos_id is not None: + token_ids.insert(0, self.bos_id) + + # Add EOS token if requested and not already added by tokenizer + if not self.hf_adds_eos and add_eos: + if self.eos_id is not None: + token_ids.append(self.eos_id) + + return token_ids + + @override + def decode(self, *args, **kwargs) -> str: + """ + Decode token IDs back to text. + + Args: + token_ids (list[int]): List of token IDs to decode + **kwargs: Additional arguments passed to the underlying tokenizer's decode method + (e.g., skip_special_tokens) + + Returns: + str: Decoded text + """ + # Extract token_ids from arguments + if len(args) >= 1: + token_ids = args[0] + # Pass through remaining kwargs + return self.tokenizer.decode(token_ids, **kwargs) + else: + token_ids = kwargs.pop("token_ids", []) + # Pass through remaining kwargs after removing token_ids + return self.tokenizer.decode(token_ids, **kwargs) + + @property + def vocab_size(self) -> int: + """Get the vocabulary size.""" + return len(self.tokenizer.get_vocab()) + + def get_vocab_size(self) -> int: + """Get the vocabulary size.""" + return len(self.tokenizer.get_vocab()) + + def get_vocab(self) -> dict[str, int]: + """Get the vocabulary as a dictionary.""" + return self.tokenizer.get_vocab() + + def token_to_id(self, token: str) -> Optional[int]: + """Convert token to ID.""" + return self.tokenizer.token_to_id(token) + + def id_to_token(self, token_id: int) -> Optional[str]: + """Convert ID to token.""" + return self.tokenizer.id_to_token(token_id) + + +def build_hf_tokenizer(tokenizer_path: str) -> HuggingFaceTokenizer: + """ + Builds a HuggingFaceTokenizer from the specified path. + + This function creates a HuggingFaceTokenizer instance that handles BOS/EOS token + inference and intelligent encoding. The tokenizer automatically detects and loads + from various file formats and infers special token behavior. + + Args: + tokenizer_path (str): Path to the directory containing tokenizer files. + Should contain one or more of the supported file types. + + Returns: + tokenizer (HuggingFaceTokenizer): Loaded tokenizer instance with intelligent BOS/EOS handling + """ + tokenizer = HuggingFaceTokenizer(tokenizer_path) + return tokenizer