File tree Expand file tree Collapse file tree 1 file changed +4
-2
lines changed Expand file tree Collapse file tree 1 file changed +4
-2
lines changed Original file line number Diff line number Diff line change @@ -2909,13 +2909,15 @@ def bind_kv_cache(
2909
2909
extract_layer_index (layer_name )
2910
2910
for layer_name in layer_need_kv_cache ))
2911
2911
2912
+ # Map from layer_name to the kv cache layer idx.
2913
+ layer_name_2_kv_cache_index = dict ()
2912
2914
for layer_name in layer_need_kv_cache :
2913
- # 1. Get the kv_cache_idx of the target_layer_name.
2914
2915
target_layer_name = shared_kv_cache_layers .get (layer_name , layer_name )
2915
2916
kv_cache_idx = layer_index_sorted .index (
2916
2917
extract_layer_index (target_layer_name ))
2918
+ layer_name_2_kv_cache_index [layer_name ] = kv_cache_idx
2917
2919
2918
- # 2. Bind kv_cache to forward_ctx.
2920
+ for layer_name , kv_cache_idx in layer_name_2_kv_cache_index . items ():
2919
2921
forward_ctx = ctx [layer_name ]
2920
2922
assert len (forward_ctx .kv_cache ) == len (kv_cache )
2921
2923
for ve , ve_kv_cache in enumerate (kv_cache ):
You can’t perform that action at this time.
0 commit comments