Skip to content

Fix thread safety issues in MLX concurrent inference (Samplers + TokenIterator) #351

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion Libraries/MLXLMCommon/Evaluate.swift
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ public struct ArgMaxSampler: LogitSampler {

/// Sampler that uses `topP` and `temperature` to sample the logits.
public struct TopPSampler: LogitSampler {
private static let randomStateLock = NSLock()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't protect it -- randomState is global and this is only protecting callers of TopPSampler. For example it will not guard against concurrent use in CategoricalSampler.

The better way to fix this would be to have random state scoped to the sampler itself, see:

To use locks, all callers of Random would have to use the same lock. Actually it is more complicated than that -- the calls to globalState are themselves thread safe:

but the calls to evaluate the resulting MLXArrays are not -- you need to guard the eval sites.


let temp: MLXArray
let topP: MLXArray

Expand Down Expand Up @@ -166,6 +168,10 @@ public struct TopPSampler: LogitSampler {
logits = logits.asType(.float32)
}

// Thread-safe sampling to prevent concurrent access to global random state
TopPSampler.randomStateLock.lock()
defer { TopPSampler.randomStateLock.unlock() }
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW the typical way to use a lock like this is:

lock.withLock {
    compiledTopPSampling(...)
}

but see my other comment on the use of locks to guard this


return compiledTopPSampling(logits, topP, temp)
}
}
Expand All @@ -174,6 +180,9 @@ public struct TopPSampler: LogitSampler {
public struct CategoricalSampler: LogitSampler {
let temp: MLXArray

// Thread-safe sampling using a lock to protect global state access
private static let randomStateLock = NSLock()

public init(temperature: Float) {
self.temp = MLXArray(temperature)
}
Expand All @@ -185,7 +194,10 @@ public struct CategoricalSampler: LogitSampler {
}()

public func sample(logits: MLXArray) -> MLXArray {
compiledCategorical(logits, temp)
// Synchronize access to global random state to prevent concurrency issues
CategoricalSampler.randomStateLock.lock()
defer { CategoricalSampler.randomStateLock.unlock() }
return compiledCategorical(logits, temp)
}
}

Expand Down Expand Up @@ -267,6 +279,9 @@ public struct RepetitionContext: LogitProcessor {
///
/// Note: this uses `asyncEval()` and there may be an async evaluation running after a call to `next()`.
public struct TokenIterator: Sequence, IteratorProtocol {
// Global lock to protect MLX evaluation operations
private static let mlxEvalLock = NSLock()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See:

This guards only concurrent calls in TokenIterator. In theory calls to eval() and asyncEval() should be thread safe as long as callers are using entirely distinct MLXArrays / compute graphs. In practice, that was never really a guarantee from mlx::core and in mlx-swift 0.25.1 we found new issues around this (changes on the core side).

The evalLock in mlx-swift is wider than just eval -- it has to guard a number of calls. It may be removed sometime in the future if we can restore the thread safe behavior in mlx::core.

Anyway, that said, this does not guard against concurrent use of the same model (which would still be a problem), nor does it add anything over the lock already in mlx-swift (that I can see, anyway).


let model: any LanguageModel
var state: LMOutput.State?

Expand Down Expand Up @@ -383,11 +398,19 @@ public struct TokenIterator: Sequence, IteratorProtocol {
// evaluate the remainder of the prompt -- this primes the pump
let token = step(previous: y)
y = .init(tokens: token)

// Protect asyncEval with the global lock
TokenIterator.mlxEvalLock.lock()
asyncEval(y.tokens)
TokenIterator.mlxEvalLock.unlock()

case .logits(let result):
y = .init(tokens: convertToToken(logits: result.logits))

// Protect asyncEval with the global lock
TokenIterator.mlxEvalLock.lock()
asyncEval(y.tokens)
TokenIterator.mlxEvalLock.unlock()

break
}
Expand Down Expand Up @@ -434,7 +457,11 @@ public struct TokenIterator: Sequence, IteratorProtocol {
// compute the next state and async eval the next token
let token = step(previous: previousY)
y = .init(tokens: token)

// Protect asyncEval with the global lock to prevent concurrent access
TokenIterator.mlxEvalLock.lock()
asyncEval(token)
TokenIterator.mlxEvalLock.unlock()

tokenCount += 1

Expand Down