Skip to content

Commit fb5ee82

Browse files
authored
attempt 2 at preparing for strict concurrency (#90)
* attempt 2 at preparing for strict concurrency - see also #83 - this marks many things in Sendable (which I think we can take regardless) - creates an actor container for models and tokenizers, which are not Sendable (though perhaps Tokenizers could be)
1 parent 885e520 commit fb5ee82

File tree

29 files changed

+506
-278
lines changed

29 files changed

+506
-278
lines changed

Applications/LLMEval/ContentView.swift

Lines changed: 41 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,9 @@ struct ContentView: View {
148148
}
149149

150150
@Observable
151+
@MainActor
151152
class LLMEvaluator {
152153

153-
@MainActor
154154
var running = false
155155

156156
var output = ""
@@ -172,91 +172,87 @@ class LLMEvaluator {
172172

173173
enum LoadState {
174174
case idle
175-
case loaded(LLMModel, Tokenizers.Tokenizer)
175+
case loaded(ModelContainer)
176176
}
177177

178178
var loadState = LoadState.idle
179179

180180
/// load and return the model -- can be called multiple times, subsequent calls will
181181
/// just return the loaded model
182-
func load() async throws -> (LLMModel, Tokenizers.Tokenizer) {
182+
func load() async throws -> ModelContainer {
183183
switch loadState {
184184
case .idle:
185185
// limit the buffer cache
186186
MLX.GPU.set(cacheLimit: 20 * 1024 * 1024)
187187

188-
let (model, tokenizer) = try await LLM.load(configuration: modelConfiguration) {
188+
let modelContainer = try await LLM.loadModelContainer(configuration: modelConfiguration)
189+
{
189190
[modelConfiguration] progress in
190-
DispatchQueue.main.sync {
191+
Task { @MainActor in
191192
self.modelInfo =
192193
"Downloading \(modelConfiguration.name): \(Int(progress.fractionCompleted * 100))%"
193194
}
194195
}
195196
self.modelInfo =
196197
"Loaded \(modelConfiguration.id). Weights: \(MLX.GPU.activeMemory / 1024 / 1024)M"
197-
loadState = .loaded(model, tokenizer)
198-
return (model, tokenizer)
198+
loadState = .loaded(modelContainer)
199+
return modelContainer
199200

200-
case .loaded(let model, let tokenizer):
201-
return (model, tokenizer)
201+
case .loaded(let modelContainer):
202+
return modelContainer
202203
}
203204
}
204205

205206
func generate(prompt: String) async {
206-
let canGenerate = await MainActor.run {
207-
if running {
208-
return false
209-
} else {
210-
running = true
211-
self.output = ""
212-
return true
213-
}
214-
}
207+
guard !running else { return }
215208

216-
guard canGenerate else { return }
209+
running = true
210+
self.output = ""
217211

218212
do {
219-
let (model, tokenizer) = try await load()
213+
let modelContainer = try await load()
214+
220215
// augment the prompt as needed
221216
let prompt = modelConfiguration.prepare(prompt: prompt)
222-
let promptTokens = tokenizer.encode(text: prompt)
217+
218+
let promptTokens = await modelContainer.perform { _, tokenizer in
219+
tokenizer.encode(text: prompt)
220+
}
223221

224222
// each time you generate you will get something new
225223
MLXRandom.seed(UInt64(Date.timeIntervalSinceReferenceDate * 1000))
226224

227-
let result = await LLM.generate(
228-
promptTokens: promptTokens, parameters: generateParameters, model: model,
229-
tokenizer: tokenizer, extraEOSTokens: modelConfiguration.extraEOSTokens
230-
) { tokens in
231-
// update the output -- this will make the view show the text as it generates
232-
if tokens.count % displayEveryNTokens == 0 {
233-
let text = tokenizer.decode(tokens: tokens)
234-
await MainActor.run {
235-
self.output = text
225+
let result = await modelContainer.perform { model, tokenizer in
226+
LLM.generate(
227+
promptTokens: promptTokens, parameters: generateParameters, model: model,
228+
tokenizer: tokenizer, extraEOSTokens: modelConfiguration.extraEOSTokens
229+
) { tokens in
230+
// update the output -- this will make the view show the text as it generates
231+
if tokens.count % displayEveryNTokens == 0 {
232+
let text = tokenizer.decode(tokens: tokens)
233+
Task { @MainActor in
234+
self.output = text
235+
}
236236
}
237-
}
238237

239-
if tokens.count >= maxTokens {
240-
return .stop
241-
} else {
242-
return .more
238+
if tokens.count >= maxTokens {
239+
return .stop
240+
} else {
241+
return .more
242+
}
243243
}
244244
}
245245

246246
// update the text if needed, e.g. we haven't displayed because of displayEveryNTokens
247-
await MainActor.run {
248-
if result.output != self.output {
249-
self.output = result.output
250-
}
251-
running = false
252-
self.stat = " Tokens/second: \(String(format: "%.3f", result.tokensPerSecond))"
247+
if result.output != self.output {
248+
self.output = result.output
253249
}
250+
self.stat = " Tokens/second: \(String(format: "%.3f", result.tokensPerSecond))"
254251

255252
} catch {
256-
await MainActor.run {
257-
running = false
258-
output = "Failed: \(error)"
259-
}
253+
output = "Failed: \(error)"
260254
}
255+
256+
running = false
261257
}
262258
}

Applications/LLMEval/ViewModels/DeviceStat.swift

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,33 +3,22 @@ import LLM
33
import MLX
44

55
@Observable
6-
class DeviceStat {
6+
final class DeviceStat: @unchecked Sendable {
7+
8+
@MainActor
79
var gpuUsage = GPU.snapshot()
8-
private var initialGPUSnapshot = GPU.snapshot()
10+
11+
private let initialGPUSnapshot = GPU.snapshot()
912
private var timer: Timer?
1013

1114
init() {
12-
startTimer()
13-
}
14-
15-
deinit {
16-
stopTimer()
17-
}
18-
19-
private func startTimer() {
20-
timer?.invalidate()
2115
timer = Timer.scheduledTimer(withTimeInterval: 2.0, repeats: true) { [weak self] _ in
22-
self?.updateStats()
16+
self?.updateGPUUsages()
2317
}
2418
}
2519

26-
private func stopTimer() {
20+
deinit {
2721
timer?.invalidate()
28-
timer = nil
29-
}
30-
31-
private func updateStats() {
32-
updateGPUUsages()
3322
}
3423

3524
private func updateGPUUsages() {

0 commit comments

Comments
 (0)