Skip to content

Commit 4e92996

Browse files
refactor
Signed-off-by: Congcong Chen <congcongchen@microsoft.com>
1 parent 4303b78 commit 4e92996

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

vllm/utils/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2909,13 +2909,15 @@ def bind_kv_cache(
29092909
extract_layer_index(layer_name)
29102910
for layer_name in layer_need_kv_cache))
29112911

2912+
# Map from layer_name to the kv cache layer idx.
2913+
layer_name_2_kv_cache_index = dict()
29122914
for layer_name in layer_need_kv_cache:
2913-
# 1. Get the kv_cache_idx of the target_layer_name.
29142915
target_layer_name = shared_kv_cache_layers.get(layer_name, layer_name)
29152916
kv_cache_idx = layer_index_sorted.index(
29162917
extract_layer_index(target_layer_name))
2918+
layer_name_2_kv_cache_index[layer_name] = kv_cache_idx
29172919

2918-
# 2. Bind kv_cache to forward_ctx.
2920+
for layer_name, kv_cache_idx in layer_name_2_kv_cache_index.items():
29192921
forward_ctx = ctx[layer_name]
29202922
assert len(forward_ctx.kv_cache) == len(kv_cache)
29212923
for ve, ve_kv_cache in enumerate(kv_cache):

0 commit comments

Comments
 (0)