diff --git a/Libraries/MLXLMCommon/Evaluate.swift b/Libraries/MLXLMCommon/Evaluate.swift index f29a43f1..41297385 100644 --- a/Libraries/MLXLMCommon/Evaluate.swift +++ b/Libraries/MLXLMCommon/Evaluate.swift @@ -141,9 +141,13 @@ public struct TopPSampler: LogitSampler { self.topP = MLXArray(topP) } - private let compiledTopPSampling: (MLXArray, MLXArray, MLXArray) -> MLXArray = { - compile(inputs: [MLXRandom.globalState], outputs: [MLXRandom.globalState]) { - logits, topP, temp in + public func sample(logits: MLXArray) -> MLXArray { + var logits = logits + if logits.dtype == .bfloat16 { + logits = logits.asType(.float32) + } + + return withRandomState(MLXRandom.RandomState()) { let probs = softmax(logits / temp, axis: -1) let sortedIndices = argSort(probs, axis: -1) @@ -158,15 +162,6 @@ public struct TopPSampler: LogitSampler { let sortedToken = categorical(log(topProbs)) return sortedIndices.squeezed(axis: 0)[sortedToken] } - }() - - public func sample(logits: MLXArray) -> MLXArray { - var logits = logits - if logits.dtype == .bfloat16 { - logits = logits.asType(.float32) - } - - return compiledTopPSampling(logits, topP, temp) } } @@ -178,14 +173,10 @@ public struct CategoricalSampler: LogitSampler { self.temp = MLXArray(temperature) } - private let compiledCategorical: (MLXArray, MLXArray) -> MLXArray = { - compile(inputs: [MLXRandom.globalState], outputs: [MLXRandom.globalState]) { logits, temp in + public func sample(logits: MLXArray) -> MLXArray { + return withRandomState(MLXRandom.RandomState()) { categorical(logits * (1 / temp)) } - }() - - public func sample(logits: MLXArray) -> MLXArray { - compiledCategorical(logits, temp) } } diff --git a/Tests/MLXLMTests/EvalTests.swift b/Tests/MLXLMTests/EvalTests.swift index 1be306ec..aea1291d 100644 --- a/Tests/MLXLMTests/EvalTests.swift +++ b/Tests/MLXLMTests/EvalTests.swift @@ -57,6 +57,120 @@ public class EvalTests: XCTestCase { XCTAssertEqual(output.shape, [1, 5, 100]) } + func testConcurrentEvaluation() async throws { + let config = LlamaConfiguration( + hiddenSize: 64, hiddenLayers: 4, intermediateSize: 128, attentionHeads: 8, + rmsNormEps: 0.00001, vocabularySize: 100, kvHeads: 4) + let model = LlamaModel(config) + quantize(model: model, groupSize: 64, bits: 4) + + let numTasks = 3 + let results = await withTaskGroup(of: MLXArray.self) { group in + var allResults: [MLXArray] = [] + + for taskId in 0..