Skip to content

Commit fb3184c

Browse files
authored
Implement Structured Tool (#311)
* Add initial implementation of Tool * Add ToolTests * Refactor currentWeatherTool with structured Tool * Fix description of unit parameter * Refactor Tool implementation: split Tool and ToolParameter into separate files for improved organization and clarity * Add ToolCall and ToolCallProcessor for handling tool calls in generated text - Introduced ToolCall struct to represent tool function details. - Added ToolCallProcessor class to manage detection and processing of tool calls in generated text. - Updated generation logic to yield tool calls alongside text chunks. - Created JSONValue enum for type-safe representation of JSON values. - Added extension for Encodable to handle JSON encoding with snake case. - Enhanced tests to verify tool call detection functionality. * Add tool role to Chat enum and introduce tool content initializer * Add new tools for addition and time retrieval, enhance weather tool functionality, and update generation logic to handle tool results. * add missing cases
1 parent 43d2e98 commit fb3184c

File tree

15 files changed

+710
-30
lines changed

15 files changed

+710
-30
lines changed

Applications/LLMEval/ContentView.swift

Lines changed: 102 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ struct ContentView: View {
3434
}
3535
HStack {
3636
Toggle(isOn: $llm.includeWeatherTool) {
37-
Text("Include \"get current weather\" tool")
37+
Text("Include tools")
3838
}
3939
.frame(maxWidth: 350, alignment: .leading)
4040
Toggle(isOn: $llm.enableThinking) {
@@ -188,28 +188,50 @@ class LLMEvaluator {
188188

189189
var loadState = LoadState.idle
190190

191-
let currentWeatherToolSpec: [String: any Sendable] =
192-
[
193-
"type": "function",
194-
"function": [
195-
"name": "get_current_weather",
196-
"description": "Get the current weather in a given location",
197-
"parameters": [
198-
"type": "object",
199-
"properties": [
200-
"location": [
201-
"type": "string",
202-
"description": "The city and state, e.g. San Francisco, CA",
203-
] as [String: String],
204-
"unit": [
205-
"type": "string",
206-
"enum": ["celsius", "fahrenheit"],
207-
] as [String: any Sendable],
208-
] as [String: [String: any Sendable]],
209-
"required": ["location"],
210-
] as [String: any Sendable],
211-
] as [String: any Sendable],
212-
] as [String: any Sendable]
191+
let currentWeatherTool = Tool<WeatherInput, WeatherOutput>(
192+
name: "get_current_weather",
193+
description: "Get the current weather in a given location",
194+
parameters: [
195+
.required(
196+
"location", type: .string, description: "The city and state, e.g. San Francisco, CA"
197+
),
198+
.optional(
199+
"unit",
200+
type: .string,
201+
description: "The unit of temperature",
202+
extraProperties: [
203+
"enum": ["celsius", "fahrenheit"],
204+
"default": "celsius",
205+
]
206+
),
207+
]
208+
) { input in
209+
let range = input.unit == "celsius" ? (min: -20.0, max: 40.0) : (min: 0, max: 100)
210+
let temperature = Double.random(in: range.min ... range.max)
211+
212+
let conditions = ["Sunny", "Cloudy", "Rainy", "Snowy", "Windy", "Stormy"].randomElement()!
213+
214+
return WeatherOutput(temperature: temperature, conditions: conditions)
215+
}
216+
217+
let addTool = Tool<AddInput, AddOutput>(
218+
name: "add_two_numbers",
219+
description: "Add two numbers together",
220+
parameters: [
221+
.required("first", type: .int, description: "The first number to add"),
222+
.required("second", type: .int, description: "The second number to add"),
223+
]
224+
) { input in
225+
AddOutput(result: input.first + input.second)
226+
}
227+
228+
let timeTool = Tool<EmptyInput, TimeOutput>(
229+
name: "get_time",
230+
description: "Get the current time",
231+
parameters: [],
232+
) { _ in
233+
TimeOutput(time: Date.now.formatted())
234+
}
213235

214236
/// load and return the model -- can be called multiple times, subsequent calls will
215237
/// just return the loaded model
@@ -243,15 +265,24 @@ class LLMEvaluator {
243265
}
244266
}
245267

246-
private func generate(prompt: String) async {
268+
private func generate(prompt: String, toolResult: String? = nil) async {
247269

248270
self.output = ""
249-
let chat: [Chat.Message] = [
271+
var chat: [Chat.Message] = [
250272
.system("You are a helpful assistant"),
251273
.user(prompt),
252274
]
275+
276+
if let toolResult {
277+
chat.append(.tool(toolResult))
278+
}
279+
253280
let userInput = UserInput(
254-
chat: chat, additionalContext: ["enable_thinking": enableThinking])
281+
chat: chat,
282+
tools: includeWeatherTool
283+
? [currentWeatherTool.schema, addTool.schema, timeTool.schema] : nil,
284+
additionalContext: ["enable_thinking": enableThinking]
285+
)
255286

256287
do {
257288
let modelContainer = try await load()
@@ -280,6 +311,10 @@ class LLMEvaluator {
280311
self.stat = "\(completion.tokensPerSecond) tokens/s"
281312
}
282313
}
314+
315+
if let toolCall = batch.compactMap({ $0.toolCall }).first {
316+
try await handleToolCall(toolCall, prompt: prompt)
317+
}
283318
}
284319
}
285320

@@ -303,4 +338,45 @@ class LLMEvaluator {
303338
generationTask?.cancel()
304339
running = false
305340
}
341+
342+
private func handleToolCall(_ toolCall: ToolCall, prompt: String) async throws {
343+
let result =
344+
switch toolCall.function.name {
345+
case currentWeatherTool.name:
346+
try await toolCall.execute(with: currentWeatherTool).toolResult
347+
case addTool.name:
348+
try await toolCall.execute(with: addTool).toolResult
349+
case timeTool.name:
350+
try await toolCall.execute(with: timeTool).toolResult
351+
default:
352+
"No tool match"
353+
}
354+
355+
await generate(prompt: prompt, toolResult: result)
356+
}
357+
}
358+
359+
struct WeatherInput: Codable {
360+
let location: String
361+
let unit: String?
362+
}
363+
364+
struct WeatherOutput: Codable {
365+
let temperature: Double
366+
let conditions: String
367+
}
368+
369+
struct AddInput: Codable {
370+
let first: Int
371+
let second: Int
372+
}
373+
374+
struct AddOutput: Codable {
375+
let result: Int
376+
}
377+
378+
struct EmptyInput: Codable {}
379+
380+
struct TimeOutput: Codable {
381+
let time: String
306382
}

Applications/MLXChatExample/ViewModels/ChatViewModel.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ class ChatViewModel {
8989
case .info(let info):
9090
// Update performance metrics
9191
generateCompletionInfo = info
92+
case .toolCall(let call):
93+
break
9294
}
9395
}
9496
}

