Skip to content

Commit af72202

Browse files
lgeigerwwl2755-google
authored andcommitted
[Core] Do not copy array during hashing (vllm-project#19484)
Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
1 parent 9f814c9 commit af72202

File tree

3 files changed

+25
-11
lines changed

3 files changed

+25
-11
lines changed

tests/multimodal/test_hasher.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,15 @@ def test_hash_collision_array_shape():
6060

6161
hasher = MultiModalHasher
6262
assert hasher.hash_kwargs(data=arr1) != hasher.hash_kwargs(data=arr2)
63+
64+
65+
def test_hash_non_contiguous_array():
66+
arr = np.arange(24).reshape(4, 6).T
67+
assert not arr.flags.c_contiguous
68+
69+
arr_c = np.ascontiguousarray(arr)
70+
assert arr_c.flags.c_contiguous
71+
72+
hasher = MultiModalHasher
73+
# Both should be hashable and produce the same hashes
74+
assert hasher.hash_kwargs(data=arr) == hasher.hash_kwargs(data=arr_c)

vllm/multimodal/hasher.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import pickle
55
from collections.abc import Iterable, Mapping
6+
from typing import Union
67

78
import numpy as np
89
import torch
@@ -23,11 +24,11 @@
2324
class MultiModalHasher:
2425

2526
@classmethod
26-
def serialize_item(cls, obj: object) -> bytes:
27+
def serialize_item(cls, obj: object) -> Union[bytes, memoryview]:
2728
# Simple cases
2829
if isinstance(obj, str):
2930
return obj.encode("utf-8")
30-
if isinstance(obj, bytes):
31+
if isinstance(obj, (bytes, memoryview)):
3132
return obj
3233
if isinstance(obj, (int, float)):
3334
return np.array(obj).tobytes()
@@ -38,12 +39,13 @@ def serialize_item(cls, obj: object) -> bytes:
3839
if isinstance(obj, torch.Tensor):
3940
return cls.item_to_bytes("tensor", obj.numpy())
4041
if isinstance(obj, np.ndarray):
41-
return cls.item_to_bytes(
42-
"ndarray", {
43-
"dtype": obj.dtype.str,
44-
"shape": obj.shape,
45-
"data": obj.tobytes(),
46-
})
42+
# If the array is non-contiguous, we need to copy it first
43+
arr_data = obj.data if obj.flags.c_contiguous else obj.tobytes()
44+
return cls.item_to_bytes("ndarray", {
45+
"dtype": obj.dtype.str,
46+
"shape": obj.shape,
47+
"data": arr_data,
48+
})
4749

4850
logger.warning(
4951
"No serialization method found for %s. "
@@ -64,7 +66,7 @@ def iter_item_to_bytes(
6466
cls,
6567
key: str,
6668
obj: object,
67-
) -> Iterable[tuple[bytes, bytes]]:
69+
) -> Iterable[tuple[bytes, Union[bytes, memoryview]]]:
6870
# Recursive cases
6971
if isinstance(obj, (list, tuple)):
7072
for i, elem in enumerate(obj):
@@ -73,7 +75,7 @@ def iter_item_to_bytes(
7375
for k, v in obj.items():
7476
yield from cls.iter_item_to_bytes(f"{key}.{k}", v)
7577
else:
76-
key_bytes = cls.serialize_item(key)
78+
key_bytes = key.encode("utf-8")
7779
value_bytes = cls.serialize_item(obj)
7880
yield key_bytes, value_bytes
7981

vllm/v1/serial_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def _encode_ndarray(
140140
) -> tuple[str, tuple[int, ...], Union[int, memoryview]]:
141141
assert self.aux_buffers is not None
142142
# If the array is non-contiguous, we need to copy it first
143-
arr_data = obj.data if obj.data.c_contiguous else obj.tobytes()
143+
arr_data = obj.data if obj.flags.c_contiguous else obj.tobytes()
144144
if not obj.shape or obj.nbytes < self.size_threshold:
145145
# Encode small arrays and scalars inline. Using this extension type
146146
# ensures we can avoid copying when decoding.

0 commit comments

Comments
 (0)