Skip to content

Commit 6ef303b

Browse files
davidkoskiawni
andauthored
add VLM support, refactor common LM code into MLXLMCommon. breaking API changes (#151)
* implement VLM - based on models from https://github.com/Blaizzy/mlx-vlm There are two new libraries: - `MLXVLM` contains vision language models that combine images and text prompts to produce text results, e.g. `describe this image` - `MLXLMCommon` contains the `LanguageModel` code that is shared between `MLXLLM` and `MLXVLM` The API between `LLM` and `VLM` is identical aside from the preparation of the `UserInput`. ```swift let parameters = GenerateParameters() // LLM prompt let input = UserInput(prompt: "tell me a story") // VLM prompt let input = UserInput(prompt: "describe the image", images: [.url(url)]) // inference is identical let result = try await modelContainer.perform { [generate, input] context in let input = try await context.processor.prepare(input: input) return try generate(input: input, parameters: parameters, context: context) { token in // print tokens as they are generated, stop early, etc. return .more } } ``` VLM example code is available in the `llm-tool` example: ``` ./mlx-run llm-tool eval --help OVERVIEW: evaluate prompt and images to generate text (VLM) USAGE: llm-tool eval <options> OPTIONS: --model <model> Name of the huggingface model or absolute path to directory -p, --prompt <prompt> The message to be processed by the model. Use @path,@path to load from files, e.g. @/tmp/prompt.txt --resize <resize> Resize images to this size (width, height) --image <image> Paths or urls for input images ... ``` Probably no effect to code external to this repo: - the mlx-swift-examples.xcodeproj now references the local `Package.swift` to build the libraries - the example code now uses the naming matching external uses of mlx-swift-examples, e.g. `import LLM` -> `import MLXLLM` - the library directories are now renamed to match their target names, e.g. `LLM` -> `MLXLLM` Breaking: - some code will now need to import both `MLXLLM` and `MLXLMCommon` (particularly code that loads models) - `MLXLMCommon` contains the common API between LLM and VLM ```swift import MLXLLM import MLXLMCommon ``` - constants for models have moved from `ModelConfiguration` to `ModelRegistry` - this is `MLXLM.ModelRegistry` and there is also `MLXVLM.ModelRegistry` ```diff - let modelConfiguration = ModelConfiguration.phi3_5_4bit + let modelConfiguration = ModelRegistry.phi3_5_4bit ``` - the `loadModelContainer()` function is now `LLMModelFactory.shared.loadContainer()` - there is a new `VLMModelFactory` with identical methods for loading VLMs ```diff - let modelContainer = try await LLM.loadModelContainer(configuration: modelConfiguration) - { + let modelContainer = try await LLMModelFactory.shared.loadContainer( + configuration: modelConfiguration + ) { ``` - `ModelContainer.perform` is now throwing (and in MLXLMCommon): ```diff - let result = await modelContainer.perform { model, tokenizer in - LLM.generate( + let result = try await modelContainer.perform { model, tokenizer in + try MLXLMCommon.generate( ``` - `ModelConfiguration` previously had a way to register new configurations. This is now on `LLMModelFactory` (and `VLMModelFactory` has the same): ```swift LLMModelFactory.shared.modelRegistry.register(configurations: [modelConfiguration]) ``` An example at the end shows all of these deprecations in context. **Prefer to use the `ModelContext.processor` to prepare prompts.** Previously users would pass in a bare `[Int]` of tokens, but in order to support more complex inputs (VLMs) the use of bare `[Int]` is deprecated and callers should use `UserInput` and `LMInput`. For example, previously callers might have done something like this: ```swift let messages = [["role": "user", "content": prompt]] let promptTokens = try await modelContainer.perform { _, tokenizer in try tokenizer.applyChatTemplate(messages: messages) } ``` Now that should be: ```swift let input = try await context.processor.prepare(input: .init(prompt: prompt)) ``` Which will initialize a `UserInput` from the prompt text and produce an `LMInput` that can be used to generate tokens. **This call to `generate()` is now deprecated:** ```swift public func generate( promptTokens: [Int], parameters: GenerateParameters, model: any LanguageModel, tokenizer: Tokenizer, extraEOSTokens: Set<String>? = nil, didGenerate: ([Int]) -> GenerateDisposition ) throws -> GenerateResult ``` This consumed the `[Int]` variety of tokens. Now this is preferred: ```swift public func generate( input: LMInput, parameters: GenerateParameters, context: ModelContext, didGenerate: ([Int]) -> GenerateDisposition ) throws -> GenerateResult ``` **This method on `ModelContainer` is now deprecated:** ```swift /// Perform an action on the model and/or tokenizer. Callers _must_ eval any `MLXArray` before returning as /// `MLXArray` is not `Sendable`. @available(*, deprecated, message: "prefer perform(_:) that uses a ModelContext") public func perform<R>(_ action: @sendable (any LanguageModel, Tokenizer) throws -> R) rethrows -> R ``` use this one instead (though the former still works): ```swift /// Perform an action on the ``ModelContext``. Callers _must_ eval any `MLXArray` before returning as /// `MLXArray` is not `Sendable`. public func perform<R>(_ action: @sendable (ModelContext) async throws -> R) async rethrows -> R ``` Putting all of these deprecations together, previously you might have generated text like this: ```swift let messages = [["role": "user", "content": prompt]] let promptTokens = try await modelContainer.perform { _, tokenizer in try tokenizer.applyChatTemplate(messages: messages) } let result = await modelContainer.perform { model, tokenizer in LLM.generate( promptTokens: promptTokens, parameters: generateParameters, model: model, tokenizer: tokenizer, extraEOSTokens: modelConfiguration.extraEOSTokens ) { tokens in ... } } ``` now do this: ```swift let result = try await modelContainer.perform { context in let input = try await context.processor.prepare(input: .init(prompt: prompt)) return try MLXLMCommon.generate( input: input, parameters: generateParameters, context: context ) { tokens in ... } } ``` Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
1 parent 318044f commit 6ef303b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

65 files changed

+5152
-2628
lines changed

.circleci/config.yml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ jobs:
1515

1616
mac_build_and_test:
1717
macos:
18-
xcode: 15.2.0
18+
xcode: 16.0.0
1919
resource_class: macos.m1.medium.gen1
2020
steps:
2121
- checkout
@@ -35,8 +35,9 @@ jobs:
3535
xcrun --show-sdk-build-version
3636
swift --version
3737
find . -name Package.resolved -exec rm {} \;
38-
xcodebuild -skipPackagePluginValidation -scheme llm-tool
39-
xcodebuild -skipPackagePluginValidation -scheme mnist-tool
38+
xcodebuild -scheme llm-tool
39+
xcodebuild -scheme image-tool
40+
xcodebuild -scheme mnist-tool
4041
4142
workflows:
4243
build_and_test:

Applications/LLMEval/ContentView.swift

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
// Copyright © 2024 Apple Inc.
22

3-
import LLM
43
import MLX
4+
import MLXLLM
5+
import MLXLMCommon
56
import MLXRandom
67
import MarkdownUI
78
import Metal
@@ -159,7 +160,7 @@ class LLMEvaluator {
159160

160161
/// This controls which model loads. `phi3_5_4bit` is one of the smaller ones, so this will fit on
161162
/// more devices.
162-
let modelConfiguration = ModelConfiguration.phi3_5_4bit
163+
let modelConfiguration = ModelRegistry.phi3_5_4bit
163164

164165
/// parameters controlling the output
165166
let generateParameters = GenerateParameters(temperature: 0.6)
@@ -185,17 +186,17 @@ class LLMEvaluator {
185186
// limit the buffer cache
186187
MLX.GPU.set(cacheLimit: 20 * 1024 * 1024)
187188

188-
let modelContainer = try await LLM.loadModelContainer(configuration: modelConfiguration)
189-
{
189+
let modelContainer = try await LLMModelFactory.shared.loadContainer(
190+
configuration: modelConfiguration
191+
) {
190192
[modelConfiguration] progress in
191193
Task { @MainActor in
192194
self.modelInfo =
193195
"Downloading \(modelConfiguration.name): \(Int(progress.fractionCompleted * 100))%"
194196
}
195197
}
196-
let numParams = await modelContainer.perform {
197-
[] model, _ in
198-
return model.numParameters()
198+
let numParams = await modelContainer.perform { context in
199+
context.model.numParameters()
199200
}
200201

201202
self.modelInfo =
@@ -217,22 +218,17 @@ class LLMEvaluator {
217218
do {
218219
let modelContainer = try await load()
219220

220-
let messages = [["role": "user", "content": prompt]]
221-
let promptTokens = try await modelContainer.perform { _, tokenizer in
222-
try tokenizer.applyChatTemplate(messages: messages)
223-
}
224-
225221
// each time you generate you will get something new
226222
MLXRandom.seed(UInt64(Date.timeIntervalSinceReferenceDate * 1000))
227223

228-
let result = await modelContainer.perform { model, tokenizer in
229-
LLM.generate(
230-
promptTokens: promptTokens, parameters: generateParameters, model: model,
231-
tokenizer: tokenizer, extraEOSTokens: modelConfiguration.extraEOSTokens
224+
let result = try await modelContainer.perform { context in
225+
let input = try await context.processor.prepare(input: .init(prompt: prompt))
226+
return try MLXLMCommon.generate(
227+
input: input, parameters: generateParameters, context: context
232228
) { tokens in
233229
// update the output -- this will make the view show the text as it generates
234230
if tokens.count % displayEveryNTokens == 0 {
235-
let text = tokenizer.decode(tokens: tokens)
231+
let text = context.tokenizer.decode(tokens: tokens)
236232
Task { @MainActor in
237233
self.output = text
238234
}

Applications/LLMEval/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ The example application uses Phi2 model by default, see [ContentView.swift](Cont
3030
let modelConfiguration = ModelConfiguration.phi4bit
3131
```
3232

33-
There are some pre-configured models in [LLM/Models.swift](../../Libraries/LLM/Models.swift#L62)
33+
There are some pre-configured models in [MLXLLM/LLMModelFactory.swift](../../Libraries/MLXLLM/LLMModelFactory.swift#L78)
3434
and you can load any weights from Hugging Face where there
3535
is a model architecture defined and you have enough
3636
memory.

Applications/LLMEval/ViewModels/DeviceStat.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import Foundation
2-
import LLM
32
import MLX
3+
import MLXLLM
44

55
@Observable
66
final class DeviceStat: @unchecked Sendable {

Applications/LoRATrainingExample/ContentView.swift

Lines changed: 27 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
// Copyright © 2024 Apple Inc.
22

3-
import LLM
43
import MLX
4+
import MLXLLM
5+
import MLXLMCommon
56
import MLXNN
67
import MLXOptimizers
78
import MLXRandom
@@ -122,7 +123,7 @@ class LoRAEvaluator {
122123

123124
var output = ""
124125

125-
private let modelConfiguration = ModelConfiguration.mistral7B4bit
126+
private let modelConfiguration = ModelRegistry.mistral7B4bit
126127
private var model: ModelState = .idle
127128

128129
private let loraLayers = 4
@@ -141,8 +142,9 @@ class LoRAEvaluator {
141142
progress = .init(title: "Loading \(name)", current: 0, limit: 1)
142143
}
143144

144-
let modelContainer = try await LLM.loadModelContainer(configuration: modelConfiguration)
145-
{
145+
let modelContainer = try await LLMModelFactory.shared.loadContainer(
146+
configuration: modelConfiguration
147+
) {
146148
progress in
147149
Task { @MainActor in
148150
self.progress = .init(
@@ -160,7 +162,7 @@ class LoRAEvaluator {
160162

161163
private func loadLoRAData(name: String) throws -> [String]? {
162164
if let url = Bundle.main.url(forResource: name, withExtension: "jsonl") {
163-
return try LLM.loadLoRAData(url: url)
165+
return try MLXLLM.loadLoRAData(url: url)
164166
}
165167
return nil
166168
}
@@ -196,9 +198,9 @@ class LoRAEvaluator {
196198
let modelContainer = try await loadModel()
197199

198200
// apply LoRA adapters and train
199-
await modelContainer.perform { model, _ in
201+
await modelContainer.perform { context in
200202
LoRATrain.convert(
201-
model: model, layers: loraLayers(model: model))
203+
model: context.model, layers: loraLayers(model: context.model))
202204
}
203205

204206
let train = try loadLoRAData(name: "train")
@@ -208,11 +210,11 @@ class LoRAEvaluator {
208210
return
209211
}
210212

211-
try await modelContainer.perform { model, tokenizer in
213+
try await modelContainer.perform { context in
212214
let optimizer = Adam(learningRate: learningRate)
213215
try LoRATrain.train(
214-
model: model, train: train, validate: valid, optimizer: optimizer,
215-
tokenizer: tokenizer,
216+
model: context.model, train: train, validate: valid, optimizer: optimizer,
217+
tokenizer: context.tokenizer,
216218
parameters: parameters
217219
) { progress in
218220
Task { @MainActor in
@@ -240,9 +242,10 @@ class LoRAEvaluator {
240242
return
241243
}
242244

243-
let loss = await modelContainer.perform { model, tokenizer in
245+
let loss = await modelContainer.perform { context in
244246
LoRATrain.evaluate(
245-
model: model, dataset: test, tokenizer: tokenizer, batchSize: 1, batchCount: 0)
247+
model: context.model, dataset: test,
248+
tokenizer: context.tokenizer, batchSize: 1, batchCount: 0)
246249
}
247250

248251
self.progress = nil
@@ -269,26 +272,20 @@ class LoRAEvaluator {
269272

270273
let modelContainer = try await loadModel()
271274

272-
let messages = [["role": "user", "content": prompt]]
273-
let promptTokens = try await modelContainer.perform { _, tokenizer in
274-
try tokenizer.applyChatTemplate(messages: messages)
275-
}
276-
277275
// evaluate
278-
let result = await modelContainer.perform { model, tokenizer in
279-
LLM.generate(
280-
promptTokens: promptTokens, parameters: generateParameters, model: model,
281-
tokenizer: tokenizer,
282-
extraEOSTokens: modelConfiguration.extraEOSTokens,
283-
didGenerate: { tokens in
284-
if tokens.count % evaluateShowEvery == 0 {
285-
let fullOutput = tokenizer.decode(tokens: tokens)
286-
Task { @MainActor in
287-
self.output = fullOutput
288-
}
276+
let result = try await modelContainer.perform { context in
277+
let input = try await context.processor.prepare(input: .init(prompt: prompt))
278+
return try MLXLMCommon.generate(
279+
input: input, parameters: generateParameters, context: context
280+
) { tokens in
281+
if tokens.count % evaluateShowEvery == 0 {
282+
let fullOutput = context.tokenizer.decode(tokens: tokens)
283+
Task { @MainActor in
284+
self.output = fullOutput
289285
}
290-
return tokens.count >= maxTokens ? .stop : .more
291-
})
286+
}
287+
return tokens.count >= maxTokens ? .stop : .more
288+
}
292289
}
293290

294291
self.output = result.output

Applications/MNISTTrainer/ContentView.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
// Copyright © 2024 Apple Inc.
22

33
import MLX
4+
import MLXMNIST
45
import MLXNN
56
import MLXOptimizers
67
import MLXRandom
7-
import MNIST
88
import SwiftUI
99

1010
struct TrainingView: View {

Applications/MNISTTrainer/PredictionView.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
//
77

88
import MLX
9+
import MLXMNIST
910
import MLXNN
10-
import MNIST
1111
import SwiftUI
1212

1313
struct Canvas: View {

0 commit comments

Comments
 (0)