Skip to content

Commit 071c639

Browse files
vMaroonpy-andy-c
authored andcommitted
[Prefix Cache] Add reproducible prefix-cache block hashing using SHA-256 + CBOR (64bit) (vllm-project#20511)
Signed-off-by: Maroon Ayoub <maroon.ayoub@ibm.com>
1 parent e545a3a commit 071c639

File tree

8 files changed

+88
-28
lines changed

8 files changed

+88
-28
lines changed

requirements/common.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,4 @@ python-json-logger # Used by logging as per examples/others/logging_configuratio
4747
scipy # Required for phi-4-multimodal-instruct
4848
ninja # Required for xgrammar, rocm, tpu, xpu
4949
pybase64 # fast base64 implementation
50+
cbor2 # Required for cross-language serialization of hashable objects

requirements/docs.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ ruff
1111
# Required for argparse hook only
1212
-f https://download.pytorch.org/whl/cpu
1313
cachetools
14+
cbor2
1415
cloudpickle
1516
fastapi
1617
msgspec

tests/v1/core/test_kv_cache_utils.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,16 @@
88
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
99
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
1010
from vllm.sampling_params import SamplingParams
11-
from vllm.utils import GiB_bytes, sha256
11+
from vllm.utils import GiB_bytes, sha256, sha256_cbor_64bit
1212
from vllm.v1.core.kv_cache_manager import KVCacheManager
1313
# disable yapf here as it formats differently than isort such that both fail
1414
# yapf: disable
1515
from vllm.v1.core.kv_cache_utils import (
1616
FreeKVCacheBlockQueue, KVCacheBlock, PrefixCachingMetrics,
1717
estimate_max_model_len, generate_block_hash_extra_keys,
1818
get_kv_cache_config, get_max_concurrency_for_kv_cache_config,
19-
hash_block_tokens, hash_request_tokens, unify_kv_cache_configs)
19+
hash_block_tokens, hash_request_tokens, init_none_hash,
20+
unify_kv_cache_configs)
2021
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
2122
KVCacheGroupSpec, KVCacheTensor,
2223
SlidingWindowSpec)
@@ -78,24 +79,27 @@ def new_sliding_window_spec(block_size=16,
7879
sliding_window=sliding_window)
7980

8081

81-
def test_none_hash(monkeypatch):
82+
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash])
83+
def test_none_hash(monkeypatch, hash_fn):
8284
import vllm.v1.core.kv_cache_utils
8385

8486
# case 1: PYTHONHASHSEED is not set, use random
8587
with monkeypatch.context() as m:
8688
m.delenv('PYTHONHASHSEED', raising=False)
8789
reloaded_kv_cache_utils = importlib.reload(vllm.v1.core.kv_cache_utils)
90+
reloaded_kv_cache_utils.init_none_hash(hash_fn)
8891
assert reloaded_kv_cache_utils.NONE_HASH is not None
8992
assert isinstance(reloaded_kv_cache_utils.NONE_HASH, int)
9093
assert reloaded_kv_cache_utils.NONE_HASH != 0
9194

92-
# case 2: PYTHONHASHSEED is set, use the seed
95+
# case 2: PYTHONHASHSEED is set, use the seed and hash_fn
9396
with monkeypatch.context() as m:
9497
m.setenv('PYTHONHASHSEED', 'python hash seed')
9598
reloaded_kv_cache_utils = importlib.reload(vllm.v1.core.kv_cache_utils)
99+
reloaded_kv_cache_utils.init_none_hash(hash_fn)
96100
assert reloaded_kv_cache_utils.NONE_HASH is not None
97101
assert isinstance(reloaded_kv_cache_utils.NONE_HASH, int)
98-
assert sha256('python hash seed') == reloaded_kv_cache_utils.NONE_HASH
102+
assert hash_fn('python hash seed') == reloaded_kv_cache_utils.NONE_HASH
99103

100104

101105
def test_kv_cache_block():
@@ -287,9 +291,10 @@ def test_generate_block_hash_extra_keys_cache_salt():
287291
assert next_mm_idx == 1
288292

289293

290-
@pytest.mark.parametrize("hash_fn", [sha256, hash])
294+
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash])
291295
def test_hash_block_tokens(hash_fn):
292296
import vllm.v1.core.kv_cache_utils
297+
init_none_hash(hash_fn)
293298
parent_block_hash = 123
294299
curr_block_token_ids = (1, 2, 3)
295300
extra_keys = ("key1", "key2")
@@ -303,9 +308,10 @@ def test_hash_block_tokens(hash_fn):
303308
assert block_hash.extra_keys == extra_keys
304309

305310

