Skip to content

[Core] Avoid KVCacheBlock.__eq__ invocations in FreeKVCacheBlockQueue #21005

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

JialinOuyang-Meta
Copy link

@JialinOuyang-Meta JialinOuyang-Meta commented Jul 15, 2025

Summary:

Optimizations

As a common trick for doubly linked list implementation, introducing fake head and tail nodes would significantly reduce the implementation overhead, and help us to get rid of dataclass.eq comparison easily.

  • No dataclass.eq invocation
  • Shorter code
  • Branchless

All these combined should yield significant perf improvement for this piece of code.

Observations

Per vLLM profiling, kv_cache_manager.allocate_slots consumed non-negligible cost for each prefill.
Screenshot 2025-07-14 at 10 26 07 AM

|{F1980260529}|{F1980260481}|{F1980260497}|

By zooming in, we could see the stack of FreeKVCacheBlockQueue.popleft is non-trivial. popleft -> remove -> string.eq which is mainly coming from dataclasses (i.e. KVCacheBlock) equal comparison.

Per dataclasses python library doc

dataclasses.dataclass(*, init=True, repr=True, eq=True, order=False, unsafe_hash=False, frozen=False, match_args=True, kw_only=False, slots=False, weakref_slot=False)

eq: If true (the default), an __eq__() method will be generated. This method compares the class as if it were a tuple of its fields, in order. Both instances in the comparison must be of the identical type.

If the class already defines __eq__(), this parameter is ignored.

Test Plan:

Result

Typically, block_size is set to 16, so in production usage, we might likely allocate 10-1000 blocks. In this range, the optimization gave us up to ~1ms TTFT savings (the improvements are more significant on long inputs).

Benchmark

After
Screenshot 2025-07-15 at 10 25 28 AM
Before
Screenshot 2025-07-15 at 10 23 56 AM|

Stack

After
Screenshot 2025-07-14 at 10 25 04 AM
Before
Screenshot 2025-07-14 at 10 26 07 AM

Rollback Plan:

Reviewed By: CuiCoco

Differential Revision: D78292345

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@facebook-github-bot
Copy link

This pull request was exported from Phabricator. Differential Revision: D78292345

@mergify mergify bot added performance Performance-related issues v1 labels Jul 15, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant performance optimization to the FreeKVCacheBlockQueue by implementing a doubly linked list with sentinel nodes. This change effectively removes expensive __eq__ comparisons on KVCacheBlock dataclasses, which should improve performance as demonstrated by the new benchmark. The implementation is a classic and well-executed approach.

My review focuses on ensuring the robustness of this new implementation. I've identified a couple of areas where adding validation checks could prevent potential crashes from state inconsistencies, making the system more resilient. These changes should have a negligible performance impact while significantly improving debuggability and correctness guarantees.

JialinOuyang-Meta added a commit to JialinOuyang-Meta/vllm-jialino that referenced this pull request Jul 15, 2025
…project#21005)

Summary:

# Optimizations
As a common trick for doubly linked list implementation, introducing fake head and tail nodes would significantly reduce the implementation overhead, and help us to get rid of dataclass.__eq__ comparison easily.
- No dataclass.__eq__ invocation
- Shorter code
- Branchless

All these combined should yield significant perf improvement for this piece of code.

# Observations
Per vLLM profiling, kv_cache_manager.allocate_slots consumed non-negligible cost for each prefill.
|{F1980260529}|{F1980260481}|{F1980260497}|

By zooming in, we could see the stack of FreeKVCacheBlockQueue.popleft is non-trivial. popleft -> remove -> string.__eq__ which is mainly coming from dataclasses (i.e. KVCacheBlock) equal comparison.

