Skip to content

Commit 9062cac

Browse files
authored
1 parent 8a83416 commit 9062cac

File tree

2 files changed

+70
-1
lines changed

2 files changed

+70
-1
lines changed

Sources/Tokenizers/Tokenizer.swift

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ enum TokenizerError: Error {
1717
case malformedVocab
1818
case chatTemplate(String)
1919
case tooLong(String)
20+
case mismatchedConfig(String)
2021
}
2122

2223
public protocol TokenizingModel {
@@ -530,6 +531,49 @@ class T5Tokenizer : UnigramTokenizer {}
530531

531532
let sentencePieceUnderline = ""
532533

534+
// Hack for Llama tokenizers, see https://github.com/huggingface/transformers/blob/bcb841f0073fcd7a4fb88ea8064313c17dcab04a/src/transformers/models/llama/tokenization_llama_fast.py#L181
535+
// Return updated config, or nil
536+
func maybeUpdatePostProcessor(tokenizerConfig: Config, processorConfig: Config?) throws -> Config? {
537+
538+
// If it's already a Template processor (instead of a ByteLevel one), assume it's correct
539+
let postProcessor = PostProcessorFactory.fromConfig(config: processorConfig)
540+
guard !(postProcessor is TemplateProcessing) else { return nil }
541+
542+
let addBosToken = tokenizerConfig.addBosToken?.boolValue ?? false
543+
let bosToken = addedTokenAsString(tokenizerConfig.bosToken)
544+
if addBosToken && bosToken == nil {
545+
throw TokenizerError.mismatchedConfig("add_bos_token is True but bos_token is nil")
546+
}
547+
548+
let addEosToken = tokenizerConfig.addEosToken?.boolValue ?? false
549+
let eosToken = addedTokenAsString(tokenizerConfig.eosToken)
550+
if addEosToken && eosToken == nil {
551+
throw TokenizerError.mismatchedConfig("add_eos_token is True but eos_token is nil")
552+
}
553+
554+
// alt implementation
555+
var single: [[String : Any]] = []
556+
if addBosToken {
557+
single = single + [["SpecialToken": ["id": bosToken!, "type_id": 0]]]
558+
}
559+
single = single + [["Sequence": ["id": "A", "type_id": 0]]]
560+
if addEosToken {
561+
single = single + [["SpecialToken": ["id": eosToken!, "type_id": 0]]]
562+
}
563+
564+
var pair: [[String : Any]] = single
565+
if addBosToken {
566+
pair = pair + [["SpecialToken": ["id": bosToken!, "type_id": 1]]]
567+
}
568+
pair = pair + [["Sequence": ["id": "B", "type_id": 1]]]
569+
if addEosToken {
570+
pair = pair + [["SpecialToken": ["id": eosToken!, "type_id": 1]]]
571+
}
572+
573+
let postProcessorConfig = Config(["type": PostProcessorType.TemplateProcessing.rawValue, "single": single, "pair": pair])
574+
return postProcessorConfig
575+
}
576+
533577
// See https://github.com/xenova/transformers.js/blob/1a9964fb09b8f54fcbeac46dc6aae8d76795809d/src/tokenizers.js#L3203 for these exceptions
534578
class LlamaPreTrainedTokenizer: PreTrainedTokenizer {
535579
let isLegacy: Bool
@@ -541,8 +585,13 @@ class LlamaPreTrainedTokenizer: PreTrainedTokenizer {
541585
configDictionary.removeValue(forKey: "normalizer")
542586
configDictionary["pre_tokenizer"] = ["type": "Metaspace", "replacement": sentencePieceUnderline, "add_prefix_space": true, "prepend_scheme": "first"]
543587
}
544-
let updatedData = Config(configDictionary)
545588

589+
if let postProcessorConfig = try maybeUpdatePostProcessor(tokenizerConfig: tokenizerConfig, processorConfig: tokenizerData.postProcessor) {
590+
configDictionary["post_processor"] = postProcessorConfig.dictionary
591+
}
592+
593+
let updatedData = Config(configDictionary)
546594
try super.init(tokenizerConfig: tokenizerConfig, tokenizerData: updatedData)
547595
}
548596
}
597+

Tests/TokenizersTests/TokenizerTests.swift

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,26 @@ class PhiSimpleTests: XCTestCase {
120120
}
121121
}
122122

123+
class LlamaPostProcessorOverrideTests: XCTestCase {
124+
/// Deepseek needs a post-processor override to add a bos token as in the reference implementation
125+
func testDeepSeek() async throws {
126+
guard let tokenizer = try await AutoTokenizer.from(pretrained: "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B") as? PreTrainedTokenizer else {
127+
XCTFail()
128+
return
129+
}
130+
XCTAssertEqual(tokenizer.encode(text: "Who are you?"), [151646, 15191, 525, 498, 30])
131+
}
132+
133+
/// Some Llama tokenizers already use a bos-prepending Template post-processor
134+
func testLlama() async throws {
135+
guard let tokenizer = try await AutoTokenizer.from(pretrained: "coreml-projects/Llama-2-7b-chat-coreml") as? PreTrainedTokenizer else {
136+
XCTFail()
137+
return
138+
}
139+
XCTAssertEqual(tokenizer.encode(text: "Who are you?"), [1, 11644, 526, 366, 29973])
140+
}
141+
}
142+
123143
class BertDiacriticsTests: XCTestCase {
124144
func testBertCased() async throws {
125145
guard let tokenizer = try await AutoTokenizer.from(pretrained: "distilbert/distilbert-base-multilingual-cased") as? PreTrainedTokenizer else {

0 commit comments

Comments
 (0)