-
Notifications
You must be signed in to change notification settings - Fork 270
Fix thread safety issues in MLX concurrent inference (Samplers + TokenIterator) #351
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
2dfd884
44befc5
eb345ef
ed02db9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -133,6 +133,8 @@ public struct ArgMaxSampler: LogitSampler { | |
|
||
/// Sampler that uses `topP` and `temperature` to sample the logits. | ||
public struct TopPSampler: LogitSampler { | ||
private static let randomStateLock = NSLock() | ||
|
||
let temp: MLXArray | ||
let topP: MLXArray | ||
|
||
|
@@ -166,6 +168,10 @@ public struct TopPSampler: LogitSampler { | |
logits = logits.asType(.float32) | ||
} | ||
|
||
// Thread-safe sampling to prevent concurrent access to global random state | ||
TopPSampler.randomStateLock.lock() | ||
defer { TopPSampler.randomStateLock.unlock() } | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FWIW the typical way to use a lock like this is: lock.withLock {
compiledTopPSampling(...)
} but see my other comment on the use of locks to guard this |
||
|
||
return compiledTopPSampling(logits, topP, temp) | ||
} | ||
} | ||
|
@@ -174,6 +180,9 @@ public struct TopPSampler: LogitSampler { | |
public struct CategoricalSampler: LogitSampler { | ||
let temp: MLXArray | ||
|
||
// Thread-safe sampling using a lock to protect global state access | ||
private static let randomStateLock = NSLock() | ||
|
||
public init(temperature: Float) { | ||
self.temp = MLXArray(temperature) | ||
} | ||
|
@@ -185,7 +194,10 @@ public struct CategoricalSampler: LogitSampler { | |
}() | ||
|
||
public func sample(logits: MLXArray) -> MLXArray { | ||
compiledCategorical(logits, temp) | ||
// Synchronize access to global random state to prevent concurrency issues | ||
CategoricalSampler.randomStateLock.lock() | ||
defer { CategoricalSampler.randomStateLock.unlock() } | ||
return compiledCategorical(logits, temp) | ||
} | ||
} | ||
|
||
|
@@ -267,6 +279,9 @@ public struct RepetitionContext: LogitProcessor { | |
/// | ||
/// Note: this uses `asyncEval()` and there may be an async evaluation running after a call to `next()`. | ||
public struct TokenIterator: Sequence, IteratorProtocol { | ||
// Global lock to protect MLX evaluation operations | ||
private static let mlxEvalLock = NSLock() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See: This guards only concurrent calls in The Anyway, that said, this does not guard against concurrent use of the same model (which would still be a problem), nor does it add anything over the lock already in mlx-swift (that I can see, anyway). |
||
|
||
let model: any LanguageModel | ||
var state: LMOutput.State? | ||
|
||
|
@@ -383,11 +398,19 @@ public struct TokenIterator: Sequence, IteratorProtocol { | |
// evaluate the remainder of the prompt -- this primes the pump | ||
let token = step(previous: y) | ||
y = .init(tokens: token) | ||
|
||
// Protect asyncEval with the global lock | ||
TokenIterator.mlxEvalLock.lock() | ||
asyncEval(y.tokens) | ||
TokenIterator.mlxEvalLock.unlock() | ||
|
||
case .logits(let result): | ||
y = .init(tokens: convertToToken(logits: result.logits)) | ||
|
||
// Protect asyncEval with the global lock | ||
TokenIterator.mlxEvalLock.lock() | ||
asyncEval(y.tokens) | ||
TokenIterator.mlxEvalLock.unlock() | ||
|
||
break | ||
} | ||
|
@@ -434,7 +457,11 @@ public struct TokenIterator: Sequence, IteratorProtocol { | |
// compute the next state and async eval the next token | ||
let token = step(previous: previousY) | ||
y = .init(tokens: token) | ||
|
||
// Protect asyncEval with the global lock to prevent concurrent access | ||
TokenIterator.mlxEvalLock.lock() | ||
asyncEval(token) | ||
TokenIterator.mlxEvalLock.unlock() | ||
|
||
tokenCount += 1 | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This won't protect it -- randomState is global and this is only protecting callers of TopPSampler. For example it will not guard against concurrent use in CategoricalSampler.
The better way to fix this would be to have random state scoped to the sampler itself, see:
To use locks, all callers of Random would have to use the same lock. Actually it is more complicated than that -- the calls to
globalState
are themselves thread safe:but the calls to evaluate the resulting MLXArrays are not -- you need to guard the eval sites.