From 2dfd884ce4c4d8642c8f598649762f434326afe7 Mon Sep 17 00:00:00 2001 From: sxy-trans-n Date: Mon, 7 Jul 2025 15:56:10 +0900 Subject: [PATCH 1/5] Fix thread safety in CategoricalSampler and TopPSampler --- Libraries/MLXLMCommon/Evaluate.swift | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/Libraries/MLXLMCommon/Evaluate.swift b/Libraries/MLXLMCommon/Evaluate.swift index f29a43f1..71e934ce 100644 --- a/Libraries/MLXLMCommon/Evaluate.swift +++ b/Libraries/MLXLMCommon/Evaluate.swift @@ -133,6 +133,8 @@ public struct ArgMaxSampler: LogitSampler { /// Sampler that uses `topP` and `temperature` to sample the logits. public struct TopPSampler: LogitSampler { + private static let randomStateLock = NSLock() + let temp: MLXArray let topP: MLXArray @@ -165,7 +167,11 @@ public struct TopPSampler: LogitSampler { if logits.dtype == .bfloat16 { logits = logits.asType(.float32) } - + + // Thread-safe sampling to prevent concurrent access to global random state + TopPSampler.randomStateLock.lock() + defer { TopPSampler.randomStateLock.unlock() } + return compiledTopPSampling(logits, topP, temp) } } @@ -173,6 +179,9 @@ public struct TopPSampler: LogitSampler { /// Processor that uses `temperature` to sample the logits public struct CategoricalSampler: LogitSampler { let temp: MLXArray + + // Thread-safe sampling using a lock to protect global state access + private static let randomStateLock = NSLock() public init(temperature: Float) { self.temp = MLXArray(temperature) @@ -185,7 +194,10 @@ public struct CategoricalSampler: LogitSampler { }() public func sample(logits: MLXArray) -> MLXArray { - compiledCategorical(logits, temp) + // Synchronize access to global random state to prevent concurrency issues + CategoricalSampler.randomStateLock.lock() + defer { CategoricalSampler.randomStateLock.unlock() } + return compiledCategorical(logits, temp) } } From 44befc59425c851f0a8d6359efc2338a838387ce Mon Sep 17 00:00:00 2001 From: sxy-trans-n Date: Mon, 7 Jul 2025 16:38:13 +0900 Subject: [PATCH 2/5] Add thread safety to TokenIterator MLX evaluation operations --- Libraries/MLXLMCommon/Evaluate.swift | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/Libraries/MLXLMCommon/Evaluate.swift b/Libraries/MLXLMCommon/Evaluate.swift index 71e934ce..21ea1f26 100644 --- a/Libraries/MLXLMCommon/Evaluate.swift +++ b/Libraries/MLXLMCommon/Evaluate.swift @@ -279,6 +279,9 @@ public struct RepetitionContext: LogitProcessor { /// /// Note: this uses `asyncEval()` and there may be an async evaluation running after a call to `next()`. public struct TokenIterator: Sequence, IteratorProtocol { + // Global lock to protect MLX evaluation operations + private static let mlxEvalLock = NSLock() + let model: any LanguageModel var state: LMOutput.State? @@ -393,13 +396,17 @@ public struct TokenIterator: Sequence, IteratorProtocol { y = tokens // evaluate the remainder of the prompt -- this primes the pump + TokenIterator.mlxEvalLock.lock() let token = step(previous: y) y = .init(tokens: token) asyncEval(y.tokens) + TokenIterator.mlxEvalLock.unlock() case .logits(let result): + TokenIterator.mlxEvalLock.lock() y = .init(tokens: convertToToken(logits: result.logits)) asyncEval(y.tokens) + TokenIterator.mlxEvalLock.unlock() break } @@ -444,9 +451,11 @@ public struct TokenIterator: Sequence, IteratorProtocol { let previousY = y // compute the next state and async eval the next token + TokenIterator.mlxEvalLock.lock() let token = step(previous: previousY) y = .init(tokens: token) asyncEval(token) + TokenIterator.mlxEvalLock.unlock() tokenCount += 1 From eb345efca73ab81566395086ff263f58a2951c3e Mon Sep 17 00:00:00 2001 From: sxy-trans-n Date: Mon, 7 Jul 2025 16:48:43 +0900 Subject: [PATCH 3/5] format code --- Libraries/MLXLMCommon/Evaluate.swift | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/Libraries/MLXLMCommon/Evaluate.swift b/Libraries/MLXLMCommon/Evaluate.swift index 21ea1f26..c149131c 100644 --- a/Libraries/MLXLMCommon/Evaluate.swift +++ b/Libraries/MLXLMCommon/Evaluate.swift @@ -134,7 +134,7 @@ public struct ArgMaxSampler: LogitSampler { /// Sampler that uses `topP` and `temperature` to sample the logits. public struct TopPSampler: LogitSampler { private static let randomStateLock = NSLock() - + let temp: MLXArray let topP: MLXArray @@ -167,11 +167,11 @@ public struct TopPSampler: LogitSampler { if logits.dtype == .bfloat16 { logits = logits.asType(.float32) } - + // Thread-safe sampling to prevent concurrent access to global random state TopPSampler.randomStateLock.lock() defer { TopPSampler.randomStateLock.unlock() } - + return compiledTopPSampling(logits, topP, temp) } } @@ -179,7 +179,7 @@ public struct TopPSampler: LogitSampler { /// Processor that uses `temperature` to sample the logits public struct CategoricalSampler: LogitSampler { let temp: MLXArray - + // Thread-safe sampling using a lock to protect global state access private static let randomStateLock = NSLock() @@ -281,7 +281,7 @@ public struct RepetitionContext: LogitProcessor { public struct TokenIterator: Sequence, IteratorProtocol { // Global lock to protect MLX evaluation operations private static let mlxEvalLock = NSLock() - + let model: any LanguageModel var state: LMOutput.State? From ed02db95946e8e3229f6ad38b583a811ab804536 Mon Sep 17 00:00:00 2001 From: sxy-trans-n Date: Tue, 8 Jul 2025 11:47:38 +0900 Subject: [PATCH 4/5] Fix MLX concurrency issues by protecting asyncEval with global lock --- Libraries/MLXLMCommon/Evaluate.swift | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/Libraries/MLXLMCommon/Evaluate.swift b/Libraries/MLXLMCommon/Evaluate.swift index c149131c..fee7f83d 100644 --- a/Libraries/MLXLMCommon/Evaluate.swift +++ b/Libraries/MLXLMCommon/Evaluate.swift @@ -396,15 +396,19 @@ public struct TokenIterator: Sequence, IteratorProtocol { y = tokens // evaluate the remainder of the prompt -- this primes the pump - TokenIterator.mlxEvalLock.lock() let token = step(previous: y) y = .init(tokens: token) + + // Protect asyncEval with the global lock + TokenIterator.mlxEvalLock.lock() asyncEval(y.tokens) TokenIterator.mlxEvalLock.unlock() case .logits(let result): - TokenIterator.mlxEvalLock.lock() y = .init(tokens: convertToToken(logits: result.logits)) + + // Protect asyncEval with the global lock + TokenIterator.mlxEvalLock.lock() asyncEval(y.tokens) TokenIterator.mlxEvalLock.unlock() @@ -451,9 +455,11 @@ public struct TokenIterator: Sequence, IteratorProtocol { let previousY = y // compute the next state and async eval the next token - TokenIterator.mlxEvalLock.lock() let token = step(previous: previousY) y = .init(tokens: token) + + // Protect asyncEval with the global lock to prevent concurrent access + TokenIterator.mlxEvalLock.lock() asyncEval(token) TokenIterator.mlxEvalLock.unlock() From 550613aa480eb111f97fdd60708515e9b5bf536b Mon Sep 17 00:00:00 2001 From: sxy-trans-n Date: Thu, 10 Jul 2025 23:47:31 +0900 Subject: [PATCH 5/5] fix: Add concurrent-safe random state isolation to samplers --- Libraries/MLXLMCommon/Evaluate.swift | 54 +++---------- Tests/MLXLMTests/EvalTests.swift | 114 +++++++++++++++++++++++++++ 2 files changed, 123 insertions(+), 45 deletions(-) diff --git a/Libraries/MLXLMCommon/Evaluate.swift b/Libraries/MLXLMCommon/Evaluate.swift index fee7f83d..41297385 100644 --- a/Libraries/MLXLMCommon/Evaluate.swift +++ b/Libraries/MLXLMCommon/Evaluate.swift @@ -133,8 +133,6 @@ public struct ArgMaxSampler: LogitSampler { /// Sampler that uses `topP` and `temperature` to sample the logits. public struct TopPSampler: LogitSampler { - private static let randomStateLock = NSLock() - let temp: MLXArray let topP: MLXArray @@ -143,9 +141,13 @@ public struct TopPSampler: LogitSampler { self.topP = MLXArray(topP) } - private let compiledTopPSampling: (MLXArray, MLXArray, MLXArray) -> MLXArray = { - compile(inputs: [MLXRandom.globalState], outputs: [MLXRandom.globalState]) { - logits, topP, temp in + public func sample(logits: MLXArray) -> MLXArray { + var logits = logits + if logits.dtype == .bfloat16 { + logits = logits.asType(.float32) + } + + return withRandomState(MLXRandom.RandomState()) { let probs = softmax(logits / temp, axis: -1) let sortedIndices = argSort(probs, axis: -1) @@ -160,19 +162,6 @@ public struct TopPSampler: LogitSampler { let sortedToken = categorical(log(topProbs)) return sortedIndices.squeezed(axis: 0)[sortedToken] } - }() - - public func sample(logits: MLXArray) -> MLXArray { - var logits = logits - if logits.dtype == .bfloat16 { - logits = logits.asType(.float32) - } - - // Thread-safe sampling to prevent concurrent access to global random state - TopPSampler.randomStateLock.lock() - defer { TopPSampler.randomStateLock.unlock() } - - return compiledTopPSampling(logits, topP, temp) } } @@ -180,24 +169,14 @@ public struct TopPSampler: LogitSampler { public struct CategoricalSampler: LogitSampler { let temp: MLXArray - // Thread-safe sampling using a lock to protect global state access - private static let randomStateLock = NSLock() - public init(temperature: Float) { self.temp = MLXArray(temperature) } - private let compiledCategorical: (MLXArray, MLXArray) -> MLXArray = { - compile(inputs: [MLXRandom.globalState], outputs: [MLXRandom.globalState]) { logits, temp in + public func sample(logits: MLXArray) -> MLXArray { + return withRandomState(MLXRandom.RandomState()) { categorical(logits * (1 / temp)) } - }() - - public func sample(logits: MLXArray) -> MLXArray { - // Synchronize access to global random state to prevent concurrency issues - CategoricalSampler.randomStateLock.lock() - defer { CategoricalSampler.randomStateLock.unlock() } - return compiledCategorical(logits, temp) } } @@ -279,9 +258,6 @@ public struct RepetitionContext: LogitProcessor { /// /// Note: this uses `asyncEval()` and there may be an async evaluation running after a call to `next()`. public struct TokenIterator: Sequence, IteratorProtocol { - // Global lock to protect MLX evaluation operations - private static let mlxEvalLock = NSLock() - let model: any LanguageModel var state: LMOutput.State? @@ -398,19 +374,11 @@ public struct TokenIterator: Sequence, IteratorProtocol { // evaluate the remainder of the prompt -- this primes the pump let token = step(previous: y) y = .init(tokens: token) - - // Protect asyncEval with the global lock - TokenIterator.mlxEvalLock.lock() asyncEval(y.tokens) - TokenIterator.mlxEvalLock.unlock() case .logits(let result): y = .init(tokens: convertToToken(logits: result.logits)) - - // Protect asyncEval with the global lock - TokenIterator.mlxEvalLock.lock() asyncEval(y.tokens) - TokenIterator.mlxEvalLock.unlock() break } @@ -457,11 +425,7 @@ public struct TokenIterator: Sequence, IteratorProtocol { // compute the next state and async eval the next token let token = step(previous: previousY) y = .init(tokens: token) - - // Protect asyncEval with the global lock to prevent concurrent access - TokenIterator.mlxEvalLock.lock() asyncEval(token) - TokenIterator.mlxEvalLock.unlock() tokenCount += 1 diff --git a/Tests/MLXLMTests/EvalTests.swift b/Tests/MLXLMTests/EvalTests.swift index 1be306ec..aea1291d 100644 --- a/Tests/MLXLMTests/EvalTests.swift +++ b/Tests/MLXLMTests/EvalTests.swift @@ -57,6 +57,120 @@ public class EvalTests: XCTestCase { XCTAssertEqual(output.shape, [1, 5, 100]) } + func testConcurrentEvaluation() async throws { + let config = LlamaConfiguration( + hiddenSize: 64, hiddenLayers: 4, intermediateSize: 128, attentionHeads: 8, + rmsNormEps: 0.00001, vocabularySize: 100, kvHeads: 4) + let model = LlamaModel(config) + quantize(model: model, groupSize: 64, bits: 4) + + let numTasks = 3 + let results = await withTaskGroup(of: MLXArray.self) { group in + var allResults: [MLXArray] = [] + + for taskId in 0..