Skip to content

Commit 39085b8

Browse files
authored
chat example (command line) (#277)
* chat example (command line)
1 parent 5f8f583 commit 39085b8

File tree

10 files changed

+324
-68
lines changed

10 files changed

+324
-68
lines changed

Libraries/MLXLLM/LLMModelFactory.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ public class LLMModelFactory: ModelFactory {
316316
let modelDirectory = try await downloadModel(
317317
hub: hub, configuration: configuration, progressHandler: progressHandler)
318318

319-
// load the generic config to unerstand which model and how to load the weights
319+
// load the generic config to understand which model and how to load the weights
320320
let configurationURL = modelDirectory.appending(component: "config.json")
321321
let baseConfig = try JSONDecoder().decode(
322322
BaseConfiguration.self, from: Data(contentsOf: configurationURL))

Libraries/MLXLMCommon/Evaluate.swift

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -700,6 +700,7 @@ public func generate(
700700
///
701701
/// - Parameters:
702702
/// - input: The input for the language model.
703+
/// - cache: optional ``KVCache``
703704
/// - parameters: The configuration options for token generation.
704705
/// - context: The model context, including the model itself and associated tokenizer.
705706
/// - Returns: An `AsyncStream` that emits `Generation` values, including generated tokens (`.token`)
@@ -729,10 +730,10 @@ public func generate(
729730
/// }
730731
/// ```
731732
public func generate(
732-
input: LMInput, parameters: GenerateParameters, context: ModelContext
733+
input: LMInput, cache: [KVCache]? = nil, parameters: GenerateParameters, context: ModelContext
733734
) throws -> AsyncStream<Generation> {
734735
let iterator = try TokenIterator(
735-
input: input, model: context.model, parameters: parameters)
736+
input: input, model: context.model, cache: cache, parameters: parameters)
736737
return generate(
737738
input: input, context: context, iterator: iterator)
738739
}

Libraries/MLXLMCommon/KVCache.swift

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ public func createAttentionMask(h: MLXArray, cache: [KVCache]?) -> MLXArray? {
3939
}
4040

4141
/// See https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/base.py#L11
42-
public class KVCacheSimple: KVCache, Evaluatable {
42+
public class KVCacheSimple: KVCache, Evaluatable, CustomDebugStringConvertible {
4343
var keys: MLXArray?
4444
var values: MLXArray?
4545

@@ -97,4 +97,7 @@ public class KVCacheSimple: KVCache, Evaluatable {
9797
)
9898
}
9999

100+
public var debugDescription: String {
101+
"\(String(describing: Self.self)) \(Unmanaged.passUnretained(self).toOpaque()), offset: \(offset), step: \(step), keys: \(keys?.shape.description ?? "-"), values: \(values?.shape.description ?? "-")"
102+
}
100103
}

Libraries/MLXLMCommon/UserInput.swift

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,12 +251,14 @@ public struct UserInput: Sendable {
251251
/// - Parameters:
252252
/// - chat: structured content
253253
/// - tools: optional tool specifications
254+
/// - processing: optional processing to be applied to media
254255
/// - additionalContext: optional context (model specific)
255256
/// ### See Also
256257
/// - ``Prompt-swift.enum/text(_:)``
257258
/// - ``init(chat:tools:additionalContext:)``
258259
public init(
259260
chat: [Chat.Message],
261+
processing: Processing = .init(),
260262
tools: [ToolSpec]? = nil,
261263
additionalContext: [String: Any]? = nil
262264
) {
@@ -269,6 +271,8 @@ public struct UserInput: Sendable {
269271
self.videos = chat.reduce(into: []) { result, message in
270272
result.append(contentsOf: message.videos)
271273
}
274+
275+
self.processing = processing
272276
self.tools = tools
273277
self.additionalContext = additionalContext
274278
}

Libraries/MLXVLM/Models/Qwen2VL.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,15 +103,15 @@ private enum Language {
103103
values = values.reshaped(B, L, kvHeads, headDim).transposed(0, 2, 1, 3)
104104

105105
let offset = cache?.offset ?? 0
106-
let mask = mask?[0..., 0 ..< keys.dim(-2)]
107-
108106
queries = rotaryEmbedding(queries, offset: offset)
109107
keys = rotaryEmbedding(keys, offset: offset)
110108

111109
if let cache {
112110
(keys, values) = cache.update(keys: keys, values: values)
113111
}
114112

113+
let mask = mask?[.ellipsis, 0 ..< keys.dim(-2)]
114+
115115
let output = MLXFast.scaledDotProductAttention(
116116
queries: queries, keys: keys, values: values, scale: scale, mask: mask
117117
)

Libraries/MLXVLM/VLMModelFactory.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ public class VLMModelFactory: ModelFactory {
208208
let modelDirectory = try await downloadModel(
209209
hub: hub, configuration: configuration, progressHandler: progressHandler)
210210

211-
// load the generic config to unerstand which model and how to load the weights
211+
// load the generic config to understand which model and how to load the weights
212212
let configurationURL = modelDirectory.appending(
213213
component: "config.json"
214214
)

Tools/llm-tool/Chat.swift

Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
// Copyright © 2025 Apple Inc.
2+
3+
import ArgumentParser
4+
import Foundation
5+
import MLX
6+
import MLXLLM
7+
import MLXLMCommon
8+
import MLXVLM
9+
10+
struct ChatCommand: AsyncParsableCommand {
11+
static let configuration = CommandConfiguration(
12+
commandName: "chat",
13+
abstract: "interactive chat with model"
14+
)
15+
16+
@OptionGroup var args: ModelArguments
17+
@OptionGroup var memory: MemoryArguments
18+
@OptionGroup var generate: GenerateArguments
19+
@OptionGroup var media: MediaArguments
20+
21+
struct State {
22+
var parameters: GenerateParameters
23+
var processing: UserInput.Processing
24+
25+
var images: [UserInput.Image]
26+
var videos: [UserInput.Video]
27+
28+
var chat: [Chat.Message]
29+
30+
var cache: [KVCache]
31+
32+
var printStats = false
33+
}
34+
35+
@MainActor
36+
mutating func run() async throws {
37+
let defaultModel = MLXLLM.LLMRegistry.mistral7B4bit
38+
39+
// Load the model
40+
let modelContainer = try await memory.start { [args] in
41+
do {
42+
return try await args.load(
43+
defaultModel: defaultModel.name, modelFactory: LLMModelFactory.shared)
44+
} catch ModelFactoryError.unsupportedModelType {
45+
return try await args.load(
46+
defaultModel: defaultModel.name, modelFactory: VLMModelFactory.shared)
47+
}
48+
}
49+
50+
// update the context/configuration with any command line parameters
51+
await modelContainer.update { [generate] context in
52+
generate.prepare(&context)
53+
}
54+
55+
try await chat(modelContainer: modelContainer)
56+
}
57+
58+
func chat(modelContainer: ModelContainer) async throws {
59+
try await modelContainer.perform { context in
60+
let parameters = generate.generateParameters
61+
let initialState = State(
62+
parameters: parameters,
63+
processing: media.processing,
64+
images: media.images, videos: media.videos,
65+
chat: [.system(generate.system)],
66+
cache: context.model.newCache(parameters: parameters))
67+
68+
var state = initialState
69+
70+
print("> ", terminator: "")
71+
while let line = readLine() {
72+
if line.hasPrefix("/") {
73+
// handle commands
74+
switch command(line: line, state: &state) {
75+
case .exit:
76+
return
77+
case .reset:
78+
state = initialState
79+
state.cache = context.model.newCache(parameters: parameters)
80+
continue
81+
case .inference:
82+
// continue and run inference
83+
break
84+
case .handled:
85+
print("\n\n> ", terminator: "")
86+
continue
87+
}
88+
} else {
89+
// chat input
90+
state.chat.append(.user(line, images: state.images, videos: state.videos))
91+
}
92+
93+
// consume the media, if any
94+
state.images.removeAll()
95+
state.videos.removeAll()
96+
97+
// convert UserInput to LMInput
98+
let userInput = UserInput(chat: state.chat, processing: state.processing)
99+
let input = try await context.processor.prepare(input: userInput)
100+
101+
// generate the output
102+
var output = ""
103+
var result: GenerateCompletionInfo?
104+
for await item in try MLXLMCommon.generate(
105+
input: input, cache: state.cache, parameters: parameters, context: context
106+
) {
107+
switch item {
108+
case .chunk(let string):
109+
output += string
110+
print(string, terminator: "")
111+
case .info(let info):
112+
result = info
113+
}
114+
}
115+
116+
// add the assistant response to the chat messages
117+
state.chat.append(.assistant(output))
118+
119+
if state.printStats, let result {
120+
print(
121+
"\ntime to first token: \(result.promptTime.formatted()) tps: \(result.tokensPerSecond.formatted())"
122+
)
123+
}
124+
print("\n\n> ", terminator: "")
125+
}
126+
}
127+
}
128+
129+
enum CommandDisposition {
130+
case exit
131+
case reset
132+
case inference
133+
case handled
134+
}
135+
136+
func help() {
137+
print(
138+
"""
139+
/help -- this message
140+
/quit -- terminate the chat
141+
/memory -- print memory stats
142+
/stats -- toggle token stats
143+
/reset -- reset the chat session to initial state
144+
/image [pathOrURL] -- provide an image
145+
/video [pathOrURL] -- provide a video
146+
/again -- rerun inference for last response
147+
/parameters -- print generation parametes
148+
/temperature [number] -- set the sampling temperature
149+
/topP [number] -- set the top p sampling
150+
/maxTokens [number] -- set the maximum number of tokens to generate or no number to remove limit
151+
""")
152+
}
153+
154+
func command(line: String, state: inout State) -> CommandDisposition {
155+
let command = line.split(separator: " ")[0]
156+
let rest = String(
157+
line.dropFirst(command.count).trimmingCharacters(in: .whitespaces))
158+
159+
func url(_ string: String) -> URL? {
160+
if string.hasPrefix("/") {
161+
URL(filePath: string)
162+
} else {
163+
URL(string: string)
164+
}
165+
}
166+
167+
switch command {
168+
case "/help":
169+
help()
170+
171+
case "/quit":
172+
return .exit
173+
174+
case "/memory":
175+
let memory = GPU.snapshot()
176+
print("Memory size: \(GPU.memoryLimit / 1024)K")
177+
print("Cache size: \(GPU.cacheLimit / 1024)K")
178+
print(memory.description)
179+
180+
case "/stats":
181+
state.printStats.toggle()
182+
print("Token stats: \(state.printStats ? "ON" : "OFF")")
183+
184+
case "/reset":
185+
return .reset
186+
187+
case "/image":
188+
if let url = url(rest) {
189+
state.images.append(UserInput.Image.url(url))
190+
}
191+
case "/video":
192+
if let url = url(rest) {
193+
state.videos.append(UserInput.Video.url(url))
194+
}
195+
196+
case "/again":
197+
state.chat.removeLast()
198+
return .inference
199+
200+
case "/parameters":
201+
print(state.parameters)
202+
case "/temperature":
203+
if let value = Float(rest) {
204+
state.parameters.temperature = value
205+
print(state.parameters)
206+
}
207+
case "/topP":
208+
if let value = Float(rest) {
209+
state.parameters.topP = value
210+
print(state.parameters)
211+
}
212+
case "/maxTokens":
213+
state.parameters.maxTokens = Int(rest)
214+
print(state.parameters)
215+
216+
default:
217+
help()
218+
}
219+
220+
return .handled
221+
}
222+
}

0 commit comments

Comments
 (0)