Skip to content

Commit 7897a7e

Browse files
authored
Merge branch 'main' into main
2 parents 307a26c + ff81749 commit 7897a7e

File tree

5 files changed

+186
-20
lines changed

5 files changed

+186
-20
lines changed

Package.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ let package = Package(
1414
.executable(name: "hub-cli", targets: ["HubCLI"]),
1515
],
1616
dependencies: [
17-
.package(url: "https://github.com/apple/swift-argument-parser.git", from: "1.4.0"),
18-
.package(url: "https://github.com/johnmai-dev/Jinja", from: "1.1.0")
17+
.package(url: "https://github.com/apple/swift-argument-parser.git", .upToNextMinor(from: "1.4.0")),
18+
.package(url: "https://github.com/johnmai-dev/Jinja", .upToNextMinor(from: "1.1.0"))
1919
],
2020
targets: [
2121
.executableTarget(

Sources/Models/LanguageModel.swift

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,9 @@ public extension LanguageModel {
190190
var tokenizer: Tokenizer {
191191
get async throws {
192192
guard _tokenizer == nil else { return _tokenizer! }
193-
guard let tokenizerConfig = try await tokenizerConfig else { throw "Cannot retrieve Tokenizer configuration" }
193+
guard let tokenizerConfig = try await tokenizerConfig else {
194+
throw TokenizerError.tokenizerConfigNotFound
195+
}
194196
let tokenizerData = try await tokenizerData
195197
_tokenizer = try AutoTokenizer.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)
196198
return _tokenizer!
@@ -212,4 +214,6 @@ extension LanguageModel: TextGenerationModel {
212214
}
213215
}
214216

215-
extension String: Error {}
217+
public enum TokenizerError: Error {
218+
case tokenizerConfigNotFound
219+
}

Sources/Tokenizers/Tokenizer.swift

Lines changed: 88 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ import Hub
99
import Foundation
1010
import Jinja
1111

12+
public typealias Message = [String: Any]
13+
public typealias ToolSpec = [String: Any]
14+
1215
enum TokenizerError: Error {
1316
case missingConfig
1417
case missingTokenizerClassInConfig
@@ -142,23 +145,57 @@ public protocol Tokenizer {
142145
var unknownTokenId: Int? { get }
143146

144147
/// The appropriate chat template is selected from the tokenizer config
145-
func applyChatTemplate(messages: [[String: String]]) throws -> [Int]
148+
func applyChatTemplate(messages: [Message]) throws -> [Int]
149+
150+
/// The appropriate chat template is selected from the tokenizer config
151+
func applyChatTemplate(messages: [Message], tools: [ToolSpec]) throws -> [Int]
146152

147153
/// The chat template is provided as a string literal or specified by name
148-
func applyChatTemplate(messages: [[String: String]], chatTemplate: ChatTemplateArgument) throws -> [Int]
154+
func applyChatTemplate(messages: [Message], chatTemplate: ChatTemplateArgument) throws -> [Int]
149155

150156
/// The chat template is provided as a string literal
151-
func applyChatTemplate(messages: [[String: String]], chatTemplate: String) throws -> [Int]
157+
func applyChatTemplate(messages: [Message], chatTemplate: String) throws -> [Int]
152158

153159
func applyChatTemplate(
154-
messages: [[String: String]],
160+
messages: [Message],
155161
/// A chat template can optionally be provided or specified by name when several templates are included in the tokenizer config. Normally this is not necessary.
156162
chatTemplate: ChatTemplateArgument?,
157163
addGenerationPrompt: Bool,
158164
truncation: Bool,
159165
maxLength: Int?,
160-
tools: [[String: Any]]?
166+
tools: [ToolSpec]?
161167
) throws -> [Int]
168+
169+
func applyChatTemplate(
170+
messages: [Message],
171+
/// A chat template can optionally be provided or specified by name when several templates are included in the tokenizer config. Normally this is not necessary.
172+
chatTemplate: ChatTemplateArgument?,
173+
addGenerationPrompt: Bool,
174+
truncation: Bool,
175+
maxLength: Int?,
176+
tools: [ToolSpec]?,
177+
additionalContext: [String: Any]?
178+
) throws -> [Int]
179+
}
180+
181+
extension Tokenizer {
182+
/// Call previous signature for backwards compatibility
183+
func applyChatTemplate(
184+
messages: [Message],
185+
/// A chat template can optionally be provided or specified by name when several templates are included in the tokenizer config. Normally this is not necessary.
186+
chatTemplate: ChatTemplateArgument?,
187+
addGenerationPrompt: Bool,
188+
truncation: Bool,
189+
maxLength: Int?,
190+
tools: [ToolSpec]?,
191+
additionalContext: [String: Any]?
192+
) throws -> [Int] {
193+
if additionalContext == nil {
194+
try applyChatTemplate(messages: messages, chatTemplate: chatTemplate, addGenerationPrompt: addGenerationPrompt, truncation: truncation, maxLength: maxLength, tools: tools)
195+
} else {
196+
throw TokenizerError.chatTemplate("Not implemented")
197+
}
198+
}
162199
}
163200

164201
public extension Tokenizer {
@@ -359,20 +396,46 @@ public class PreTrainedTokenizer: Tokenizer {
359396
model.convertIdToToken(id)
360397
}
361398

362-
public func applyChatTemplate(messages: [[String: String]]) throws -> [Int] {
399+
public func applyChatTemplate(messages: [Message]) throws -> [Int] {
363400
try applyChatTemplate(messages: messages, addGenerationPrompt: true)
364401
}
365402

366-
public func applyChatTemplate(messages: [[String: String]], chatTemplate: ChatTemplateArgument) throws -> [Int] {
403+
public func applyChatTemplate(messages: [Message], tools: [ToolSpec]) throws -> [Int] {
404+
try applyChatTemplate(messages: messages, addGenerationPrompt: true, tools: tools)
405+
}
406+
407+
public func applyChatTemplate(messages: [Message], tools: [ToolSpec], additionalContext: [String: Any]) throws
408+
-> [Int]
409+
{
410+
try applyChatTemplate(
411+
messages: messages,
412+
addGenerationPrompt: true,
413+
tools: tools,
414+
additionalContext: additionalContext
415+
)
416+
}
417+
418+
public func applyChatTemplate(messages: [Message], chatTemplate: ChatTemplateArgument) throws -> [Int] {
367419
try applyChatTemplate(messages: messages, chatTemplate: chatTemplate, addGenerationPrompt: true)
368420
}
369421

370-
public func applyChatTemplate(messages: [[String: String]], chatTemplate: String) throws -> [Int] {
422+
public func applyChatTemplate(messages: [Message], chatTemplate: String) throws -> [Int] {
371423
try applyChatTemplate(messages: messages, chatTemplate: .literal(chatTemplate), addGenerationPrompt: true)
372424
}
373425

374426
public func applyChatTemplate(
375-
messages: [[String: String]],
427+
messages: [Message],
428+
chatTemplate: ChatTemplateArgument? = nil,
429+
addGenerationPrompt: Bool = false,
430+
truncation: Bool = false,
431+
maxLength: Int? = nil,
432+
tools: [ToolSpec]? = nil
433+
) throws -> [Int] {
434+
try applyChatTemplate(messages: messages, chatTemplate: chatTemplate, addGenerationPrompt: addGenerationPrompt, truncation: truncation, maxLength: maxLength, tools: tools, additionalContext: nil)
435+
}
436+
437+
public func applyChatTemplate(
438+
messages: [Message],
376439
chatTemplate: ChatTemplateArgument? = nil,
377440
addGenerationPrompt: Bool = false,
378441
truncation: Bool = false,
@@ -382,8 +445,8 @@ public class PreTrainedTokenizer: Tokenizer {
382445
/// giving the name, description and argument types for the tool. See the
383446
/// [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use)
384447
/// for more information.
385-
/// Note: tool calling is not supported yet, it will be available in a future update.
386-
tools: [[String: Any]]? = nil
448+
tools: [ToolSpec]? = nil,
449+
additionalContext: [String: Any]? = nil
387450
) throws -> [Int] {
388451
var selectedChatTemplate: String?
389452
if let chatTemplate, case .literal(let template) = chatTemplate {
@@ -425,10 +488,21 @@ public class PreTrainedTokenizer: Tokenizer {
425488
let template = try Template(selectedChatTemplate)
426489
var context: [String: Any] = [
427490
"messages": messages,
428-
"add_generation_prompt": addGenerationPrompt
429-
// TODO: Add `tools` entry when support is added in Jinja
430-
// "tools": tools
491+
"add_generation_prompt": addGenerationPrompt,
431492
]
493+
if let tools {
494+
context["tools"] = tools
495+
}
496+
if let additionalContext {
497+
/*
498+
Additional keys and values to be added to the context provided to the prompt templating engine.
499+
For example, the app could set "tools_in_user_message" to false for Llama 3.1 and 3.2 if a system message is provided.
500+
The default value is true in the Llama 3.1 and 3.2 chat templates, but these models will perform better if the tools are included in a system message.
501+
*/
502+
for (key, value) in additionalContext {
503+
context[key] = value
504+
}
505+
}
432506

433507
// TODO: maybe keep NSString here
434508
for (key, value) in tokenizerConfig.dictionary as [String : Any] {

Tests/TokenizersTests/ChatTemplateTests.swift

Lines changed: 89 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,5 +80,93 @@ class ChatTemplateTests: XCTestCase {
8080
XCTAssertEqual(decoded, decodedTarget)
8181
}
8282

83-
// TODO: Add tests for tool use template
83+
func testQwen2_5WithTools() async throws {
84+
let tokenizer = try await AutoTokenizer.from(pretrained: "mlx-community/Qwen2.5-7B-Instruct-4bit")
85+
86+
let weatherQueryMessages: [[String: String]] = [
87+
[
88+
"role": "user",
89+
"content": "What is the weather in Paris today?",
90+
]
91+
]
92+
93+
let getCurrentWeatherToolSpec: [String: Any] = [
94+
"type": "function",
95+
"function": [
96+
"name": "get_current_weather",
97+
"description": "Get the current weather in a given location",
98+
"parameters": [
99+
"type": "object",
100+
"properties": [
101+
"location": [
102+
"type": "string",
103+
"description": "The city and state, e.g. San Francisco, CA"
104+
],
105+
"unit": [
106+
"type": "string",
107+
"enum": ["celsius", "fahrenheit"]
108+
]
109+
],
110+
"required": ["location"]
111+
]
112+
]
113+
]
114+
115+
let encoded = try tokenizer.applyChatTemplate(messages: weatherQueryMessages, tools: [getCurrentWeatherToolSpec])
116+
let decoded = tokenizer.decode(tokens: encoded)
117+
118+
func assertDictsAreEqual(_ actual: [String: Any], _ expected: [String: Any]) {
119+
for (key, value) in actual {
120+
if let nestedDict = value as? [String: Any], let nestedDict2 = expected[key] as? [String: Any] {
121+
assertDictsAreEqual(nestedDict, nestedDict2)
122+
} else if let arrayValue = value as? [String] {
123+
let expectedArrayValue = expected[key] as? [String]
124+
XCTAssertNotNil(expectedArrayValue)
125+
XCTAssertEqual(Set(arrayValue), Set(expectedArrayValue!))
126+
} else {
127+
XCTAssertEqual(value as? String, expected[key] as? String)
128+
}
129+
}
130+
}
131+
132+
if let startRange = decoded.range(of: "<tools>\n"),
133+
let endRange = decoded.range(of: "\n</tools>", range: startRange.upperBound..<decoded.endIndex) {
134+
let toolsSection = String(decoded[startRange.upperBound..<endRange.lowerBound])
135+
if let toolsDict = try? JSONSerialization.jsonObject(with: toolsSection.data(using: .utf8)!) as? [String : Any] {
136+
assertDictsAreEqual(toolsDict, getCurrentWeatherToolSpec)
137+
} else {
138+
XCTFail("Failed to decode tools section")
139+
}
140+
} else {
141+
XCTFail("Failed to find tools section")
142+
}
143+
144+
let expectedPromptStart = """
145+
<|im_start|>system
146+
You are Qwen, created by Alibaba Cloud. You are a helpful assistant.
147+
148+
# Tools
149+
150+
You may call one or more functions to assist with the user query.
151+
152+
You are provided with function signatures within <tools></tools> XML tags:
153+
<tools>
154+
"""
155+
156+
let expectedPromptEnd = """
157+
</tools>
158+
159+
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
160+
<tool_call>
161+
{"name": <function-name>, "arguments": <args-json-object>}
162+
</tool_call><|im_end|>
163+
<|im_start|>user
164+
What is the weather in Paris today?<|im_end|>
165+
<|im_start|>assistant
166+
167+
"""
168+
169+
XCTAssertTrue(decoded.hasPrefix(expectedPromptStart), "Prompt should start with expected system message")
170+
XCTAssertTrue(decoded.hasSuffix(expectedPromptEnd), "Prompt should end with expected format")
171+
}
84172
}

Tests/TokenizersTests/TokenizerTests.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ class TokenizerTester {
260260
do {
261261
guard let tokenizerConfig = try await configuration!.tokenizerConfig else {
262262
XCTFail("Cannot retrieve Tokenizer configuration")
263-
return nil
263+
return nil
264264
}
265265
let tokenizerData = try await configuration!.tokenizerData
266266
_tokenizer = try AutoTokenizer.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)

0 commit comments

Comments
 (0)