306-
@pytest.mark.parametrize("hash_fn", [sha256, hash])
311+
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash])
307312
def test_hash_request_tokens(hash_fn):
308313
import vllm.v1.core.kv_cache_utils
314+
init_none_hash(hash_fn)
309315
request = make_request(
310316
request_id=0,
311317
prompt_token_ids=[_ for _ in range(6)],
@@ -332,8 +338,10 @@ def test_hash_request_tokens(hash_fn):
332338
assert block_hashes[1].extra_keys == ("hash2", )
333339

334340

335-
@pytest.mark.parametrize("hash_fn", [sha256, hash])
341+
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash])
336342
def test_hash_tokens_different_mm_input(hash_fn):
343+
init_none_hash(hash_fn)
344+
337345
request1 = make_request(
338346
request_id=0,
339347
prompt_token_ids=[_ for _ in range(6)],
@@ -359,8 +367,10 @@ def test_hash_tokens_different_mm_input(hash_fn):
359367
assert block_hashes1[1] != block_hashes2[1]
360368

361369

362-
@pytest.mark.parametrize("hash_fn", [sha256, hash])
370+
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash])
363371
def test_hash_request_tokens_no_mm_inputs(hash_fn):
372+
init_none_hash(hash_fn)
373+
364374
request = make_request(
365375
request_id=0,
366376
prompt_token_ids=[_ for _ in range(6)],
@@ -916,4 +926,4 @@ def test_get_kv_cache_config():
916926
],
917927
kv_cache_groups=[
918928
KVCacheGroupSpec(["layer_1", "layer_2"], new_kv_cache_spec())
919-
])
929+
])

tests/v1/core/test_prefix_caching.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,12 @@
1111
from vllm.distributed.kv_events import AllBlocksCleared, BlockRemoved
1212
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
1313
from vllm.sampling_params import SamplingParams
14-
from vllm.utils import sha256
14+
from vllm.utils import sha256, sha256_cbor_64bit
1515
from vllm.v1.core.block_pool import BlockPool
1616
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
1717
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
18-
KVCacheBlock, hash_block_tokens)
18+
KVCacheBlock, hash_block_tokens,
19+
init_none_hash)
1920
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
2021
KVCacheGroupSpec, SlidingWindowSpec)
2122

@@ -91,7 +92,7 @@ def make_kv_cache_config_hybrid_model(block_size: int,
9192
)
9293

9394

94-
@pytest.mark.parametrize("hash_algo", ["sha256", "hash"])
95+
@pytest.mark.parametrize("hash_algo", ["sha256", "sha256_cbor_64bit", "hash"])
9596
def test_prefill(hash_algo):
9697
manager = KVCacheManager(
9798
make_kv_cache_config(16, 11),
@@ -101,7 +102,8 @@ def test_prefill(hash_algo):
101102
)
102103

103104
# choose the hash function according to the parameter
104-
hash_fn = sha256 if hash_algo == "sha256" else hash
105+
hash_fn = (sha256_cbor_64bit if hash_algo == "sha256_cbor_64bit" else
106+
sha256 if hash_algo == "sha256" else hash)
105107

106108
# Complete 3 blocks (48 tokens)
107109
common_token_ids = [i for i in range(3) for _ in range(16)]
@@ -696,12 +698,14 @@ def test_basic_prefix_caching_disabled():
696698
assert not blocks
697699

698700

699-
@pytest.mark.parametrize("hash_fn", [sha256, hash])
701+
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash])
700702
def test_cache_blocks(hash_fn):
701703
"""
702704
This is a unit test that tests the correctness of the _cache_full_blocks
703705
function of KVCacheManager.
704706
"""
707+
init_none_hash(hash_fn)
708+
705709
block_size = 4
706710
block_pool = BlockPool(
707711
num_gpu_blocks=5,

vllm/config.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1564,7 +1564,7 @@ def get_and_verify_max_len(self, max_model_len: int):
15641564

15651565
BlockSize = Literal[1, 8, 16, 32, 64, 128]
15661566
CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2"]
1567-
PrefixCachingHashAlgo = Literal["builtin", "sha256"]
1567+
PrefixCachingHashAlgo = Literal["builtin", "sha256", "sha256_cbor_64bit"]
15681568

15691569

15701570
@config
@@ -1609,7 +1609,12 @@ class CacheConfig:
16091609
prefix_caching_hash_algo: PrefixCachingHashAlgo = "builtin"
16101610
"""Set the hash algorithm for prefix caching:\n
16111611
- "builtin" is Python's built-in hash.\n
1612-
- "sha256" is collision resistant but with certain overheads."""
1612+
- "sha256" is collision resistant but with certain overheads.
1613+
This option uses Pickle for object serialization before hashing.\n
1614+
- "sha256_cbor_64bit" provides a reproducible, cross-language compatible
1615+
hash. It serializes objects using canonical CBOR and hashes them with
1616+
SHA-256. The resulting hash consists of the lower 64 bits of the SHA-256
1617+
digest."""
16131618
cpu_offload_gb: float = 0
16141619
"""The space in GiB to offload to CPU, per GPU. Default is 0, which means
16151620
no offloading. Intuitively, this argument can be seen as a virtual way to

vllm/utils/__init__.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
from uuid import uuid4
5353

5454
import cachetools
55+
import cbor2
5556
import cloudpickle
5657
import numpy as np
5758
import numpy.typing as npt
@@ -3177,6 +3178,29 @@ def sha256(input) -> int:
31773178
byteorder="big")
31783179

