Skip to content

Commit 8e41311

Browse files
authored
support for running all the models (#317)
* support for running all the models - add a --download argument - add a `list` command - support scripts to write a script to run all the models * updates for models / templates that do not accept system role
1 parent 4831d42 commit 8e41311

File tree

12 files changed

+298
-7
lines changed

12 files changed

+298
-7
lines changed

Libraries/MLXLLM/LLMModel.swift

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,15 @@
22

33
import MLX
44
import MLXLMCommon
5+
import Tokenizers
56

67
/// Marker protocol for LLMModels
78
public protocol LLMModel: LanguageModel, LoRAModel {
9+
10+
/// Models can implement this is they need a custom `MessageGenerator`.
11+
///
12+
/// The default implementation returns `DefaultMessageGenerator`.
13+
func messageGenerator(tokenizer: Tokenizer) -> MessageGenerator
814
}
915

1016
extension LLMModel {
@@ -30,4 +36,8 @@ extension LLMModel {
3036

3137
return .tokens(y)
3238
}
39+
40+
public func messageGenerator(tokenizer: Tokenizer) -> MessageGenerator {
41+
DefaultMessageGenerator()
42+
}
3343
}

Libraries/MLXLLM/LLMModelFactory.swift

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,11 +331,18 @@ public class LLMModelFactory: ModelFactory {
331331

332332
let tokenizer = try await loadTokenizer(configuration: configuration, hub: hub)
333333

334+
let messageGenerator =
335+
if let model = model as? LLMModel {
336+
model.messageGenerator(tokenizer: tokenizer)
337+
} else {
338+
DefaultMessageGenerator()
339+
}
340+
334341
return .init(
335342
configuration: configuration, model: model,
336343
processor: LLMUserInputProcessor(
337344
tokenizer: tokenizer, configuration: configuration,
338-
messageGenerator: DefaultMessageGenerator()),
345+
messageGenerator: messageGenerator),
339346
tokenizer: tokenizer)
340347
}
341348

Libraries/MLXLLM/Models/Gemma.swift

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import Foundation
44
import MLX
55
import MLXLMCommon
66
import MLXNN
7+
import Tokenizers
78

89
// Port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/gemma.py
910

@@ -187,6 +188,10 @@ public class GemmaModel: Module, LLMModel, KVCacheDimensionProvider {
187188
let out = model(inputs, cache: cache)
188189
return model.embedTokens.asLinear(out)
189190
}
191+
192+
public func messageGenerator(tokenizer: any Tokenizer) -> any MessageGenerator {
193+
NoSystemMessageGenerator()
194+
}
190195
}
191196

192197
public struct GemmaConfiguration: Codable, Sendable {

Libraries/MLXLLM/Models/Gemma2.swift

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import Foundation
44
import MLX
55
import MLXLMCommon
66
import MLXNN
7+
import Tokenizers
78

89
// Port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/gemma2.py
910

@@ -212,6 +213,10 @@ public class Gemma2Model: Module, LLMModel, KVCacheDimensionProvider {
212213
out = tanh(out / logitSoftCap) * logitSoftCap
213214
return out
214215
}
216+
217+
public func messageGenerator(tokenizer: any Tokenizer) -> any MessageGenerator {
218+
NoSystemMessageGenerator()
219+
}
215220
}
216221

