Skip to content

Commit 4d25d20

Browse files
Support multiple chat templates per model (#134)
* Improve chat template parsing * Clean up * Improve chat template selection * Add tests for chat templates * Update Sources/Tokenizers/Tokenizer.swift Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * Improve template selection * More elegant solution for chatTemplate argument * Update Sources/Tokenizers/Tokenizer.swift Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * Add overload with `chatTemplate` argument of type `String` --------- Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
1 parent a7a61a2 commit 4d25d20

File tree

2 files changed

+158
-17
lines changed

2 files changed

+158
-17
lines changed

Sources/Tokenizers/Tokenizer.swift

Lines changed: 85 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@ import Hub
99
import Foundation
1010
import Jinja
1111

12-
enum TokenizerError : Error {
12+
enum TokenizerError: Error {
1313
case missingConfig
1414
case missingTokenizerClassInConfig
1515
case unsupportedTokenizer(String)
1616
case missingVocab
1717
case malformedVocab
18-
18+
case chatTemplate(String)
1919
case tooLong(String)
2020
}
2121

@@ -94,6 +94,13 @@ struct TokenizerModel {
9494
}
9595
}
9696

97+
public enum ChatTemplateArgument {
98+
/// A Jinja template to use for the conversation. Normally it is not necessary to provide a template, since it will be read from the tokenizer config.
99+
case literal(String)
100+
/// For models whose tokenizer config includes multiple chat templates, the template can be specified by name. Normally this is not necessary.
101+
case name(String)
102+
}
103+
97104
public protocol Tokenizer {
98105
func tokenize(text: String) -> [String]
99106

@@ -117,15 +124,24 @@ public protocol Tokenizer {
117124
var eosTokenId: Int? { get }
118125
var unknownToken: String? { get }
119126
var unknownTokenId: Int? { get }
120-
127+
128+
/// The appropriate chat template is selected from the tokenizer config
121129
func applyChatTemplate(messages: [[String: String]]) throws -> [Int]
122-
130+
131+
/// The chat template is provided as a string literal or specified by name
132+
func applyChatTemplate(messages: [[String: String]], chatTemplate: ChatTemplateArgument) throws -> [Int]
133+
134+
/// The chat template is provided as a string literal
135+
func applyChatTemplate(messages: [[String: String]], chatTemplate: String) throws -> [Int]
136+
123137
func applyChatTemplate(
124138
messages: [[String: String]],
125-
chatTemplate: String?,
139+
/// A chat template can optionally be provided or specified by name when several templates are included in the tokenizer config. Normally this is not necessary.
140+
chatTemplate: ChatTemplateArgument?,
126141
addGenerationPrompt: Bool,
127142
truncation: Bool,
128-
maxLength: Int?
143+
maxLength: Int?,
144+
tools: [[String: Any]]?
129145
) throws -> [Int]
130146
}
131147

@@ -176,8 +192,6 @@ public class PreTrainedTokenizer: Tokenizer {
176192
private let tokenizerConfig: Config
177193

178194
private let cleanUpTokenizationSpaces: Bool
179-
180-
private let defaultChatTemplate: String = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
181195

182196
required public init(tokenizerConfig: Config, tokenizerData: Config) throws {
183197
var addedTokens: [String : Int] = [:]
@@ -222,7 +236,7 @@ public class PreTrainedTokenizer: Tokenizer {
222236
self.decoder = DecoderFactory.fromConfig(config: tokenizerData.decoder)
223237
self.cleanUpTokenizationSpaces = tokenizerConfig.cleanUpTokenizationSpaces?.boolValue ?? true
224238
self.tokenizerConfig = tokenizerConfig
225-
239+
226240
model = try TokenizerModel.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData, addedTokens: addedTokens)
227241
}
228242

@@ -316,22 +330,76 @@ public class PreTrainedTokenizer: Tokenizer {
316330
public func convertIdToToken(_ id: Int) -> String? {
317331
model.convertIdToToken(id)
318332
}
319-
333+
320334
public func applyChatTemplate(messages: [[String: String]]) throws -> [Int] {
321-
try applyChatTemplate(messages: messages, chatTemplate: nil, addGenerationPrompt: true, maxLength: nil)
335+
try applyChatTemplate(messages: messages, addGenerationPrompt: true)
336+
}
337+
338+
public func applyChatTemplate(messages: [[String: String]], chatTemplate: ChatTemplateArgument) throws -> [Int] {
339+
try applyChatTemplate(messages: messages, chatTemplate: chatTemplate, addGenerationPrompt: true)
322340
}
323-
341+
342+
public func applyChatTemplate(messages: [[String: String]], chatTemplate: String) throws -> [Int] {
343+
try applyChatTemplate(messages: messages, chatTemplate: .literal(chatTemplate), addGenerationPrompt: true)
344+
}
345+
324346
public func applyChatTemplate(
325347
messages: [[String: String]],
326-
chatTemplate: String?,
348+
chatTemplate: ChatTemplateArgument? = nil,
327349
addGenerationPrompt: Bool = false,
328350
truncation: Bool = false,
329-
maxLength: Int?
351+
maxLength: Int? = nil,
352+
/// A list of tools (callable functions) that will be accessible to the model. If the template does not
353+
/// support function calling, this argument will have no effect. Each tool should be passed as a JSON Schema,
354+
/// giving the name, description and argument types for the tool. See the
355+
/// [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use)
356+
/// for more information.
357+
/// Note: tool calling is not supported yet, it will be available in a future update.
358+
tools: [[String: Any]]? = nil
330359
) throws -> [Int] {
331-
let template = try Template(chatTemplate ?? tokenizerConfig.chatTemplate?.stringValue ?? defaultChatTemplate)
360+
var selectedChatTemplate: String?
361+
if let chatTemplate, case .literal(let template) = chatTemplate {
362+
// Use chat template from argument
363+
selectedChatTemplate = template
364+
} else if let valueFromConfig = tokenizerConfig.chatTemplate {
365+
if let arrayValue = valueFromConfig.arrayValue {
366+
// If the config specifies a list of chat templates, convert them to a dictionary
367+
let templateDict = Dictionary<String, String>(uniqueKeysWithValues: arrayValue.compactMap { item in
368+
guard let name = item.name?.stringValue, let template = item.template?.stringValue else {
369+
return nil
370+
}
371+
return (name, template)
372+
})
373+
if let chatTemplate, case .name(let name) = chatTemplate {
374+
// Select chat template from config by name
375+
if let matchingDictEntry = templateDict[name] {
376+
selectedChatTemplate = matchingDictEntry
377+
} else {
378+
throw TokenizerError.chatTemplate("No chat template named \"\(name)\" was found in the tokenizer config")
379+
}
380+
} else if let tools, !tools.isEmpty, let toolUseTemplate = templateDict["tool_use"] {
381+
// Use tool use chat template from config
382+
selectedChatTemplate = toolUseTemplate
383+
} else if let defaultChatTemplate = templateDict["default"] {
384+
// Use default chat template from config
385+
selectedChatTemplate = defaultChatTemplate
386+
}
387+
} else if let stringValue = valueFromConfig.stringValue {
388+
// Use chat template from config
389+
selectedChatTemplate = stringValue
390+
}
391+
}
392+
393+
guard let selectedChatTemplate else {
394+
throw TokenizerError.chatTemplate("No chat template was specified")
395+
}
396+
397+
let template = try Template(selectedChatTemplate)
332398
var context: [String: Any] = [
333399
"messages": messages,
334400
"add_generation_prompt": addGenerationPrompt
401+
// TODO: Add `tools` entry when support is added in Jinja
402+
// "tools": tools
335403
]
336404

337405
// TODO: maybe keep NSString here
@@ -397,15 +465,15 @@ extension AutoTokenizer {
397465

398466
return try AutoTokenizer.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)
399467
}
400-
468+
401469
public static func from(
402470
modelFolder: URL,
403471
hubApi: HubApi = .shared
404472
) async throws -> Tokenizer {
405473
let config = LanguageModelConfigurationFromHub(modelFolder: modelFolder, hubApi: hubApi)
406474
guard let tokenizerConfig = try await config.tokenizerConfig else { throw TokenizerError.missingConfig }
407475
let tokenizerData = try await config.tokenizerData
408-
476+
409477
return try PreTrainedTokenizer(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)
410478
}
411479
}
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
//
2+
// ChatTemplateTests.swift
3+
// swift-transformers
4+
//
5+
// Created by Anthony DePasquale on 2/10/24.
6+
//
7+
8+
import XCTest
9+
import Tokenizers
10+
11+
class ChatTemplateTests: XCTestCase {
12+
let messages = [[
13+
"role": "user",
14+
"content": "Describe the Swift programming language.",
15+
]]
16+
17+
func testTemplateFromConfig() async throws {
18+
let tokenizer = try await AutoTokenizer.from(pretrained: "microsoft/Phi-3-mini-128k-instruct")
19+
let encoded = try tokenizer.applyChatTemplate(messages: messages)
20+
let encodedTarget = [32010, 4002, 29581, 278, 14156, 8720, 4086, 29889, 32007, 32001]
21+
let decoded = tokenizer.decode(tokens: encoded)
22+
let decodedTarget = "<|user|>Describe the Swift programming language.<|end|><|assistant|>"
23+
XCTAssertEqual(encoded, encodedTarget)
24+
XCTAssertEqual(decoded, decodedTarget)
25+
}
26+
27+
func testDefaultTemplateFromArrayInConfig() async throws {
28+
let tokenizer = try await AutoTokenizer.from(pretrained: "mlx-community/Mistral-7B-Instruct-v0.3-4bit")
29+
let encoded = try tokenizer.applyChatTemplate(messages: messages)
30+
let encodedTarget = [1, 29473, 3, 28752, 1040, 4672, 2563, 17060, 4610, 29491, 29473, 4]
31+
let decoded = tokenizer.decode(tokens: encoded)
32+
let decodedTarget = "<s> [INST] Describe the Swift programming language. [/INST]"
33+
XCTAssertEqual(encoded, encodedTarget)
34+
XCTAssertEqual(decoded, decodedTarget)
35+
}
36+
37+
func testTemplateFromArgumentWithEnum() async throws {
38+
let tokenizer = try await AutoTokenizer.from(pretrained: "microsoft/Phi-3-mini-128k-instruct")
39+
// Purposely not using the correct template for this model to verify that the template from the config is not being used
40+
let mistral7BDefaultTemplate = "{{bos_token}}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ ' [INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + ' ' + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"
41+
let encoded = try tokenizer.applyChatTemplate(messages: messages, chatTemplate: .literal(mistral7BDefaultTemplate))
42+
let encodedTarget = [1, 518, 25580, 29962, 20355, 915, 278, 14156, 8720, 4086, 29889, 518, 29914, 25580, 29962]
43+
let decoded = tokenizer.decode(tokens: encoded)
44+
let decodedTarget = "<s> [INST] Describe the Swift programming language. [/INST]"
45+
XCTAssertEqual(encoded, encodedTarget)
46+
XCTAssertEqual(decoded, decodedTarget)
47+
}
48+
49+
func testTemplateFromArgumentWithString() async throws {
50+
let tokenizer = try await AutoTokenizer.from(pretrained: "microsoft/Phi-3-mini-128k-instruct")
51+
// Purposely not using the correct template for this model to verify that the template from the config is not being used
52+
let mistral7BDefaultTemplate = "{{bos_token}}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ ' [INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + ' ' + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"
53+
let encoded = try tokenizer.applyChatTemplate(messages: messages, chatTemplate: mistral7BDefaultTemplate)
54+
let encodedTarget = [1, 518, 25580, 29962, 20355, 915, 278, 14156, 8720, 4086, 29889, 518, 29914, 25580, 29962]
55+
let decoded = tokenizer.decode(tokens: encoded)
56+
let decodedTarget = "<s> [INST] Describe the Swift programming language. [/INST]"
57+
XCTAssertEqual(encoded, encodedTarget)
58+
XCTAssertEqual(decoded, decodedTarget)
59+
}
60+
61+
func testNamedTemplateFromArgument() async throws {
62+
let tokenizer = try await AutoTokenizer.from(pretrained: "mlx-community/Mistral-7B-Instruct-v0.3-4bit")
63+
// Normally it is not necessary to specify the name `default`, but I'm not aware of models with lists of templates in the config that are not `default` or `tool_use`
64+
let encoded = try tokenizer.applyChatTemplate(messages: messages, chatTemplate: .name("default"))
65+
let encodedTarget = [1, 29473, 3, 28752, 1040, 4672, 2563, 17060, 4610, 29491, 29473, 4]
66+
let decoded = tokenizer.decode(tokens: encoded)
67+
let decodedTarget = "<s> [INST] Describe the Swift programming language. [/INST]"
68+
XCTAssertEqual(encoded, encodedTarget)
69+
XCTAssertEqual(decoded, decodedTarget)
70+
}
71+
72+
// TODO: Add tests for tool use template
73+
}

0 commit comments

Comments
 (0)