File tree Expand file tree Collapse file tree 1 file changed +4
-2
lines changed Expand file tree Collapse file tree 1 file changed +4
-2
lines changed Original file line number Diff line number Diff line change @@ -2918,13 +2918,15 @@ def bind_kv_cache(
2918
2918
extract_layer_index (layer_name )
2919
2919
for layer_name in layer_need_kv_cache ))
2920
2920
2921
+ # Map from layer_name to the kv cache layer idx.
2922
+ layer_name_2_kv_cache_index = dict ()
2921
2923
for layer_name in layer_need_kv_cache :
2922
- # 1. Get the kv_cache_idx of the target_layer_name.
2923
2924
target_layer_name = shared_kv_cache_layers .get (layer_name , layer_name )
2924
2925
kv_cache_idx = layer_index_sorted .index (
2925
2926
extract_layer_index (target_layer_name ))
2927
+ layer_name_2_kv_cache_index [layer_name ] = kv_cache_idx
2926
2928
2927
- # 2. Bind kv_cache to forward_ctx.
2929
+ for layer_name , kv_cache_idx in layer_name_2_kv_cache_index . items ():
2928
2930
forward_ctx = ctx [layer_name ]
2929
2931
assert len (forward_ctx .kv_cache ) == len (kv_cache )
2930
2932
for ve , ve_kv_cache in enumerate (kv_cache ):
You can’t perform that action at this time.
0 commit comments