217222
public struct Gemma2Configuration: Codable {
@@ -245,7 +250,7 @@ public struct Gemma2Configuration: Codable {
245250
case queryPreAttnScalar = "query_pre_attn_scalar"
246251
}
247252

248-
public init(from decoder: Decoder) throws {
253+
public init(from decoder: Swift.Decoder) throws {
249254
// Custom implementation to handle optional keys with required values
250255
let container: KeyedDecodingContainer<CodingKeys> = try decoder.container(
251256
keyedBy: CodingKeys.self)

Libraries/MLXLLM/Models/Llama.swift

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import Foundation
44
import MLX
55
import MLXLMCommon
66
import MLXNN
7+
import Tokenizers
78

89
// port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/llama.py
910

@@ -316,6 +317,23 @@ public class LlamaModel: Module, LLMModel, KVCacheDimensionProvider {
316317
!$0.key.contains("self_attn.rotary_emb.inv_freq")
317318
}
318319
}
320+
321+
public func messageGenerator(tokenizer: any Tokenizer) -> any MessageGenerator {
322+
// some models allow the system role and some do not -- this is enforced
323+
// by the chat template (code).
324+
do {
325+
let probe = [
326+
[
327+
"role": "system",
328+
"content": "test",
329+
]
330+
]
331+
_ = try tokenizer.applyChatTemplate(messages: probe)
332+
return DefaultMessageGenerator()
333+
} catch {
334+
return NoSystemMessageGenerator()
335+
}
336+
}
319337
}
320338

321339
public struct LlamaConfiguration: Codable, Sendable {
@@ -382,7 +400,7 @@ public struct LlamaConfiguration: Codable, Sendable {
382400
case mlpBias = "mlp_bias"
383401
}
384402

385-
public init(from decoder: Decoder) throws {
403+
public init(from decoder: Swift.Decoder) throws {
386404
let container = try decoder.container(keyedBy: CodingKeys.self)
387405

388406
hiddenSize = try container.decode(Int.self, forKey: .hiddenSize)

Libraries/MLXLMCommon/Chat.swift

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,25 @@ public enum Chat {
6363
/// ```
6464
public protocol MessageGenerator {
6565

66+
/// Generates messages from the input.
67+
func generate(from input: UserInput) -> [Message]
68+
69+
/// Returns array of `[String: Any]` aka ``Message``
70+
func generate(messages: [Chat.Message]) -> [Message]
71+
6672
/// Returns `[String: Any]` aka ``Message``.
6773
func generate(message: Chat.Message) -> Message
6874
}
6975

7076
extension MessageGenerator {
71-
/// Returns array of `[String: Any]` aka ``Message``
77+
78+
public func generate(message: Chat.Message) -> Message {
79+
[
80+
"role": message.role.rawValue,
81+
"content": message.content,
82+
]
83+
}
84+
7285
public func generate(messages: [Chat.Message]) -> [Message] {
7386
var rawMessages: [Message] = []
7487

@@ -80,7 +93,6 @@ extension MessageGenerator {
8093
return rawMessages
8194
}
8295

83-
/// Generates messages from the input.
8496
public func generate(from input: UserInput) -> [Message] {
8597
switch input.prompt {
8698
case .text(let text):
@@ -112,3 +124,22 @@ public struct DefaultMessageGenerator: MessageGenerator {
112124
]
113125
}
114126
}
127+
128+
/// Implementation of ``MessageGenerator`` that produces a
129+
/// `role` and `content` but omits `system` roles.
130+
///
131+
/// ```swift
132+
/// [
133+
/// "role": message.role.rawValue,
134+
/// "content": message.content,
135+
/// ]
136+
/// ```
137+
public struct NoSystemMessageGenerator: MessageGenerator {
138+
public init() {}
139+
140+
public func generate(messages: [Chat.Message]) -> [Message] {
141+
messages
142+
.filter { $0.role != .system }
143+
.map { generate(message: $0) }
144+
}
145+
}

Tools/llm-tool/LLMTool.swift

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@ import Tokenizers
1414
struct LLMTool: AsyncParsableCommand {
1515
static let configuration = CommandConfiguration(
1616
abstract: "Command line tool for generating text and manipulating LLMs",
17-
subcommands: [EvaluateCommand.self, ChatCommand.self, LoRACommand.self],
17+
subcommands: [
18+
EvaluateCommand.self, ChatCommand.self, LoRACommand.self,
19+
ListCommands.self,
20+
],
1821
defaultSubcommand: EvaluateCommand.self)
1922
}
2023

@@ -24,6 +27,9 @@ struct ModelArguments: ParsableArguments, Sendable {
2427
@Option(name: .long, help: "Name of the Hugging Face model or absolute path to directory")
2528
var model: String?
2629

30+
@Option(help: "Hub download directory")
31+
var download: URL?
32+
2733
@Sendable
2834
func load(defaultModel: String, modelFactory: ModelFactory) async throws -> ModelContainer {
2935
let modelConfiguration: ModelConfiguration
@@ -39,7 +45,15 @@ struct ModelArguments: ParsableArguments, Sendable {
3945
// identifier
4046
modelConfiguration = modelFactory.configuration(id: modelName)
4147
}
42-
return try await modelFactory.loadContainer(configuration: modelConfiguration)
48+
49+
let hub =
50+
if let download {
51+
HubApi(downloadBase: download)
52+
} else {
53+
HubApi()
54+
}
55+
56+
return try await modelFactory.loadContainer(hub: hub, configuration: modelConfiguration)
4357
}
4458
}
4559

@@ -313,6 +327,10 @@ struct EvaluateCommand: AsyncParsableCommand {
313327
return try await generate.generate(input: input, context: context)
314328
}
315329

330+
// wait for any asynchronous cleanup, e.g. tearing down compiled functions
331+
// before the task exits -- this would race with mlx::core shutdown
332+
try await Task.sleep(for: .milliseconds(30))
333+
316334
if !generate.quiet {
317335
print("------")
318336
print(result.summary())

Tools/llm-tool/ListCommands.swift

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
// Copyright © 2025 Apple Inc.
2+
3+
import ArgumentParser
4+
import Foundation
5+
import MLXLLM
6+
import MLXVLM
7+
8+
struct ListCommands: AsyncParsableCommand {
9+
10+
static let configuration = CommandConfiguration(
11+
commandName: "list",
12+
abstract: "list registered model configurations",
13+
subcommands: [
14+
ListLLMCommand.self, ListVLMCommand.self,
15+
]
16+
)
17+
}
18+
19+
struct ListLLMCommand: AsyncParsableCommand {
20+
21+
static let configuration = CommandConfiguration(
22+
commandName: "llms",
23+
abstract: "List registered LLM model configurations"
24+
)
25+
26+
func run() async throws {
27+
for configuration in LLMRegistry.shared.models {
28+
switch configuration.id {
29+
case .id(let id): print(id)
30+
case .directory: break
31+
}
32+
}
33+
}
34+
}
35+
36+
struct ListVLMCommand: AsyncParsableCommand {
37+
38+
static let configuration = CommandConfiguration(
39+
commandName: "vlms",
40+
abstract: "List registered VLM model configurations"
41+
)
42+
43+
func run() async throws {
44+
for configuration in VLMRegistry.shared.models {
45+
switch configuration.id {
46+
case .id(let id): print(id)
47+
case .directory: break
48+
}
49+
}
50+
}
51+
}

mlx-swift-examples.xcodeproj/project.pbxproj

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
C36BF0082BC5CE56002D4AFE /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = C36BF0072BC5CE56002D4AFE /* Assets.xcassets */; };
5858
C36BF00C2BC5CE56002D4AFE /* Preview Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = C36BF00B2BC5CE56002D4AFE /* Preview Assets.xcassets */; };
5959
C36BF0352BC70F11002D4AFE /* Arguments.swift in Sources */ = {isa = PBXBuildFile; fileRef = C36BF0342BC70F11002D4AFE /* Arguments.swift */; };
60+
C37133A22DD6524B00D19830 /* ListCommands.swift in Sources */ = {isa = PBXBuildFile; fileRef = C37133A12DD6524B00D19830 /* ListCommands.swift */; };
6061
C38BA3AA2DB8321600BAFA88 /* Chat.swift in Sources */ = {isa = PBXBuildFile; fileRef = C38BA3A92DB8321600BAFA88 /* Chat.swift */; };
6162
C392737D2B606A1D00368D5D /* Tutorial.swift in Sources */ = {isa = PBXBuildFile; fileRef = C392737C2B606A1D00368D5D /* Tutorial.swift */; };
6263
C397C59C2B62C6D0004B084D /* ArgumentParser in Frameworks */ = {isa = PBXBuildFile; productRef = C397C59B2B62C6D0004B084D /* ArgumentParser */; };
@@ -232,6 +233,7 @@
232233
C36BF0092BC5CE56002D4AFE /* StableDiffusionExample.entitlements */ = {isa = PBXFileReference; lastKnownFileType = text.plist.entitlements; path = StableDiffusionExample.entitlements; sourceTree = "<group>"; };
233234
C36BF00B2BC5CE56002D4AFE /* Preview Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = "Preview Assets.xcassets"; sourceTree = "<group>"; };
234235
C36BF0342BC70F11002D4AFE /* Arguments.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Arguments.swift; sourceTree = "<group>"; };
236+
C37133A12DD6524B00D19830 /* ListCommands.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ListCommands.swift; sourceTree = "<group>"; };
235237
C38BA3A92DB8321600BAFA88 /* Chat.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Chat.swift; sourceTree = "<group>"; };
236238
C39273742B606A0A00368D5D /* Tutorial */ = {isa = PBXFileReference; explicitFileType = "compiled.mach-o.executable"; includeInIndex = 0; path = Tutorial; sourceTree = BUILT_PRODUCTS_DIR; };
237239
C392737C2B606A1D00368D5D /* Tutorial.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Tutorial.swift; sourceTree = "<group>"; };
@@ -481,6 +483,7 @@
481483
C36BEFB32BBDEA69002D4AFE /* LoraCommands.swift */,
482484
C36BEFB62BBDECBC002D4AFE /* Arguments.swift */,
483485
C38BA3A92DB8321600BAFA88 /* Chat.swift */,
486+
C37133A12DD6524B00D19830 /* ListCommands.swift */,
484487
);
485488
path = "llm-tool";
486489
sourceTree = "<group>";
@@ -1148,6 +1151,7 @@
11481151
C38BA3AA2DB8321600BAFA88 /* Chat.swift in Sources */,
11491152
C34E48F52B696F0B00FCB841 /* LLMTool.swift in Sources */,
11501153
C36BEFB52BBDEAD8002D4AFE /* LoraCommands.swift in Sources */,
1154+
C37133A22DD6524B00D19830 /* ListCommands.swift in Sources */,
11511155
);
11521156
runOnlyForDeploymentPostprocessing = 0;
11531157
};

support/generate-run-all-llms.sh

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#!/bin/sh
2+
3+
echo "#!/bin/sh"
4+
echo "# NOTE: GENERATED BY generate-run-all-llms.sh -- DO NOT MODIFY BY HAND"
5+
6+
./mlx-run llm-tool list llms | \
7+
awk '{printf "./mlx-run llm-tool eval --download ~/Downloads/huggingface --model %s\n", $0}' | \
8+
awk '{printf "echo\necho ======\necho '\''%s'\''\n%s\n", $0, $0}'
9+
10+
./mlx-run llm-tool list vlms | \
11+
awk '{printf "./mlx-run llm-tool eval --download ~/Downloads/huggingface --model %s --resize 512 --image support/test.jpg\n", $0}' | \
12+
awk '{printf "echo\necho ======\necho '\''%s'\''\n%s\n", $0, $0}'

0 commit comments

Comments
 (0)