Skip to content

Commit 983eaac

Browse files
Support tool use and add example (#174)
* Update swift-transformers * Support tool use, add example
1 parent 3864824 commit 983eaac

File tree

5 files changed

+96
-22
lines changed

5 files changed

+96
-22
lines changed

Applications/LLMEval/ContentView.swift

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@ import SwiftUI
1010
import Tokenizers
1111

1212
struct ContentView: View {
13+
@Environment(DeviceStat.self) private var deviceStat
1314

14-
@State var prompt = ""
1515
@State var llm = LLMEvaluator()
16-
@Environment(DeviceStat.self) private var deviceStat
16+
@State var prompt = "What's the current weather in Paris?"
1717

1818
enum displayStyle: String, CaseIterable, Identifiable {
1919
case plain, markdown
@@ -34,6 +34,10 @@ struct ContentView: View {
3434
Text(llm.stat)
3535
}
3636
HStack {
37+
Toggle(isOn: $llm.includeWeatherTool) {
38+
Text("Include \"get current weather\" tool")
39+
}
40+
.frame(maxWidth: 350, alignment: .leading)
3741
Spacer()
3842
if llm.running {
3943
ProgressView()
@@ -127,7 +131,6 @@ struct ContentView: View {
127131
}
128132
.task {
129133
self.prompt = llm.modelConfiguration.defaultPrompt
130-
131134
// pre-load the weights on launch to speed up the first generation
132135
_ = try? await llm.load()
133136
}
@@ -154,13 +157,15 @@ class LLMEvaluator {
154157

155158
var running = false
156159

160+
var includeWeatherTool = false
161+
157162
var output = ""
158163
var modelInfo = ""
159164
var stat = ""
160165

161-
/// This controls which model loads. `phi3_5_4bit` is one of the smaller ones, so this will fit on
166+
/// This controls which model loads. `qwen2_5_1_5b` is one of the smaller ones, so this will fit on
162167
/// more devices.
163-
let modelConfiguration = ModelRegistry.phi3_5_4bit
168+
let modelConfiguration = ModelRegistry.qwen2_5_1_5b
164169

165170
/// parameters controlling the output
166171
let generateParameters = GenerateParameters(temperature: 0.6)
@@ -178,6 +183,29 @@ class LLMEvaluator {
178183

179184
var loadState = LoadState.idle
180185

186+
let currentWeatherToolSpec: [String: any Sendable] =
187+
[
188+
"type": "function",
189+
"function": [
190+
"name": "get_current_weather",
191+
"description": "Get the current weather in a given location",
192+
"parameters": [
193+
"type": "object",
194+
"properties": [
195+
"location": [
196+
"type": "string",
197+
"description": "The city and state, e.g. San Francisco, CA",
198+
] as [String: String],
199+
"unit": [
200+
"type": "string",
201+
"enum": ["celsius", "fahrenheit"],
202+
] as [String: any Sendable],
203+
] as [String: [String: any Sendable]],
204+
"required": ["location"],
205+
] as [String: any Sendable],
206+
] as [String: any Sendable],
207+
] as [String: any Sendable]
208+
181209
/// load and return the model -- can be called multiple times, subsequent calls will
182210
/// just return the loaded model
183211
func load() async throws -> ModelContainer {
@@ -222,18 +250,22 @@ class LLMEvaluator {
222250
MLXRandom.seed(UInt64(Date.timeIntervalSinceReferenceDate * 1000))
223251

224252
let result = try await modelContainer.perform { context in
225-
let input = try await context.processor.prepare(input: .init(prompt: prompt))
253+
let input = try await context.processor.prepare(
254+
input: .init(
255+
messages: [
256+
["role": "system", "content": "You are a helpful assistant."],
257+
["role": "user", "content": prompt],
258+
], tools: includeWeatherTool ? [currentWeatherToolSpec] : nil))
226259
return try MLXLMCommon.generate(
227260
input: input, parameters: generateParameters, context: context
228261
) { tokens in
229-
// update the output -- this will make the view show the text as it generates
262+
// Show the text in the view as it generates
230263
if tokens.count % displayEveryNTokens == 0 {
231264
let text = context.tokenizer.decode(tokens: tokens)
232265
Task { @MainActor in
233266
self.output = text
234267
}
235268
}
236-
237269
if tokens.count >= maxTokens {
238270
return .stop
239271
} else {

Libraries/MLXLLM/LLMModelFactory.swift

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,16 @@ public class ModelRegistry: @unchecked Sendable {
149149
defaultPrompt: "why is the sky blue?"
150150
)
151151

152+
static public let qwen2_5_7b = ModelConfiguration(
153+
id: "mlx-community/Qwen2.5-7B-Instruct-4bit",
154+
defaultPrompt: "Why is the sky blue?"
155+
)
156+
157+
static public let qwen2_5_1_5b = ModelConfiguration(
158+
id: "mlx-community/Qwen2.5-1.5B-Instruct-4bit",
159+
defaultPrompt: "Why is the sky blue?"
160+
)
161+
152162
static public let openelm270m4bit = ModelConfiguration(
153163
id: "mlx-community/OpenELM-270M-Instruct",
154164
// https://huggingface.co/apple/OpenELM
@@ -193,6 +203,8 @@ public class ModelRegistry: @unchecked Sendable {
193203
phi3_5_4bit,
194204
phi4bit,
195205
qwen205b4bit,
206+
qwen2_5_7b,
207+
qwen2_5_1_5b,
196208
smolLM_135M_4bit,
197209
]
198210
}
@@ -229,7 +241,8 @@ private struct LLMUserInputProcessor: UserInputProcessor {
229241
func prepare(input: UserInput) throws -> LMInput {
230242
do {
231243
let messages = input.prompt.asMessages()
232-
let promptTokens = try tokenizer.applyChatTemplate(messages: messages)
244+
let promptTokens = try tokenizer.applyChatTemplate(
245+
messages: messages, tools: input.tools, additionalContext: input.additionalContext)
233246
return LMInput(tokens: MLXArray(promptTokens))
234247
} catch {
235248
// #150 -- it might be a TokenizerError.chatTemplate("No chat template was specified")

Libraries/MLXLMCommon/UserInput.swift

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import AVFoundation
44
import CoreImage
55
import Foundation
66
import MLX
7+
import Tokenizers
78

89
/// Container for raw user input.
910
///
@@ -125,23 +126,42 @@ public struct UserInput: Sendable {
125126
public var prompt: Prompt
126127
public var images = [Image]()
127128
public var videos = [Video]()
129+
public var tools: [ToolSpec]?
130+
/// Additional values provided for the chat template rendering context
131+
public var additionalContext: [String: Any]?
128132
public var processing: Processing = .init()
129133

130-
public init(prompt: String, images: [Image] = [Image](), videos: [Video] = [Video]()) {
134+
public init(
135+
prompt: String, images: [Image] = [Image](), videos: [Video] = [Video](),
136+
tools: [ToolSpec]? = nil,
137+
additionalContext: [String: Any]? = nil
138+
) {
131139
self.prompt = .text(prompt)
132140
self.images = images
133141
self.videos = videos
142+
self.tools = tools
143+
self.additionalContext = additionalContext
134144
}
135145

136-
public init(messages: [[String: String]], images: [Image] = [Image]()) {
146+
public init(
147+
messages: [[String: String]], images: [Image] = [Image](), tools: [ToolSpec]? = nil,
148+
additionalContext: [String: Any]? = nil
149+
) {
137150
self.prompt = .messages(messages)
138151
self.images = images
152+
self.tools = tools
153+
self.additionalContext = additionalContext
139154
}
140155

141-
public init(prompt: Prompt, images: [Image] = [Image](), processing: Processing = .init()) {
156+
public init(
157+
prompt: Prompt, images: [Image] = [Image](), processing: Processing = .init(),
158+
tools: [ToolSpec]? = nil, additionalContext: [String: Any]? = nil
159+
) {
142160
self.prompt = prompt
143161
self.images = images
144162
self.processing = processing
163+
self.tools = tools
164+
self.additionalContext = additionalContext
145165
}
146166
}
147167

Package.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ let package = Package(
2929
dependencies: [
3030
.package(url: "https://github.com/ml-explore/mlx-swift", .upToNextMinor(from: "0.21.2")),
3131
.package(
32-
url: "https://github.com/huggingface/swift-transformers", .upToNextMinor(from: "0.1.15")
32+
url: "https://github.com/huggingface/swift-transformers", .upToNextMinor(from: "0.1.17")
3333
),
3434
.package(
3535
url: "https://github.com/apple/swift-async-algorithms", .upToNextMinor(from: "1.0.0")),

mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved

Lines changed: 18 additions & 9 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)