Skip to content

Commit 1e9438e

Browse files
authored
[MISC] Move bind_kv_cache to worker module (#20900)
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
1 parent 697ef76 commit 1e9438e

File tree

6 files changed

+57
-55
lines changed

6 files changed

+57
-55
lines changed

tests/v1/test_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import torch
55

6-
from vllm.v1.utils import bind_kv_cache
6+
from vllm.v1.worker.utils import bind_kv_cache
77

88

99
def test_bind_kv_cache():

vllm/v1/utils.py

Lines changed: 0 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import multiprocessing
55
import time
66
import weakref
7-
from collections import defaultdict
87
from collections.abc import Sequence
98
from multiprocessing import connection
109
from multiprocessing.process import BaseProcess
@@ -14,14 +13,12 @@
1413
import torch
1514

1615
from vllm.logger import init_logger
17-
from vllm.model_executor.models.utils import extract_layer_index
1816
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
1917
usage_message)
2018
from vllm.utils import (get_open_port, get_open_zmq_ipc_path, get_tcp_uri,
2119
kill_process_tree)
2220

2321
if TYPE_CHECKING:
24-
from vllm.attention.layer import Attention
2522
from vllm.v1.engine.coordinator import DPCoordinator
2623
from vllm.v1.engine.utils import (CoreEngineActorManager,
2724
CoreEngineProcManager)
@@ -275,51 +272,6 @@ def shutdown(procs: list[BaseProcess]):
275272
kill_process_tree(pid)
276273

277274

