Skip to content

Commit 20da1ca

Browse files
authored
Fix SmolLM3 support and add unit test to cover it (#102)
* Fix SmolLM3 support and add unit test to cover it * Add test for llama3.2 * Lint
1 parent d5d8680 commit 20da1ca

File tree

4 files changed

+62
-7
lines changed

4 files changed

+62
-7
lines changed

.github/workflows/pull.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ jobs:
3030
3131
# Install tokenizers
3232
pip install . -v
33-
pip install pytest blobfile
33+
pip install pytest blobfile transformers>=4.53.1
3434
3535
# Run tests
3636
pytest

.github/workflows/trunk.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ jobs:
3737
3838
# Install tokenizers
3939
${CONDA_RUN} pip install . -v
40-
${CONDA_RUN} pip install pytest blobfile
40+
${CONDA_RUN} pip install pytest blobfile transformers>=4.53.1
4141
4242
# Run tests
4343
${CONDA_RUN} pytest

src/hf_tokenizer.cpp

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -107,12 +107,19 @@ Error HFTokenizer::load(const std::string& path) {
107107
// Set up the normalizer (optional)
108108
try {
109109
TK_LOG(Info, "Setting up normalizer...");
110-
_normalizer =
111-
NormalizerConfig().parse_json(parsed_json.at("normalizer")).create();
112-
TK_LOG(Info, "Normalizer set up");
110+
const auto& normalizer_json = parsed_json.at("normalizer");
111+
if (!normalizer_json.is_null()) {
112+
_normalizer = NormalizerConfig().parse_json(normalizer_json).create();
113+
TK_LOG(Info, "Normalizer set up");
114+
} else {
115+
TK_LOG(Info, "Normalizer field is null, skipping");
116+
}
113117
} catch (const json::out_of_range& e) {
114-
// No normalizer specified, this is optional
115-
TK_LOG(Info, "No normalizer specified");
118+
// No "Normalizer" field found
119+
TK_LOG(
120+
Info,
121+
"No 'Normalizer' field found in json, out of range error: %s",
122+
e.what());
116123
}
117124

118125
// Set up the pre-tokenizer

test/test_hf_tokenizer.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
# @lint-ignore-every LICENSELINT
7+
8+
"""
9+
Test script for hf tokenizers.
10+
"""
11+
12+
import unittest
13+
from pytorch_tokenizers import CppHFTokenizer
14+
from transformers import AutoTokenizer
15+
from tempfile import TemporaryDirectory
16+
17+
PROMPT = "What is the capital of France?"
18+
19+
class TestHfTokenizer(unittest.TestCase):
20+
def setUp(self) -> None:
21+
self.temp_dir = TemporaryDirectory()
22+
super().setUp()
23+
24+
def test_smolLM3(self) -> None:
25+
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM3-3B")
26+
tokenizer_path = tokenizer.save_pretrained(self.temp_dir.name)[-1]
27+
28+
cpp_tokenizer = CppHFTokenizer()
29+
cpp_tokenizer.load(tokenizer_path)
30+
31+
tokens = tokenizer.encode(PROMPT)
32+
cpp_tokens = cpp_tokenizer.encode(PROMPT)
33+
self.assertEqual(tokens, cpp_tokens)
34+
35+
def test_llama3_2_1b(self) -> None:
36+
tokenizer = AutoTokenizer.from_pretrained("unsloth/Llama-3.2-1B-Instruct")
37+
tokenizer_path = tokenizer.save_pretrained(self.temp_dir.name)[-1]
38+
39+
cpp_tokenizer = CppHFTokenizer()
40+
cpp_tokenizer.load(tokenizer_path)
41+
42+
tokens = tokenizer.encode(PROMPT)
43+
cpp_tokens = cpp_tokenizer.encode(PROMPT, bos=1)
44+
self.assertEqual(tokens, cpp_tokens)
45+
46+
47+
async def test_async_DO_NOT_COMMIT(self) -> None:
48+
pass

0 commit comments

Comments
 (0)