Per [dataclasses python library doc](https://docs.python.org/3/library/dataclasses.html#dataclasses.dataclass)
```
dataclasses.dataclass(*, init=True, repr=True, eq=True, order=False, unsafe_hash=False, frozen=False, match_args=True, kw_only=False, slots=False, weakref_slot=False)

eq: If true (the default), an __eq__() method will be generated. This method compares the class as if it were a tuple of its fields, in order. Both instances in the comparison must be of the identical type.

If the class already defines __eq__(), this parameter is ignored.
```

Test Plan:
# Result
Typically, block_size is set to 16, so in production usage, we might likely allocate 10-1000 blocks. In this range, the optimization gave us up to ~1ms TTFT savings (the improvements are more significant on long inputs).

|After|Before|
|{F1980286936}|{F1980286941}|

Rollback Plan:

Reviewed By: CuiCoco

Differential Revision: D78292345
JialinOuyang-Meta added a commit to JialinOuyang-Meta/vllm-jialino that referenced this pull request Jul 15, 2025
…project#21005)

Summary:

# Optimizations
As a common trick for doubly linked list implementation, introducing fake head and tail nodes would significantly reduce the implementation overhead, and help us to get rid of dataclass.__eq__ comparison easily.
- No dataclass.__eq__ invocation
- Shorter code
- Branchless

All these combined should yield significant perf improvement for this piece of code.

# Observations
Per vLLM profiling, kv_cache_manager.allocate_slots consumed non-negligible cost for each prefill.
|{F1980260529}|{F1980260481}|{F1980260497}|

By zooming in, we could see the stack of FreeKVCacheBlockQueue.popleft is non-trivial. popleft -> remove -> string.__eq__ which is mainly coming from dataclasses (i.e. KVCacheBlock) equal comparison.

Per [dataclasses python library doc](https://docs.python.org/3/library/dataclasses.html#dataclasses.dataclass)
```
dataclasses.dataclass(*, init=True, repr=True, eq=True, order=False, unsafe_hash=False, frozen=False, match_args=True, kw_only=False, slots=False, weakref_slot=False)

eq: If true (the default), an __eq__() method will be generated. This method compares the class as if it were a tuple of its fields, in order. Both instances in the comparison must be of the identical type.

If the class already defines __eq__(), this parameter is ignored.
```

Test Plan:
# Result
Typically, block_size is set to 16, so in production usage, we might likely allocate 10-1000 blocks. In this range, the optimization gave us up to ~1ms TTFT savings (the improvements are more significant on long inputs).

|After|Before|
|{F1980286936}|{F1980286941}|

Rollback Plan:

Reviewed By: CuiCoco

Differential Revision: D78292345
@facebook-github-bot
Copy link

This pull request was exported from Phabricator. Differential Revision: D78292345

1 similar comment
@facebook-github-bot
Copy link

This pull request was exported from Phabricator. Differential Revision: D78292345

JialinOuyang-Meta added a commit to JialinOuyang-Meta/vllm-jialino that referenced this pull request Jul 15, 2025
…project#21005)

Summary:
Pull Request resolved: vllm-project#21005

# Optimizations
As a common trick for doubly linked list implementation, introducing fake head and tail nodes would significantly reduce the implementation overhead, and help us to get rid of dataclass.__eq__ comparison easily.
- No dataclass.__eq__ invocation
- Shorter code
- Branchless

All these combined should yield significant perf improvement for this piece of code.

# Observations
Per vLLM profiling, kv_cache_manager.allocate_slots consumed non-negligible cost for each prefill.
|{F1980260529}|{F1980260481}|{F1980260497}|

By zooming in, we could see the stack of FreeKVCacheBlockQueue.popleft is non-trivial. popleft -> remove -> string.__eq__ which is mainly coming from dataclasses (i.e. KVCacheBlock) equal comparison.

Per [dataclasses python library doc](https://docs.python.org/3/library/dataclasses.html#dataclasses.dataclass)
```
dataclasses.dataclass(*, init=True, repr=True, eq=True, order=False, unsafe_hash=False, frozen=False, match_args=True, kw_only=False, slots=False, weakref_slot=False)

eq: If true (the default), an __eq__() method will be generated. This method compares the class as if it were a tuple of its fields, in order. Both instances in the comparison must be of the identical type.

If the class already defines __eq__(), this parameter is ignored.
```

Test Plan:
# Result
Typically, block_size is set to 16, so in production usage, we might likely allocate 10-1000 blocks. In this range, the optimization gave us up to ~1ms TTFT savings (the improvements are more significant on long inputs).

|After|Before|
|{F1980286936}|{F1980286941}|

Rollback Plan:

Reviewed By: CuiCoco

Differential Revision: D78292345
@JialinOuyang-Meta JialinOuyang-Meta changed the title Avoid KVCacheBlock.__eq__ invocations in FreeKVCacheBlockQueue [Core] Avoid KVCacheBlock.__eq__ invocations in FreeKVCacheBlockQueue Jul 15, 2025
…project#21005)

Summary:
Pull Request resolved: vllm-project#21005

# Optimizations
As a common trick for doubly linked list implementation, introducing fake head and tail nodes would significantly reduce the implementation overhead, and help us to get rid of dataclass.__eq__ comparison easily.
- No dataclass.__eq__ invocation
- Shorter code
- Branchless

All these combined should yield significant perf improvement for this piece of code.

# Observations
Per vLLM profiling, kv_cache_manager.allocate_slots consumed non-negligible cost for each prefill.
|{F1980260529}|{F1980260481}|{F1980260497}|

By zooming in, we could see the stack of FreeKVCacheBlockQueue.popleft is non-trivial. popleft -> remove -> string.__eq__ which is mainly coming from dataclasses (i.e. KVCacheBlock) equal comparison.

Per [dataclasses python library doc](https://docs.python.org/3/library/dataclasses.html#dataclasses.dataclass)
```
dataclasses.dataclass(*, init=True, repr=True, eq=True, order=False, unsafe_hash=False, frozen=False, match_args=True, kw_only=False, slots=False, weakref_slot=False)

eq: If true (the default), an __eq__() method will be generated. This method compares the class as if it were a tuple of its fields, in order. Both instances in the comparison must be of the identical type.

If the class already defines __eq__(), this parameter is ignored.
```

Test Plan:
# Result
Typically, block_size is set to 16, so in production usage, we might likely allocate 10-1000 blocks. In this range, the optimization gave us up to ~1ms TTFT savings (the improvements are more significant on long inputs).

|After|Before|
|{F1980286936}|{F1980286941}|

Rollback Plan:

Reviewed By: CuiCoco

Differential Revision: D78292345

Signed-off-by: Jialin Ouyang <jialino@meta.com>
Signed-off-by: Jialin Ouyang <jialino@meta.com>
Signed-off-by: Jialin Ouyang <jialino@meta.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance Performance-related issues v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants