Skip to content

Commit 5a7a1a4

Browse files
Use chat template (#135)
* Use chat template * Update packages
1 parent 4e5977d commit 5a7a1a4

File tree

8 files changed

+49
-116
lines changed

8 files changed

+49
-116
lines changed

Applications/LLMEval/ContentView.swift

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,10 @@ class LLMEvaluator {
159159

160160
/// This controls which model loads. `phi3_5_4bit` is one of the smaller ones, so this will fit on
161161
/// more devices.
162-
let modelConfiguration = ModelConfiguration.phi3_5_4bit
162+
163+
// let modelConfiguration = ModelConfiguration.phi3_5_4bit
164+
// let modelConfiguration = ModelConfiguration.mistral7B4bit
165+
let modelConfiguration = ModelConfiguration.llama3_2_3B_4bit
163166

164167
/// parameters controlling the output
165168
let generateParameters = GenerateParameters(temperature: 0.6)
@@ -217,11 +220,9 @@ class LLMEvaluator {
217220
do {
218221
let modelContainer = try await load()
219222

220-
// augment the prompt as needed
221-
let prompt = modelConfiguration.prepare(prompt: prompt)
222-
223-
let promptTokens = await modelContainer.perform { _, tokenizer in
224-
tokenizer.encode(text: prompt)
223+
let messages = [["role": "user", "content": prompt]]
224+
let promptTokens = try await modelContainer.perform { _, tokenizer in
225+
try tokenizer.applyChatTemplate(messages: messages)
225226
}
226227

227228
// each time you generate you will get something new

Applications/LoRATrainingExample/ContentView.swift

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -269,10 +269,9 @@ class LoRAEvaluator {
269269

270270
let modelContainer = try await loadModel()
271271

272-
// prepare the prompt
273-
let preparedPrompt = modelConfiguration.prepare(prompt: prompt)
274-
let promptTokens = await modelContainer.perform { _, tokenizer in
275-
tokenizer.encode(text: preparedPrompt)
272+
let messages = [["role": "user", "content": prompt]]
273+
let promptTokens = try await modelContainer.perform { _, tokenizer in
274+
try tokenizer.applyChatTemplate(messages: messages)
276275
}
277276

278277
// evaluate

Libraries/LLM/LLMModel.swift

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,13 @@ import Tokenizers
88

99
/// Container for models that guarantees single threaded access.
1010
///
11-
/// Wrap models used by e.g. the UI in a ModelContainer. Callers can access
11+
/// Wrap models used by e.g. the UI in a ModelContainer. Callers can access
1212
/// the model and/or tokenizer:
1313
///
1414
/// ```swift
15-
/// let promptTokens = await modelContainer.perform { _, tokenizer in
16-
/// tokenizer.encode(text: prompt)
15+
/// let messages = [["role": "user", "content": prompt]]
16+
/// let promptTokens = try await modelContainer.perform { _, tokenizer in
17+
/// try tokenizer.applyChatTemplate(messages: messages)
1718
/// }
1819
/// ```
1920
///

Libraries/LLM/Models.swift

Lines changed: 15 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,6 @@ public struct ModelConfiguration: Sendable {
3939
/// Additional tokens to use for end of string
4040
public let extraEOSTokens: Set<String>
4141

42-
/// custom preparation logic for the prompt. custom tokenizers provide more capability, but this
43-
/// allows some minor formtting changes, e.g. wrapping the user input in the expected prompt
44-
/// format
45-
private let preparePrompt: (@Sendable (String) -> String)?
46-
4742
public init(
4843
id: String, tokenizerId: String? = nil, overrideTokenizer: String? = nil,
4944
defaultPrompt: String = "hello",
@@ -55,25 +50,18 @@ public struct ModelConfiguration: Sendable {
5550
self.overrideTokenizer = overrideTokenizer
5651
self.defaultPrompt = defaultPrompt
5752
self.extraEOSTokens = extraEOSTokens
58-
self.preparePrompt = preparePrompt
5953
}
6054

6155
public init(
6256
directory: URL, tokenizerId: String? = nil, overrideTokenizer: String? = nil,
6357
defaultPrompt: String = "hello",
64-
extraEOSTokens: Set<String> = [],
65-
preparePrompt: (@Sendable (String) -> String)? = nil
58+
extraEOSTokens: Set<String> = []
6659
) {
6760
self.id = .directory(directory)
6861
self.tokenizerId = tokenizerId
6962
self.overrideTokenizer = overrideTokenizer
7063
self.defaultPrompt = defaultPrompt
7164
self.extraEOSTokens = extraEOSTokens
72-
self.preparePrompt = preparePrompt
73-
}
74-
75-
public func prepare(prompt: String) -> String {
76-
preparePrompt?(prompt) ?? prompt
7765
}
7866

7967
public func modelDirectory(hub: HubApi = HubApi()) -> URL {
@@ -116,40 +104,26 @@ extension ModelConfiguration {
116104
public static let smolLM_135M_4bit = ModelConfiguration(
117105
id: "mlx-community/SmolLM-135M-Instruct-4bit",
118106
defaultPrompt: "Tell me about the history of Spain."
119-
) {
120-
prompt in
121-
"<|im_start|>user\n\(prompt)<|im_end|>\n<|im_start|>assistant\n"
122-
}
107+
)
123108

124109
public static let mistralNeMo4bit = ModelConfiguration(
125110
id: "mlx-community/Mistral-Nemo-Instruct-2407-4bit",
126111
defaultPrompt: "Explain quaternions."
127-
) { prompt in
128-
"<s>[INST] \(prompt) [/INST] "
129-
}
112+
)
130113

131114
public static let mistral7B4bit = ModelConfiguration(
132115
id: "mlx-community/Mistral-7B-Instruct-v0.3-4bit",
133116
defaultPrompt: "Describe the Swift language."
134-
) { prompt in
135-
"<s>[INST] \(prompt) [/INST] "
136-
}
117+
)
137118

138119
public static let codeLlama13b4bit = ModelConfiguration(
139120
id: "mlx-community/CodeLlama-13b-Instruct-hf-4bit-MLX",
140121
overrideTokenizer: "PreTrainedTokenizer",
141122
defaultPrompt: "func sortArray(_ array: [Int]) -> String { <FILL_ME> }"
142-
) { prompt in
143-
// given the prompt: func sortArray(_ array: [Int]) -> String { <FILL_ME> }
144-
// the python code produces this (via its custom tokenizer):
145-
// <PRE> func sortArray(_ array: [Int]) -> String { <SUF> } <MID>
146-
147-
"<PRE> " + prompt.replacingOccurrences(of: "<FILL_ME>", with: "<SUF>") + " <MID>"
148-
}
123+
)
149124

150125
public static let phi4bit = ModelConfiguration(
151126
id: "mlx-community/phi-2-hf-4bit-mlx",
152-
153127
// https://www.promptingguide.ai/models/phi-2
154128
defaultPrompt: "Why is the sky blue?"
155129
)
@@ -158,92 +132,60 @@ extension ModelConfiguration {
158132
id: "mlx-community/Phi-3.5-mini-instruct-4bit",
159133
defaultPrompt: "What is the gravity on Mars and the moon?",
160134
extraEOSTokens: ["<|end|>"]
161-
) {
162-
prompt in
163-
"<s><|user|>\n\(prompt)<|end|>\n<|assistant|>\n"
164-
}
135+
)
165136

166137
public static let gemma2bQuantized = ModelConfiguration(
167138
id: "mlx-community/quantized-gemma-2b-it",
168139
overrideTokenizer: "PreTrainedTokenizer",
169-
170140
// https://www.promptingguide.ai/models/gemma
171141
defaultPrompt: "what is the difference between lettuce and cabbage?"
172-
173-
) { prompt in
174-
"<start_of_turn>user\n\(prompt)<end_of_turn>\n<start_of_turn>model\n"
175-
}
142+
)
176143

177144
public static let gemma_2_9b_it_4bit = ModelConfiguration(
178145
id: "mlx-community/gemma-2-9b-it-4bit",
179146
overrideTokenizer: "PreTrainedTokenizer",
180-
181147
// https://www.promptingguide.ai/models/gemma
182148
defaultPrompt: "What is the difference between lettuce and cabbage?"
183-
184-
) { prompt in
185-
"<start_of_turn>user\n\(prompt)<end_of_turn>\n<start_of_turn>model\n"
186-
}
149+
)
187150

188151
public static let gemma_2_2b_it_4bit = ModelConfiguration(
189152
id: "mlx-community/gemma-2-2b-it-4bit",
190153
overrideTokenizer: "PreTrainedTokenizer",
191-
192154
// https://www.promptingguide.ai/models/gemma
193155
defaultPrompt: "What is the difference between lettuce and cabbage?"
194-
195-
) { prompt in
196-
"<start_of_turn>user \(prompt)<end_of_turn><start_of_turn>model"
197-
}
156+
)
198157

199158
public static let qwen205b4bit = ModelConfiguration(
200159
id: "mlx-community/Qwen1.5-0.5B-Chat-4bit",
201160
overrideTokenizer: "PreTrainedTokenizer",
202161
defaultPrompt: "why is the sky blue?"
203-
) { prompt in
204-
"<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\n\(prompt)<|im_end|>\n<|im_start|>assistant"
205-
}
162+
)
206163

207164
public static let openelm270m4bit = ModelConfiguration(
208165
id: "mlx-community/OpenELM-270M-Instruct",
209-
210166
// https://huggingface.co/apple/OpenELM
211167
defaultPrompt: "Once upon a time there was"
212-
) { prompt in
213-
"\(prompt)"
214-
}
168+
)
215169

216170
public static let llama3_1_8B_4bit = ModelConfiguration(
217171
id: "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit",
218172
defaultPrompt: "What is the difference between a fruit and a vegetable?"
219-
) {
220-
prompt in
221-
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\nYou are a helpful assistant<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\(prompt)<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>"
222-
}
173+
)
223174

