Skip to content

Commit be855fa

Browse files
Prefer chat_template.json for chat template (#184)
* Prefer chat_template.json for chat template * Refinements
1 parent 1d26ce0 commit be855fa

File tree

2 files changed

+58
-9
lines changed

2 files changed

+58
-9
lines changed

Sources/Hub/Hub.swift

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -177,28 +177,39 @@ public class LanguageModelConfigurationFromHub {
177177
modelName: String,
178178
hubApi: HubApi = .shared
179179
) async throws -> Configurations {
180-
let filesToDownload = ["config.json", "tokenizer_config.json", "tokenizer.json"]
180+
let filesToDownload = ["config.json", "tokenizer_config.json", "chat_template.json", "tokenizer.json"]
181181
let repo = Hub.Repo(id: modelName)
182182
let downloadedModelFolder = try await hubApi.snapshot(from: repo, matching: filesToDownload)
183183

184184
return try await loadConfig(modelFolder: downloadedModelFolder, hubApi: hubApi)
185185
}
186-
186+
187187
func loadConfig(
188188
modelFolder: URL,
189189
hubApi: HubApi = .shared
190190
) async throws -> Configurations {
191-
// Note tokenizerConfig may be nil (does not exist in all models)
191+
// Load required configurations
192192
let modelConfig = try hubApi.configuration(fileURL: modelFolder.appending(path: "config.json"))
193-
let tokenizerConfig = try? hubApi.configuration(fileURL: modelFolder.appending(path: "tokenizer_config.json"))
194-
let tokenizerVocab = try hubApi.configuration(fileURL: modelFolder.appending(path: "tokenizer.json"))
195-
196-
let configs = Configurations(
193+
let tokenizerData = try hubApi.configuration(fileURL: modelFolder.appending(path: "tokenizer.json"))
194+
// Load tokenizer config
195+
var tokenizerConfig = try? hubApi.configuration(fileURL: modelFolder.appending(path: "tokenizer_config.json"))
196+
// Check for chat template and merge if available
197+
if let chatTemplateConfig = try? hubApi.configuration(fileURL: modelFolder.appending(path: "chat_template.json")),
198+
let chatTemplate = chatTemplateConfig.chatTemplate?.stringValue {
199+
// The value of chat_template could also be an array of strings, but we're not handling that case here, since it's discouraged.
200+
// Create or update tokenizer config with chat template
201+
if var configDict = tokenizerConfig?.dictionary {
202+
configDict["chat_template"] = chatTemplate
203+
tokenizerConfig = Config(configDict)
204+
} else {
205+
tokenizerConfig = Config(["chat_template": chatTemplate])
206+
}
207+
}
208+
return Configurations(
197209
modelConfig: modelConfig,
198210
tokenizerConfig: tokenizerConfig,
199-
tokenizerData: tokenizerVocab
211+
tokenizerData: tokenizerData
200212
)
201-
return configs
202213
}
203214

204215
static func fallbackTokenizerConfig(for modelType: String) -> Config? {

Tests/TokenizersTests/ChatTemplateTests.swift

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,44 @@ What is the weather in Paris today?<|im_end|>
178178
XCTAssertTrue(tokenizer.hasChatTemplate)
179179
}
180180

181+
// Test for vision models with a vision chat template in chat_template.json
182+
func testChatTemplateFromChatTemplateJson() async throws {
183+
let visionMessages = [
184+
[
185+
"role": "user",
186+
"content": [
187+
[
188+
"type": "text",
189+
"text": "What's in this image?",
190+
] as [String: String],
191+
[
192+
"type": "image",
193+
"image_url": "example.jpg",
194+
] as [String: String],
195+
] as [[String: String]],
196+
] as [String: Any]
197+
] as [[String: Any]]
198+
// Qwen 2 VL does not have a chat_template.json file. The chat template is in tokenizer_config.json.
199+
let qwen2VLTokenizer = try await AutoTokenizer.from(pretrained: "mlx-community/Qwen2-VL-7B-Instruct-4bit")
200+
// Qwen 2.5 VL has a chat_template.json file with a different chat template than the one in tokenizer_config.json.
201+
let qwen2_5VLTokenizer = try await AutoTokenizer.from(pretrained: "mlx-community/Qwen2.5-VL-7B-Instruct-4bit")
202+
let qwen2VLEncoded = try qwen2VLTokenizer.applyChatTemplate(messages: visionMessages)
203+
let qwen2VLDecoded = qwen2VLTokenizer.decode(tokens: qwen2VLEncoded)
204+
let qwen2_5VLEncoded = try qwen2_5VLTokenizer.applyChatTemplate(messages: visionMessages)
205+
let qwen2_5VLDecoded = qwen2_5VLTokenizer.decode(tokens: qwen2_5VLEncoded)
206+
let expectedOutput = """
207+
<|im_start|>system
208+
You are a helpful assistant.<|im_end|>
209+
<|im_start|>user
210+
What's in this image?<|vision_start|><|image_pad|><|vision_end|><|im_end|>
211+
<|im_start|>assistant
212+
213+
"""
214+
XCTAssertEqual(qwen2VLEncoded, qwen2_5VLEncoded, "Encoded sequences should be equal")
215+
XCTAssertEqual(qwen2VLDecoded, qwen2_5VLDecoded, "Decoded sequences should be equal")
216+
XCTAssertEqual(qwen2_5VLDecoded, expectedOutput, "Decoded sequence should match expected output")
217+
}
218+
181219
func testApplyTemplateError() async throws {
182220
let tokenizer = try await AutoTokenizer.from(pretrained: "google-bert/bert-base-uncased")
183221
XCTAssertFalse(tokenizer.hasChatTemplate)

0 commit comments

Comments
 (0)