@@ -148,9 +148,9 @@ struct ContentView: View {
148
148
}
149
149
150
150
@Observable
151
+ @MainActor
151
152
class LLMEvaluator {
152
153
153
- @MainActor
154
154
var running = false
155
155
156
156
var output = " "
@@ -172,91 +172,87 @@ class LLMEvaluator {
172
172
173
173
enum LoadState {
174
174
case idle
175
- case loaded( LLMModel , Tokenizers . Tokenizer )
175
+ case loaded( ModelContainer )
176
176
}
177
177
178
178
var loadState = LoadState . idle
179
179
180
180
/// load and return the model -- can be called multiple times, subsequent calls will
181
181
/// just return the loaded model
182
- func load( ) async throws -> ( LLMModel , Tokenizers . Tokenizer ) {
182
+ func load( ) async throws -> ModelContainer {
183
183
switch loadState {
184
184
case . idle:
185
185
// limit the buffer cache
186
186
MLX . GPU. set ( cacheLimit: 20 * 1024 * 1024 )
187
187
188
- let ( model, tokenizer) = try await LLM . load ( configuration: modelConfiguration) {
188
+ let modelContainer = try await LLM . loadModelContainer ( configuration: modelConfiguration)
189
+ {
189
190
[ modelConfiguration] progress in
190
- DispatchQueue . main . sync {
191
+ Task { @ MainActor in
191
192
self . modelInfo =
192
193
" Downloading \( modelConfiguration. name) : \( Int ( progress. fractionCompleted * 100 ) ) % "
193
194
}
194
195
}
195
196
self . modelInfo =
196
197
" Loaded \( modelConfiguration. id) . Weights: \( MLX . GPU. activeMemory / 1024 / 1024 ) M "
197
- loadState = . loaded( model , tokenizer )
198
- return ( model , tokenizer )
198
+ loadState = . loaded( modelContainer )
199
+ return modelContainer
199
200
200
- case . loaded( let model , let tokenizer ) :
201
- return ( model , tokenizer )
201
+ case . loaded( let modelContainer ) :
202
+ return modelContainer
202
203
}
203
204
}
204
205
205
206
func generate( prompt: String ) async {
206
- let canGenerate = await MainActor . run {
207
- if running {
208
- return false
209
- } else {
210
- running = true
211
- self . output = " "
212
- return true
213
- }
214
- }
207
+ guard !running else { return }
215
208
216
- guard canGenerate else { return }
209
+ running = true
210
+ self . output = " "
217
211
218
212
do {
219
- let ( model, tokenizer) = try await load ( )
213
+ let modelContainer = try await load ( )
214
+
220
215
// augment the prompt as needed
221
216
let prompt = modelConfiguration. prepare ( prompt: prompt)
222
- let promptTokens = tokenizer. encode ( text: prompt)
217
+
218
+ let promptTokens = await modelContainer. perform { _, tokenizer in
219
+ tokenizer. encode ( text: prompt)
220
+ }
223
221
224
222
// each time you generate you will get something new
225
223
MLXRandom . seed ( UInt64 ( Date . timeIntervalSinceReferenceDate * 1000 ) )
226
224
227
- let result = await LLM . generate (
228
- promptTokens: promptTokens, parameters: generateParameters, model: model,
229
- tokenizer: tokenizer, extraEOSTokens: modelConfiguration. extraEOSTokens
230
- ) { tokens in
231
- // update the output -- this will make the view show the text as it generates
232
- if tokens. count % displayEveryNTokens == 0 {
233
- let text = tokenizer. decode ( tokens: tokens)
234
- await MainActor . run {
235
- self . output = text
225
+ let result = await modelContainer. perform { model, tokenizer in
226
+ LLM . generate (
227
+ promptTokens: promptTokens, parameters: generateParameters, model: model,
228
+ tokenizer: tokenizer, extraEOSTokens: modelConfiguration. extraEOSTokens
229
+ ) { tokens in
230
+ // update the output -- this will make the view show the text as it generates
231
+ if tokens. count % displayEveryNTokens == 0 {
232
+ let text = tokenizer. decode ( tokens: tokens)
233
+ Task { @MainActor in
234
+ self . output = text
235
+ }
236
236
}
237
- }
238
237
239
- if tokens. count >= maxTokens {
240
- return . stop
241
- } else {
242
- return . more
238
+ if tokens. count >= maxTokens {
239
+ return . stop
240
+ } else {
241
+ return . more
242
+ }
243
243
}
244
244
}
245
245
246
246
// update the text if needed, e.g. we haven't displayed because of displayEveryNTokens
247
- await MainActor . run {
248
- if result. output != self . output {
249
- self . output = result. output
250
- }
251
- running = false
252
- self . stat = " Tokens/second: \( String ( format: " %.3f " , result. tokensPerSecond) ) "
247
+ if result. output != self . output {
248
+ self . output = result. output
253
249
}
250
+ self . stat = " Tokens/second: \( String ( format: " %.3f " , result. tokensPerSecond) ) "
254
251
255
252
} catch {
256
- await MainActor . run {
257
- running = false
258
- output = " Failed: \( error) "
259
- }
253
+ output = " Failed: \( error) "
260
254
}
255
+
256
+ running = false
261
257
}
262
258
}
0 commit comments