@@ -54,6 +54,7 @@ def get_kv_cache_shape(
54
54
) -> Tuple [int , ...]:
55
55
if block_size % 16 != 0 :
56
56
raise ValueError ("Block size must be a multiple of 16." )
57
+ assert num_kv_heads % 2 == 0 , "num_kv_heads must be divisible by 2"
57
58
return (2 , 2 , num_blocks , block_size , num_kv_heads // 2 , head_size )
58
59
59
60
@staticmethod
@@ -872,7 +873,7 @@ def forward(
872
873
k1 , k2 = self .split_heads (k )
873
874
v1 , v2 = self .split_heads (v )
874
875
875
- # kv_cache shape is (2, 2, num_blocks, block_size * num_kv_heads // 2 * head_size)
876
+ # kv_cache shape is (2, 2, num_blocks, block_size, num_kv_heads // 2, head_size)
876
877
# Split by half along the first dimension.
877
878
kv_cache1 , kv_cache2 = self .split_kv_cache (kv_cache )
878
879
assert kv_cache1 .is_contiguous (), "kv_cache1 is not contiguous"
@@ -909,7 +910,7 @@ def forward(
909
910
else : # re-use the kv cache, full attention
910
911
q = q .view (- 1 , self .num_heads , self .head_size )
911
912
q1 , q2 = self .split_heads (q )
912
- # kv_cache shape is (2, num_blocks, block_size * num_kv_heads * head_size)
913
+ # kv_cache shape is (2, num_blocks, block_size, num_kv_heads, head_size)
913
914
kv_cache1 , kv_cache2 = self .split_kv_cache (kv_cache )
914
915
key_cache1 , value_cache1 = kv_cache1 [0 ], kv_cache1 [1 ]
915
916
key_cache2 , value_cache2 = kv_cache2 [0 ], kv_cache2 [1 ]
0 commit comments