224175
public static let llama3_8B_4bit = ModelConfiguration(
225176
id: "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
226177
defaultPrompt: "What is the difference between a fruit and a vegetable?"
227-
) {
228-
prompt in
229-
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\nYou are a helpful assistant<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\(prompt)<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>"
230-
}
178+
)
231179

232180
public static let llama3_2_1B_4bit = ModelConfiguration(
233181
id: "mlx-community/Llama-3.2-1B-Instruct-4bit",
234182
defaultPrompt: "What is the difference between a fruit and a vegetable?"
235-
) {
236-
prompt in
237-
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\nYou are a helpful assistant<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\(prompt)<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>"
238-
}
183+
)
239184

240185
public static let llama3_2_3B_4bit = ModelConfiguration(
241186
id: "mlx-community/Llama-3.2-3B-Instruct-4bit",
242187
defaultPrompt: "What is the difference between a fruit and a vegetable?"
243-
) {
244-
prompt in
245-
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\nYou are a helpful assistant<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\(prompt)<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>"
246-
}
188+
)
247189

248190
private enum BootstrapState: Sendable {
249191
case idle

Tools/llm-tool/LLMTool.swift

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -84,18 +84,6 @@ struct GenerateArguments: ParsableArguments, Sendable {
8484
}
8585
}
8686