Libraries/MLXLMCommon/Chat.swift

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,15 @@ public enum Chat {
4242
Self(role: .user, content: content, images: images, videos: videos)
4343
}
4444

45+
public static func tool(_ content: String) -> Self {
46+
Self(role: .tool, content: content)
47+
}
48+
4549
public enum Role: String {
4650
case user
4751
case assistant
4852
case system
53+
case tool
4954
}
5055
}
5156
}

Libraries/MLXLMCommon/Evaluate.swift

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -804,6 +804,7 @@ public func generate(
804804

805805
var tokenCount = 0
806806
var detokenizer = NaiveStreamingDetokenizer(tokenizer: context.tokenizer)
807+
let toolCallProcessor = ToolCallProcessor()
807808

808809
for token in iterator {
809810

@@ -826,7 +827,16 @@ public func generate(
826827
detokenizer.append(token: token)
827828
if let chunk = detokenizer.next() {
828829
tokenCount += 1
829-
continuation.yield(.chunk(chunk))
830+
831+
// Process chunk through the tool call processor
832+
if let textToYield = toolCallProcessor.processChunk(chunk) {
833+
continuation.yield(.chunk(textToYield))
834+
}
835+
836+
// Check if we have a complete tool call
837+
if let toolCall = toolCallProcessor.toolCalls.popLast() {
838+
continuation.yield(.toolCall(toolCall))
839+
}
830840
}
831841
}
832842

@@ -909,14 +919,19 @@ public struct GenerateCompletionInfo: Sendable {
909919
public enum Generation: Sendable {
910920
/// A generated token represented as a String
911921
case chunk(String)
922+
912923
/// Completion information summarizing token counts and performance metrics.
913924
case info(GenerateCompletionInfo)
914925

926+
/// A tool call from the language model.
927+
case toolCall(ToolCall)
928+
915929
/// Generated text or nil
916930
public var chunk: String? {
917931
switch self {
918932
case .chunk(let string): string
919933
case .info: nil
934+
case .toolCall: nil
920935
}
921936
}
922937

@@ -925,6 +940,16 @@ public enum Generation: Sendable {
925940
switch self {
926941
case .chunk: nil
927942
case .info(let info): info
943+
case .toolCall: nil
944+
}
945+
}
946+
947+
/// Tool call or nil
948+
public var toolCall: ToolCall? {
949+
switch self {
950+
case .chunk: nil
951+
case .info: nil
952+
case .toolCall(let toolCall): toolCall
928953
}
929954
}
930955

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
// Copyright © 2025 Apple Inc.
2+
3+
import Foundation
4+
5+
// Extension on Codable to handle JSON encoding with snake case
6+
extension Encodable {
7+
public var toolResult: String {
8+
let encoder = JSONEncoder()
9+
encoder.keyEncodingStrategy = .convertToSnakeCase
10+
11+
guard let data = try? encoder.encode(self) else { return "{}" }
12+
return String(data: data, encoding: .utf8) ?? "{}"
13+
}
14+
}

Libraries/MLXLMCommon/Streamlined.swift

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,8 @@ private class Generator {
101101
for await item in try MLXLMCommon.generate(
102102
input: input, cache: cache, parameters: generateParameters, context: context)
103103
{
104-
switch item {
105-
case .chunk(let chunk): continuation.yield(chunk)
106-
case .info: break
104+
if let chunk = item.chunk {
105+
continuation.yield(chunk)
107106
}
108107
}
109108

Libraries/MLXLMCommon/Tool/Tool.swift

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
// Copyright © 2025 Apple Inc.
2+
3+
import Foundation
4+
import Tokenizers
5+
6+
/// Protocol defining the requirements for a tool.
7+
public protocol ToolProtocol: Sendable {
8+
/// The JSON Schema describing the tool's interface.
9+
var schema: ToolSpec { get }
10+
}
11+
12+
public struct Tool<Input: Codable, Output: Codable>: ToolProtocol {
13+
/// The JSON Schema describing the tool's interface.
14+
public let schema: ToolSpec
15+
16+
/// The handler for the tool.
17+
public let handler: (Input) async throws -> Output
18+
19+
/// The name of the tool extracted from the schema
20+
public var name: String {
21+
let function = schema["function"] as? [String: Any]
22+
let name = function?["name"] as? String
23+
return name ?? ""
24+
}
25+
26+
public init(
27+
name: String,
28+
description: String,
29+
parameters: [ToolParameter],
30+
handler: @escaping (Input) async throws -> Output
31+
) {
32+
var properties = [String: Any]()
33+
var requiredParams = [String]()
34+
35+
for param in parameters {
36+
properties[param.name] = param.schema
37+
if param.isRequired {
38+
requiredParams.append(param.name)
39+
}
40+
}
41+
42+
self.schema = [
43+
"type": "function",
44+
"function": [
45+
"name": name,
46+
"description": description,
47+
"parameters": [
48+
"type": "object",
49+
"properties": properties,
50+
"required": requiredParams,
51+
],
52+
],
53+
]
54+
55+
self.handler = handler
56+
}
57+
58+
public init(schema: ToolSpec, handler: @escaping (Input) async throws -> Output) {
59+
self.schema = schema
60+
self.handler = handler
61+
}
62+
}
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
// Copyright © 2025 Apple Inc.
2+
3+
import Foundation
4+
5+
public struct ToolCall: Hashable, Codable, Sendable {
6+
/// Represents the function details for a tool call
7+
public struct Function: Hashable, Codable, Sendable {
8+
/// The name of the function
9+
public let name: String
10+
11+
/// The arguments passed to the function
12+
public let arguments: [String: JSONValue]
13+
14+
public init(name: String, arguments: [String: Any]) {
15+
self.name = name
16+
self.arguments = arguments.mapValues { JSONValue.from($0) }
17+
}
18+
}
19+
20+
/// The function to be called
21+
public let function: Function
22+
}
23+
24+
extension ToolCall {
25+
public func execute<Input, Output>(with tool: Tool<Input, Output>) async throws -> Output {
26+
// Check that the tool name matches the function name
27+
guard tool.name == function.name else {
28+
throw ToolError.nameMismatch(toolName: tool.name, functionName: function.name)
29+
}
30+
31+
// Convert the JSONValue arguments dictionary to a JSON-encoded Data object
32+
let jsonObject = function.arguments.mapValues { $0.anyValue }
33+
let jsonData = try JSONSerialization.data(withJSONObject: jsonObject)
34+
35+
// Decode the Input type from the JSON data
36+
let input = try JSONDecoder().decode(Input.self, from: jsonData)
37+
38+
// Execute the tool's handler with the decoded input
39+
return try await tool.handler(input)
40+
}
41+
}
42+
43+
// Define Tool-related errors
44+
public enum ToolError: Error, LocalizedError {
45+
case nameMismatch(toolName: String, functionName: String)
46+
47+
public var errorDescription: String? {
48+
switch self {
49+
case .nameMismatch(let toolName, let functionName):
50+
return "Tool name mismatch: expected '\(toolName)' but got '\(functionName)'"
51+
}
52+
}
53+
}

0 commit comments

Comments
 (0)