|
| 1 | +// Copyright © 2025 Apple Inc. |
| 2 | + |
| 3 | +import ArgumentParser |
| 4 | +import Foundation |
| 5 | +import MLX |
| 6 | +import MLXLLM |
| 7 | +import MLXLMCommon |
| 8 | +import MLXVLM |
| 9 | + |
| 10 | +struct ChatCommand: AsyncParsableCommand { |
| 11 | + static let configuration = CommandConfiguration( |
| 12 | + commandName: "chat", |
| 13 | + abstract: "interactive chat with model" |
| 14 | + ) |
| 15 | + |
| 16 | + @OptionGroup var args: ModelArguments |
| 17 | + @OptionGroup var memory: MemoryArguments |
| 18 | + @OptionGroup var generate: GenerateArguments |
| 19 | + @OptionGroup var media: MediaArguments |
| 20 | + |
| 21 | + struct State { |
| 22 | + var parameters: GenerateParameters |
| 23 | + var processing: UserInput.Processing |
| 24 | + |
| 25 | + var images: [UserInput.Image] |
| 26 | + var videos: [UserInput.Video] |
| 27 | + |
| 28 | + var chat: [Chat.Message] |
| 29 | + |
| 30 | + var cache: [KVCache] |
| 31 | + |
| 32 | + var printStats = false |
| 33 | + } |
| 34 | + |
| 35 | + @MainActor |
| 36 | + mutating func run() async throws { |
| 37 | + let defaultModel = MLXLLM.LLMRegistry.mistral7B4bit |
| 38 | + |
| 39 | + // Load the model |
| 40 | + let modelContainer = try await memory.start { [args] in |
| 41 | + do { |
| 42 | + return try await args.load( |
| 43 | + defaultModel: defaultModel.name, modelFactory: LLMModelFactory.shared) |
| 44 | + } catch ModelFactoryError.unsupportedModelType { |
| 45 | + return try await args.load( |
| 46 | + defaultModel: defaultModel.name, modelFactory: VLMModelFactory.shared) |
| 47 | + } |
| 48 | + } |
| 49 | + |
| 50 | + // update the context/configuration with any command line parameters |
| 51 | + await modelContainer.update { [generate] context in |
| 52 | + generate.prepare(&context) |
| 53 | + } |
| 54 | + |
| 55 | + try await chat(modelContainer: modelContainer) |
| 56 | + } |
| 57 | + |
| 58 | + func chat(modelContainer: ModelContainer) async throws { |
| 59 | + try await modelContainer.perform { context in |
| 60 | + let parameters = generate.generateParameters |
| 61 | + let initialState = State( |
| 62 | + parameters: parameters, |
| 63 | + processing: media.processing, |
| 64 | + images: media.images, videos: media.videos, |
| 65 | + chat: [.system(generate.system)], |
| 66 | + cache: context.model.newCache(parameters: parameters)) |
| 67 | + |
| 68 | + var state = initialState |
| 69 | + |
| 70 | + print("> ", terminator: "") |
| 71 | + while let line = readLine() { |
| 72 | + if line.hasPrefix("/") { |
| 73 | + // handle commands |
| 74 | + switch command(line: line, state: &state) { |
| 75 | + case .exit: |
| 76 | + return |
| 77 | + case .reset: |
| 78 | + state = initialState |
| 79 | + state.cache = context.model.newCache(parameters: parameters) |
| 80 | + continue |
| 81 | + case .inference: |
| 82 | + // continue and run inference |
| 83 | + break |
| 84 | + case .handled: |
| 85 | + print("\n\n> ", terminator: "") |
| 86 | + continue |
| 87 | + } |
| 88 | + } else { |
| 89 | + // chat input |
| 90 | + state.chat.append(.user(line, images: state.images, videos: state.videos)) |
| 91 | + } |
| 92 | + |
| 93 | + // consume the media, if any |
| 94 | + state.images.removeAll() |
| 95 | + state.videos.removeAll() |
| 96 | + |
| 97 | + // convert UserInput to LMInput |
| 98 | + let userInput = UserInput(chat: state.chat, processing: state.processing) |
| 99 | + let input = try await context.processor.prepare(input: userInput) |
| 100 | + |
| 101 | + // generate the output |
| 102 | + var output = "" |
| 103 | + var result: GenerateCompletionInfo? |
| 104 | + for await item in try MLXLMCommon.generate( |
| 105 | + input: input, cache: state.cache, parameters: parameters, context: context |
| 106 | + ) { |
| 107 | + switch item { |
| 108 | + case .chunk(let string): |
| 109 | + output += string |
| 110 | + print(string, terminator: "") |
| 111 | + case .info(let info): |
| 112 | + result = info |
| 113 | + } |
| 114 | + } |
| 115 | + |
| 116 | + // add the assistant response to the chat messages |
| 117 | + state.chat.append(.assistant(output)) |
| 118 | + |
| 119 | + if state.printStats, let result { |
| 120 | + print( |
| 121 | + "\ntime to first token: \(result.promptTime.formatted()) tps: \(result.tokensPerSecond.formatted())" |
| 122 | + ) |
| 123 | + } |
| 124 | + print("\n\n> ", terminator: "") |
| 125 | + } |
| 126 | + } |
| 127 | + } |
| 128 | + |
| 129 | + enum CommandDisposition { |
| 130 | + case exit |
| 131 | + case reset |
| 132 | + case inference |
| 133 | + case handled |
| 134 | + } |
| 135 | + |
| 136 | + func help() { |
| 137 | + print( |
| 138 | + """ |
| 139 | + /help -- this message |
| 140 | + /quit -- terminate the chat |
| 141 | + /memory -- print memory stats |
| 142 | + /stats -- toggle token stats |
| 143 | + /reset -- reset the chat session to initial state |
| 144 | + /image [pathOrURL] -- provide an image |
| 145 | + /video [pathOrURL] -- provide a video |
| 146 | + /again -- rerun inference for last response |
| 147 | + /parameters -- print generation parametes |
| 148 | + /temperature [number] -- set the sampling temperature |
| 149 | + /topP [number] -- set the top p sampling |
| 150 | + /maxTokens [number] -- set the maximum number of tokens to generate or no number to remove limit |
| 151 | + """) |
| 152 | + } |
| 153 | + |
| 154 | + func command(line: String, state: inout State) -> CommandDisposition { |
| 155 | + let command = line.split(separator: " ")[0] |
| 156 | + let rest = String( |
| 157 | + line.dropFirst(command.count).trimmingCharacters(in: .whitespaces)) |
| 158 | + |
| 159 | + func url(_ string: String) -> URL? { |
| 160 | + if string.hasPrefix("/") { |
| 161 | + URL(filePath: string) |
| 162 | + } else { |
| 163 | + URL(string: string) |
| 164 | + } |
| 165 | + } |
| 166 | + |
| 167 | + switch command { |
| 168 | + case "/help": |
| 169 | + help() |
| 170 | + |
| 171 | + case "/quit": |
| 172 | + return .exit |
| 173 | + |
| 174 | + case "/memory": |
| 175 | + let memory = GPU.snapshot() |
| 176 | + print("Memory size: \(GPU.memoryLimit / 1024)K") |
| 177 | + print("Cache size: \(GPU.cacheLimit / 1024)K") |
| 178 | + print(memory.description) |
| 179 | + |
| 180 | + case "/stats": |
| 181 | + state.printStats.toggle() |
| 182 | + print("Token stats: \(state.printStats ? "ON" : "OFF")") |
| 183 | + |
| 184 | + case "/reset": |
| 185 | + return .reset |
| 186 | + |
| 187 | + case "/image": |
| 188 | + if let url = url(rest) { |
| 189 | + state.images.append(UserInput.Image.url(url)) |
| 190 | + } |
| 191 | + case "/video": |
| 192 | + if let url = url(rest) { |
| 193 | + state.videos.append(UserInput.Video.url(url)) |
| 194 | + } |
| 195 | + |
| 196 | + case "/again": |
| 197 | + state.chat.removeLast() |
| 198 | + return .inference |
| 199 | + |
| 200 | + case "/parameters": |
| 201 | + print(state.parameters) |
| 202 | + case "/temperature": |
| 203 | + if let value = Float(rest) { |
| 204 | + state.parameters.temperature = value |
| 205 | + print(state.parameters) |
| 206 | + } |
| 207 | + case "/topP": |
| 208 | + if let value = Float(rest) { |
| 209 | + state.parameters.topP = value |
| 210 | + print(state.parameters) |
| 211 | + } |
| 212 | + case "/maxTokens": |
| 213 | + state.parameters.maxTokens = Int(rest) |
| 214 | + print(state.parameters) |
| 215 | + |
| 216 | + default: |
| 217 | + help() |
| 218 | + } |
| 219 | + |
| 220 | + return .handled |
| 221 | + } |
| 222 | +} |
0 commit comments