87-
func tokenizePrompt(configuration: ModelConfiguration, tokenizer: Tokenizer) throws -> (
88-
String, [Int]
89-
) {
90-
MLXRandom.seed(seed)
91-
92-
let prompt = try resolvePrompt(configuration: configuration)
93-
let preparedPrompt = configuration.prepare(prompt: prompt)
94-
let promptTokens = tokenizer.encode(text: preparedPrompt)
95-
96-
return (prompt, promptTokens)
97-
}
98-
9987
func generate(
10088
promptTokens: [Int], model: LLMModel, tokenizer: Tokenizer,
10189
extraEOSTokens: Set<String>? = nil
@@ -221,9 +209,10 @@ struct EvaluateCommand: AsyncParsableCommand {
221209
print("Model loaded -> \(modelConfiguration.id)")
222210
}
223211

224-
let (prompt, promptTokens) = try await modelContainer.perform { [generate] _, tokenizer in
225-
try generate.tokenizePrompt(
226-
configuration: modelConfiguration, tokenizer: tokenizer)
212+
let prompt = generate.prompt ?? modelConfiguration.defaultPrompt
213+
let messages = [["role": "user", "content": prompt]]
214+
let promptTokens = try await modelContainer.perform { _, tokenizer in
215+
try tokenizer.applyChatTemplate(messages: messages)
227216
}
228217

229218
if !generate.quiet {

Tools/llm-tool/LoraCommands.swift

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -291,9 +291,10 @@ struct LoRAEvalCommand: AsyncParsableCommand {
291291

292292
memory.start()
293293

294-
let (prompt, promptTokens) = try await modelContainer.perform { [generate] _, tokenizer in
295-
try generate.tokenizePrompt(
296-
configuration: modelConfiguration, tokenizer: tokenizer)
294+
let prompt = generate.prompt ?? modelConfiguration.defaultPrompt
295+
let messages = [["role": "user", "content": prompt]]
296+
let promptTokens = try await modelContainer.perform { _, tokenizer in
297+
try tokenizer.applyChatTemplate(messages: messages)
297298
}
298299

299300
if !generate.quiet {

mlx-swift-examples.xcodeproj/project.pbxproj

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3555,7 +3555,7 @@
35553555
repositoryURL = "https://github.com/huggingface/swift-transformers";
35563556
requirement = {
35573557
kind = upToNextMajorVersion;
3558-
minimumVersion = 0.1.12;
3558+
minimumVersion = 0.1.13;
35593559
};
35603560
};
35613561
C392736E2B60699100368D5D /* XCRemoteSwiftPackageReference "swift-argument-parser" */ = {
@@ -3571,7 +3571,7 @@
35713571
repositoryURL = "https://github.com/ml-explore/mlx-swift";
35723572
requirement = {
35733573
kind = upToNextMajorVersion;
3574-
minimumVersion = 0.16.1;
3574+
minimumVersion = 0.18.0;
35753575
};
35763576
};
35773577
/* End XCRemoteSwiftPackageReference section */

mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved

Lines changed: 10 additions & 10 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)