278-
def bind_kv_cache(
279-
kv_caches: dict[str, torch.Tensor],
280-
forward_context: dict[str, "Attention"],
281-
runner_kv_caches: list[torch.Tensor],
282-
) -> None:
283-
"""
284-
Bind the allocated KV cache to both ModelRunner and forward context so
285-
that the KV cache can be used in the forward pass.
286-
287-
This function:
288-
1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with
289-
kv_caches.
290-
2) Associates each attention layer in the `forward_context` with its
291-
corresponding KV cache in kv_caches.
292-
293-
Args:
294-
kv_caches: The allocated kv_caches with layer names as keys.
295-
forward_context: The global forward context containing all Attention
296-
layers with layer names as keys.
297-
runner_kv_caches: The kv_cache declared by ModelRunner.
298-
"""
299-
# Bind kv_caches to ModelRunner
300-
assert len(runner_kv_caches) == 0
301-
302-
# Convert kv_caches dict to a list of tensors in the order of layer_index.
303-
index2name = defaultdict(list)
304-
for layer_name in kv_caches:
305-
index2name[extract_layer_index(layer_name)].append(layer_name)
306-
307-
for layer_index in sorted(index2name.keys()):
308-
layer_names = index2name[layer_index]
309-
if len(layer_names) > 1:
310-
# One typical case is encoder-decoder model, e.g., bart.
311-
# The cross attention and self attention in the same decoder layer
312-
# has different layer_name but the same layer_index.
313-
raise NotImplementedError
314-
layer_name = layer_names[0]
315-
runner_kv_caches.append(kv_caches[layer_name])
316-
317-
# Bind kv_caches to forward context
318-
for layer_name, kv_cache in kv_caches.items():
319-
# NOTE: Use list because of v0 PP virtual engine.
320-
forward_context[layer_name].kv_cache = [kv_cache]
321-
322-
323275
def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor,
324276
length: int) -> torch.Tensor:
325277
"""

vllm/v1/worker/gpu_model_runner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,13 @@
6262
from vllm.v1.spec_decode.medusa import MedusaProposer
6363
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
6464
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
65-
from vllm.v1.utils import bind_kv_cache
6665
from vllm.v1.worker.block_table import BlockTable
6766
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
6867
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
6968

7069
from ..sample.logits_processor import LogitsProcessorManager
71-
from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing,
70+
from .utils import (bind_kv_cache, gather_mm_placeholders,
71+
initialize_kv_cache_for_kv_sharing,
7272
sanity_check_mm_encoder_outputs, scatter_mm_placeholders)
7373

7474
if TYPE_CHECKING:

vllm/v1/worker/tpu_model_runner.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,10 @@
4242
LogprobsTensors, ModelRunnerOutput)
4343
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
4444
from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler
45-
from vllm.v1.utils import bind_kv_cache
4645
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
4746
from vllm.v1.worker.tpu_input_batch import CachedRequestState, InputBatch
4847

49-
from .utils import (initialize_kv_cache_for_kv_sharing,
48+
from .utils import (bind_kv_cache, initialize_kv_cache_for_kv_sharing,
5049
sanity_check_mm_encoder_outputs)
5150

5251
if TYPE_CHECKING:

vllm/v1/worker/tpu_worker.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,9 @@
2525
from vllm.v1.kv_cache_interface import (AttentionSpec, KVCacheConfig,
2626
KVCacheSpec)
2727
from vllm.v1.outputs import ModelRunnerOutput
28-
from vllm.v1.utils import bind_kv_cache, report_usage_stats
28+
from vllm.v1.utils import report_usage_stats
2929
from vllm.v1.worker.tpu_model_runner import TPUModelRunner
30+
from vllm.v1.worker.utils import bind_kv_cache
3031

3132
logger = init_logger(__name__)
3233

vllm/v1/worker/utils.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3-
from typing import Optional
3+
from collections import defaultdict
4+
from typing import TYPE_CHECKING, Optional
45

56
import torch
67

78
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
9+
from vllm.model_executor.models.utils import extract_layer_index
810
from vllm.v1.kv_cache_interface import KVCacheGroupSpec
911

12+
if TYPE_CHECKING:
13+
from vllm.attention.layer import Attention
14+
1015

1116
def sanity_check_mm_encoder_outputs(
1217
mm_embeddings: MultiModalEmbeddings,
@@ -110,3 +115,48 @@ def initialize_kv_cache_for_kv_sharing(
110115
kv_caches[layer_name] = kv_caches[target_layer_name]
111116
group_idx = layer_to_kv_cache_group_idx[target_layer_name]
112117
kv_cache_groups[group_idx].layer_names.append(layer_name)
118+
119+
120+
def bind_kv_cache(
121+
kv_caches: dict[str, torch.Tensor],
122+
forward_context: dict[str, "Attention"],
123+
runner_kv_caches: list[torch.Tensor],
124+
) -> None:
125+
"""
126+
Bind the allocated KV cache to both ModelRunner and forward context so
127+
that the KV cache can be used in the forward pass.
128+
129+
This function:
130+
1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with
131+
kv_caches.
132+
2) Associates each attention layer in the `forward_context` with its
133+
corresponding KV cache in kv_caches.
134+
135+
Args:
136+
kv_caches: The allocated kv_caches with layer names as keys.
137+
forward_context: The global forward context containing all Attention
138+
layers with layer names as keys.
139+
runner_kv_caches: The kv_cache declared by ModelRunner.
140+
"""
141+
# Bind kv_caches to ModelRunner
142+
assert len(runner_kv_caches) == 0
143+
144+
# Convert kv_caches dict to a list of tensors in the order of layer_index.
145+
index2name = defaultdict(list)
146+
for layer_name in kv_caches:
147+
index2name[extract_layer_index(layer_name)].append(layer_name)
148+
149+
for layer_index in sorted(index2name.keys()):
150+
layer_names = index2name[layer_index]
151+
if len(layer_names) > 1:
152+
# One typical case is encoder-decoder model, e.g., bart.
153+
# The cross attention and self attention in the same decoder layer
154+
# has different layer_name but the same layer_index.
155+
raise NotImplementedError
156+
layer_name = layer_names[0]
157+
runner_kv_caches.append(kv_caches[layer_name])
158+
159+
# Bind kv_caches to forward context
160+
for layer_name, kv_cache in kv_caches.items():
161+
# NOTE: Use list because of v0 PP virtual engine.
162+
forward_context[layer_name].kv_cache = [kv_cache]

0 commit comments

Comments
 (0)