Skip to content

Fix thread safety issues in MLX concurrent inference (Samplers + TokenIterator) #351

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

sxy-trans-n
Copy link

🐛 Problem

The MLX Swift Examples library suffers from multiple thread safety issues when used in concurrent inference scenarios. The issues manifest at two levels:

  1. Sampler Level: CategoricalSampler and TopPSampler race on MLXRandom.globalState
  2. Evaluation Level: TokenIterator instances race on MLX's internal evaluation engine

Error Symptoms

  • Crash: [eval] Attempting to eval an array without a primitive
  • Inconsistent results: Different outputs for identical inputs
  • Memory corruption: Occasional segmentation faults during concurrent sampling
  • Evaluation errors: MLX internal assertions when multiple asyncEval() calls overlap

Root Cause Analysis

1. Sampler Race Conditions

Both samplers use compiled functions that implicitly access the global random state:

compile(inputs: [MLXRandom.globalState], outputs: [MLXRandom.globalState]) { ... }

2. MLX Evaluation Race Conditions

Multiple TokenIterator instances calling asyncEval() concurrently cause race conditions in MLX's internal evaluation engine, even when samplers are properly synchronized.

When multiple threads call these operations concurrently, they race to access and modify shared state, causing undefined behavior.

🔧 Solution

Added comprehensive thread safety at both the sampler and evaluation levels using NSLock to serialize access to shared MLX resources.

Key Changes

  1. CategoricalSampler: Added randomStateLock to protect global random state access
  2. TopPSampler: Added randomStateLock to protect global random state access
  3. TokenIterator: Added mlxEvalLock to serialize MLX evaluation operations (asyncEval, model.prepare)

Design Decisions

  • Minimal code changes: No API breaking changes, maintains backward compatibility
  • Static locks per component: Minimize memory overhead while ensuring proper synchronization
  • Focused critical sections: Lock only the essential operations to minimize contention

📊 Performance Impact

Sampler Level: Minimal impact (~0.001% overhead) - sampling is <1% of total inference time
TokenIterator Level: Moderate impact - model evaluations are serialized, but throughput remains 3-4x better than pure serial processing

Observed Pattern: 10 concurrent requests complete in 3-4 batches rather than full parallelism, but with 100% stability.

@@ -133,6 +133,8 @@ public struct ArgMaxSampler: LogitSampler {

/// Sampler that uses `topP` and `temperature` to sample the logits.
public struct TopPSampler: LogitSampler {
private static let randomStateLock = NSLock()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't protect it -- randomState is global and this is only protecting callers of TopPSampler. For example it will not guard against concurrent use in CategoricalSampler.

The better way to fix this would be to have random state scoped to the sampler itself, see:

To use locks, all callers of Random would have to use the same lock. Actually it is more complicated than that -- the calls to globalState are themselves thread safe:

but the calls to evaluate the resulting MLXArrays are not -- you need to guard the eval sites.

@@ -166,6 +168,10 @@ public struct TopPSampler: LogitSampler {
logits = logits.asType(.float32)
}

// Thread-safe sampling to prevent concurrent access to global random state
TopPSampler.randomStateLock.lock()
defer { TopPSampler.randomStateLock.unlock() }
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW the typical way to use a lock like this is:

lock.withLock {
    compiledTopPSampling(...)
}

but see my other comment on the use of locks to guard this

@@ -267,6 +279,9 @@ public struct RepetitionContext: LogitProcessor {
///
/// Note: this uses `asyncEval()` and there may be an async evaluation running after a call to `next()`.
public struct TokenIterator: Sequence, IteratorProtocol {
// Global lock to protect MLX evaluation operations
private static let mlxEvalLock = NSLock()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See:

This guards only concurrent calls in TokenIterator. In theory calls to eval() and asyncEval() should be thread safe as long as callers are using entirely distinct MLXArrays / compute graphs. In practice, that was never really a guarantee from mlx::core and in mlx-swift 0.25.1 we found new issues around this (changes on the core side).

The evalLock in mlx-swift is wider than just eval -- it has to guard a number of calls. It may be removed sometime in the future if we can restore the thread safe behavior in mlx::core.

Anyway, that said, this does not guard against concurrent use of the same model (which would still be a problem), nor does it add anything over the lock already in mlx-swift (that I can see, anyway).

@davidkoski
Copy link
Collaborator

I am curious about the use case where you encountered errors/crashes. I don't think the locks added here are the correct way to protect the state -- they are either too narrow (sampling) or redundant (evals in the iterator). If the use case is multiple threads evaluating the same model, then I don't think these are sufficient.

I do agree with your Problem statement -- there are thread safety concerns here, but I think we need different approaches if these are important to guard against. Many of the threading issues are guarded against, but perhaps not all. Can you please describe how you are encountering these?

@sxy-trans-n
Copy link
Author

@davidkoski

Thank you for the detailed explanation! This clarifies why our locks are insufficient.

Our Use Case

We're running Swama as an OpenAI HTTP server where multiple concurrent requests hit the same model instance. We currently serialize all model access to avoid crashes, but want to enable parallelism for better throughput.

The Problem

We hit [eval] Attempting to eval an array without a primitive crashes in two scenarios:

  1. Non-deterministic sampling (temp > 0): CategoricalSampler/TopPSampler concurrently access MLXRandom.globalState
  2. Even deterministic sampling (temp = 0): Multiple ArgMaxSampler requests cause concurrent asyncEval() calls to interfere

Questions

  1. Should we refactor samplers to use withRandomState instead of compile(inputs: [MLXRandom.globalState], ...)?
  2. For our HTTP server use case, what's the recommended approach - separate model instances per request or something else?
  3. Are the asyncEval() issues with ArgMaxSampler expected to be handled by MLX-Swift's existing evalLock?

Next Steps

If withRandomState is the right direction, we'd be happy to implement and test the fix in our fork, then submit a PR. We have a stress test that consistently reproduces the issues if that would be helpful.

We appreciate your guidance on the proper solution!

@ronaldmannak
Copy link
Contributor

I'm very interested in concurrency support like what Ollama does too. And while there are some Swift-concurrency issues, I wonder if Swift-concurrency is the root cause of the issues you've encountered. I haven't looked in it in detail, but assume state (e.g. KV cache) is the main reason MLX-Swift LM can't handle concurrent requests correctly. Let me know if I'm mistaken here

@davidkoski
Copy link
Collaborator

If withRandomState is the right direction, we'd be happy to implement and test the fix in our fork, then submit a PR. We have a stress test that consistently reproduces the issues if that would be helpful.

I think that is the right direction -- that will give the samplers independent random state, you just need to make sure that each thread of execution has its own samplers.

I think any issues you can find with the stress test would be awesome.

I believe that:

  • fully evaluated model weights
  • independant random state
  • independant KVCache state

Should give us multithreaded evaluation -- I have done something like this in the past where I had two VLMs running at once. You need to be careful of the prompt processing because that is submitting larger batches and those need to finish before the next piece of work can queue up.

We can also set up some integration tests like this:

That can easily do some multithreaded evaluation and we can use this to show 1) how to do it and 2) make sure it keeps working as expected.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants