diff --git a/Sources/Hub/Hub.swift b/Sources/Hub/Hub.swift index 4116dcb..b303736 100644 --- a/Sources/Hub/Hub.swift +++ b/Sources/Hub/Hub.swift @@ -177,28 +177,39 @@ public class LanguageModelConfigurationFromHub { modelName: String, hubApi: HubApi = .shared ) async throws -> Configurations { - let filesToDownload = ["config.json", "tokenizer_config.json", "tokenizer.json"] + let filesToDownload = ["config.json", "tokenizer_config.json", "chat_template.json", "tokenizer.json"] let repo = Hub.Repo(id: modelName) let downloadedModelFolder = try await hubApi.snapshot(from: repo, matching: filesToDownload) return try await loadConfig(modelFolder: downloadedModelFolder, hubApi: hubApi) } - + func loadConfig( modelFolder: URL, hubApi: HubApi = .shared ) async throws -> Configurations { - // Note tokenizerConfig may be nil (does not exist in all models) + // Load required configurations let modelConfig = try hubApi.configuration(fileURL: modelFolder.appending(path: "config.json")) - let tokenizerConfig = try? hubApi.configuration(fileURL: modelFolder.appending(path: "tokenizer_config.json")) - let tokenizerVocab = try hubApi.configuration(fileURL: modelFolder.appending(path: "tokenizer.json")) - - let configs = Configurations( + let tokenizerData = try hubApi.configuration(fileURL: modelFolder.appending(path: "tokenizer.json")) + // Load tokenizer config + var tokenizerConfig = try? hubApi.configuration(fileURL: modelFolder.appending(path: "tokenizer_config.json")) + // Check for chat template and merge if available + if let chatTemplateConfig = try? hubApi.configuration(fileURL: modelFolder.appending(path: "chat_template.json")), + let chatTemplate = chatTemplateConfig.chatTemplate?.stringValue { + // The value of chat_template could also be an array of strings, but we're not handling that case here, since it's discouraged. + // Create or update tokenizer config with chat template + if var configDict = tokenizerConfig?.dictionary { + configDict["chat_template"] = chatTemplate + tokenizerConfig = Config(configDict) + } else { + tokenizerConfig = Config(["chat_template": chatTemplate]) + } + } + return Configurations( modelConfig: modelConfig, tokenizerConfig: tokenizerConfig, - tokenizerData: tokenizerVocab + tokenizerData: tokenizerData ) - return configs } static func fallbackTokenizerConfig(for modelType: String) -> Config? { diff --git a/Tests/TokenizersTests/ChatTemplateTests.swift b/Tests/TokenizersTests/ChatTemplateTests.swift index 13d897d..88e1843 100644 --- a/Tests/TokenizersTests/ChatTemplateTests.swift +++ b/Tests/TokenizersTests/ChatTemplateTests.swift @@ -178,6 +178,44 @@ What is the weather in Paris today?<|im_end|> XCTAssertTrue(tokenizer.hasChatTemplate) } + // Test for vision models with a vision chat template in chat_template.json + func testChatTemplateFromChatTemplateJson() async throws { + let visionMessages = [ + [ + "role": "user", + "content": [ + [ + "type": "text", + "text": "What's in this image?", + ] as [String: String], + [ + "type": "image", + "image_url": "example.jpg", + ] as [String: String], + ] as [[String: String]], + ] as [String: Any] + ] as [[String: Any]] + // Qwen 2 VL does not have a chat_template.json file. The chat template is in tokenizer_config.json. + let qwen2VLTokenizer = try await AutoTokenizer.from(pretrained: "mlx-community/Qwen2-VL-7B-Instruct-4bit") + // Qwen 2.5 VL has a chat_template.json file with a different chat template than the one in tokenizer_config.json. + let qwen2_5VLTokenizer = try await AutoTokenizer.from(pretrained: "mlx-community/Qwen2.5-VL-7B-Instruct-4bit") + let qwen2VLEncoded = try qwen2VLTokenizer.applyChatTemplate(messages: visionMessages) + let qwen2VLDecoded = qwen2VLTokenizer.decode(tokens: qwen2VLEncoded) + let qwen2_5VLEncoded = try qwen2_5VLTokenizer.applyChatTemplate(messages: visionMessages) + let qwen2_5VLDecoded = qwen2_5VLTokenizer.decode(tokens: qwen2_5VLEncoded) + let expectedOutput = """ +<|im_start|>system +You are a helpful assistant.<|im_end|> +<|im_start|>user +What's in this image?<|vision_start|><|image_pad|><|vision_end|><|im_end|> +<|im_start|>assistant + +""" + XCTAssertEqual(qwen2VLEncoded, qwen2_5VLEncoded, "Encoded sequences should be equal") + XCTAssertEqual(qwen2VLDecoded, qwen2_5VLDecoded, "Decoded sequences should be equal") + XCTAssertEqual(qwen2_5VLDecoded, expectedOutput, "Decoded sequence should match expected output") + } + func testApplyTemplateError() async throws { let tokenizer = try await AutoTokenizer.from(pretrained: "google-bert/bert-base-uncased") XCTAssertFalse(tokenizer.hasChatTemplate)