Skip to content

Commit 1ff7eeb

Browse files
refactor
Signed-off-by: Congcong Chen <congcongchen@microsoft.com>
1 parent 88cb796 commit 1ff7eeb

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

vllm/utils/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2918,13 +2918,15 @@ def bind_kv_cache(
29182918
extract_layer_index(layer_name)
29192919
for layer_name in layer_need_kv_cache))
29202920

2921+
# Map from layer_name to the kv cache layer idx.
2922+
layer_name_2_kv_cache_index = dict()
29212923
for layer_name in layer_need_kv_cache:
2922-
# 1. Get the kv_cache_idx of the target_layer_name.
29232924
target_layer_name = shared_kv_cache_layers.get(layer_name, layer_name)
29242925
kv_cache_idx = layer_index_sorted.index(
29252926
extract_layer_index(target_layer_name))
2927+
layer_name_2_kv_cache_index[layer_name] = kv_cache_idx
29262928

2927-
# 2. Bind kv_cache to forward_ctx.
2929+
for layer_name, kv_cache_idx in layer_name_2_kv_cache_index.items():
29282930
forward_ctx = ctx[layer_name]
29292931
assert len(forward_ctx.kv_cache) == len(kv_cache)
29302932
for ve, ve_kv_cache in enumerate(kv_cache):

0 commit comments

Comments
 (0)