31793180

3181+
def sha256_cbor_64bit(input) -> int:
3182+
"""
3183+
Hash objects using CBOR serialization and SHA-256, then truncate to 64bits.
3184+
3185+
This option is useful for non-Python-dependent serialization and hashing.
3186+
3187+
Args:
3188+
input: Object to be serialized and hashed. Supported types include
3189+
basic Python types and complex structures like lists, tuples, and
3190+
dictionaries.
3191+
Custom classes must implement CBOR serialization methods.
3192+
3193+
Returns:
3194+
An integer in the range [0, 2^64-1] representing the lower 64 bits
3195+
of the SHA-256 hash of the CBOR serialized input.
3196+
"""
3197+
input_bytes = cbor2.dumps(input, canonical=True)
3198+
full_hash = int.from_bytes(hashlib.sha256(input_bytes).digest(),
3199+
byteorder="big")
3200+
3201+
return full_hash & ((1 << 64) - 1)
3202+
3203+
31803204
def is_torch_equal_or_newer(target: str) -> bool:
31813205
"""Check if the installed torch version is >= the target version.
31823206

vllm/v1/core/kv_cache_manager.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77

88
from vllm.distributed.kv_events import KVCacheEvent
99
from vllm.logger import init_logger
10-
from vllm.utils import sha256
10+
from vllm.utils import sha256, sha256_cbor_64bit
1111
from vllm.v1.core.kv_cache_coordinator import get_kv_cache_coordinator
1212
from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock,
13-
hash_request_tokens)
13+
hash_request_tokens, init_none_hash)
1414
from vllm.v1.kv_cache_interface import KVCacheConfig
1515
from vllm.v1.metrics.stats import PrefixCacheStats
1616
from vllm.v1.request import Request, RequestStatus
@@ -79,7 +79,10 @@ def __init__(
7979
self.max_model_len = max_model_len
8080

8181
self.enable_caching = enable_caching
82-
self.caching_hash_fn = sha256 if caching_hash_algo == "sha256" else hash
82+
self.caching_hash_fn = (
83+
sha256_cbor_64bit if caching_hash_algo == "sha256_cbor_64bit" else
84+
sha256 if caching_hash_algo == "sha256" else hash)
85+
init_none_hash(self.caching_hash_fn)
8386
self.use_eagle = use_eagle
8487
self.log_stats = log_stats
8588
# FIXME: make prefix cache stats conditional on log_stats

vllm/v1/core/kv_cache_utils.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from vllm.config import VllmConfig
1212
from vllm.logger import init_logger
13-
from vllm.utils import GiB_bytes, cdiv, sha256
13+
from vllm.utils import GiB_bytes, cdiv, sha256_cbor_64bit
1414
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
1515
KVCacheGroupSpec, KVCacheSpec,
1616
KVCacheTensor, SlidingWindowSpec)
@@ -46,18 +46,30 @@ def get_hash_value(self) -> int:
4646
return self.block_hash.hash_value
4747

4848

49-
# The hash seed for the first block of the prefix block sequence.
50-
#
51-
# Even if the hash function is the builtin hash(), we use sha256 to generate
52-
# the initial hash to simplify the code. This is not performance critical
53-
# as it is done one per process.
49+
# The hash seed for the first block of any prefix block sequence.
5450
#
5551
# We use a random value to avoid hash collisions or PYTHONHASHSEED environment
5652
# variable if set such that processes can share the seed if needed.
5753
# This aligns with the behavior of Python's hash() function, which also uses
5854
# a random seed if PYTHONHASHSEED is not set.
59-
NONE_HASH = int.from_bytes(os.urandom(32), byteorder="big") if os.getenv(
60-
"PYTHONHASHSEED") is None else sha256(os.getenv("PYTHONHASHSEED"))
55+
#
56+
# The function `init_none_hash` initializes this variable globally.
57+
NONE_HASH: int
58+
59+
60+
def init_none_hash(hash_fn: Callable):
61+
global NONE_HASH
62+
63+
hash_seed = os.getenv("PYTHONHASHSEED")
64+
if hash_seed is None and hash_fn is sha256_cbor_64bit:
65+
logger.warning(
66+
"PYTHONHASHSEED is not set. This will lead to non-reproducible "
67+
"block-hashes when using sha256_cbor_64bit as the hash function."
68+
"Consider setting PYTHONHASHSEED to a fixed value for "
69+
"reproducibility.")
70+
71+
NONE_HASH = (int.from_bytes(os.urandom(32), byteorder="big")
72+
if hash_seed is None else hash_fn(hash_seed))
6173

6274

6375
class PrefixCachingMetrics:

0 commit comments

Comments
 (0)