|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import os |
| 8 | +import shutil |
| 9 | +import tempfile |
| 10 | +import unittest |
| 11 | + |
| 12 | +from scripts.download_tokenizer import download_hf_tokenizer_files |
| 13 | + |
| 14 | +from tokenizers import Tokenizer |
| 15 | + |
| 16 | +from torchtitan.components.tokenizer import build_hf_tokenizer |
| 17 | + |
| 18 | + |
| 19 | +class TestTokenizerIntegration(unittest.TestCase): |
| 20 | + """Test integration between download_tokenizer and load_tokenizer functions.""" |
| 21 | + |
| 22 | + def setUp(self): |
| 23 | + """Create a temporary directory for test files.""" |
| 24 | + self.temp_dir = tempfile.mkdtemp() |
| 25 | + |
| 26 | + def tearDown(self): |
| 27 | + """Clean up temporary directory.""" |
| 28 | + shutil.rmtree(self.temp_dir) |
| 29 | + |
| 30 | + def test_download_and_load_tokenizer_integration(self): |
| 31 | + """ |
| 32 | + Test downloading tokenizer files and loading them, comparing with official APIs. |
| 33 | +
|
| 34 | + This test: |
| 35 | + 1. Downloads tokenizer files using download_hf_tokenizer_files |
| 36 | + 2. Loads tokenizer using our load_tokenizer function |
| 37 | + 3. Compares behavior with official Tokenizer library |
| 38 | + 4. Compares with transformers AutoTokenizer (if available) |
| 39 | + """ |
| 40 | + # Use a smaller, accessible model for testing |
| 41 | + test_repo_id = "deepseek-ai/DeepSeek-V3" |
| 42 | + |
| 43 | + # Step 1: Download tokenizer files |
| 44 | + download_hf_tokenizer_files( |
| 45 | + repo_id=test_repo_id, |
| 46 | + local_dir=self.temp_dir, |
| 47 | + hf_token=None, # Public model, no token needed |
| 48 | + ) |
| 49 | + |
| 50 | + # Step 2: Load tokenizer using our function |
| 51 | + model_name = test_repo_id.split("/")[-1] |
| 52 | + tokenizer_path = os.path.join(self.temp_dir, model_name) |
| 53 | + our_tokenizer = build_hf_tokenizer(tokenizer_path) |
| 54 | + |
| 55 | + # Step 3: Load tokenizer using official Tokenizer library |
| 56 | + official_tokenizer = Tokenizer.from_pretrained(test_repo_id) |
| 57 | + |
| 58 | + # Step 4: Load tokenizer using transformers AutoTokenizer (if available) |
| 59 | + transformers_tokenizer = None |
| 60 | + try: |
| 61 | + from transformers import AutoTokenizer |
| 62 | + |
| 63 | + transformers_tokenizer = AutoTokenizer.from_pretrained(test_repo_id) |
| 64 | + except Exception: |
| 65 | + pass # Skip transformers comparison if not available |
| 66 | + |
| 67 | + # Step 5: Compare underlying tokenizer attributes |
| 68 | + # Test that our_tokenizer.tokenizer has the same attributes as official_tokenizer |
| 69 | + |
| 70 | + # Get the underlying tokenizer from our wrapper |
| 71 | + our_underlying_tokenizer = our_tokenizer.tokenizer |
| 72 | + |
| 73 | + # Compare key attributes that should be identical |
| 74 | + # Vocabulary size |
| 75 | + self.assertEqual( |
| 76 | + our_underlying_tokenizer.get_vocab_size(), |
| 77 | + official_tokenizer.get_vocab_size(), |
| 78 | + "Vocabulary sizes should match", |
| 79 | + ) |
| 80 | + |
| 81 | + # Compare vocabularies (this might be large, so we'll sample some tokens) |
| 82 | + our_vocab = our_underlying_tokenizer.get_vocab() |
| 83 | + official_vocab = official_tokenizer.get_vocab() |
| 84 | + |
| 85 | + # Test a few common tokens to ensure vocabularies match |
| 86 | + common_test_tokens = ["hello", "world", "the", "and", "is", "a"] |
| 87 | + for token in common_test_tokens: |
| 88 | + if token in our_vocab and token in official_vocab: |
| 89 | + self.assertEqual( |
| 90 | + our_vocab[token], |
| 91 | + official_vocab[token], |
| 92 | + f"Token '{token}' should have the same ID in both tokenizers", |
| 93 | + ) |
| 94 | + |
| 95 | + # Compare special tokens if they exist |
| 96 | + # Get added tokens from both tokenizers |
| 97 | + our_added_tokens = our_underlying_tokenizer.get_added_tokens_decoder() |
| 98 | + official_added_tokens = official_tokenizer.get_added_tokens_decoder() |
| 99 | + |
| 100 | + # Compare the number of added tokens |
| 101 | + self.assertEqual( |
| 102 | + len(our_added_tokens), |
| 103 | + len(official_added_tokens), |
| 104 | + "Number of added special tokens should match", |
| 105 | + ) |
| 106 | + |
| 107 | + # Compare each added token |
| 108 | + for token_id, our_token in our_added_tokens.items(): |
| 109 | + if token_id in official_added_tokens: |
| 110 | + official_token = official_added_tokens[token_id] |
| 111 | + self.assertEqual( |
| 112 | + our_token.content, |
| 113 | + official_token.content, |
| 114 | + f"Special token content should match for ID {token_id}", |
| 115 | + ) |
| 116 | + # Compare token properties if they exist |
| 117 | + if hasattr(our_token, "special") and hasattr(official_token, "special"): |
| 118 | + self.assertEqual( |
| 119 | + our_token.special, |
| 120 | + official_token.special, |
| 121 | + f"Special token 'special' property should match for token '{our_token.content}'", |
| 122 | + ) |
| 123 | + |
| 124 | + # Step 6: Compare with transformers tokenizer if available |
| 125 | + if transformers_tokenizer: |
| 126 | + # Test text encoding/decoding with transformers tokenizer |
| 127 | + text = "Hello world! This is a test." |
| 128 | + |
| 129 | + # Get tokens from our tokenizer (using the wrapper's encode method) |
| 130 | + our_tokens = our_tokenizer.encode(text) |
| 131 | + our_decoded_text = our_tokenizer.decode(our_tokens) |
| 132 | + |
| 133 | + # Verify our tokenizer produces expected output |
| 134 | + self.assertIsInstance(our_tokens, list) |
| 135 | + self.assertEqual(our_decoded_text, text) |
| 136 | + |
| 137 | + # Get tokens from transformers tokenizer |
| 138 | + transformers_tokens = transformers_tokenizer.encode(text) |
| 139 | + transformers_decoded = transformers_tokenizer.decode(transformers_tokens) |
| 140 | + |
| 141 | + # Compare our tokens with transformers tokens |
| 142 | + self.assertEqual( |
| 143 | + our_tokens, |
| 144 | + transformers_tokens, |
| 145 | + f"Tokens should match between our tokenizer and transformers tokenizer for input: '{text}'", |
| 146 | + ) |
| 147 | + |
| 148 | + |
| 149 | +if __name__ == "__main__": |
| 150 | + unittest.main() |
0 commit comments