Skip to content

Commit f11e4d3

Browse files
[Executorch][llm] Fix ring kv cache when used with quantized kv cache and sdpa (#12143)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #12132 by @kimishpatel ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/kimishpatel/196/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/kimishpatel/196/head Merge bot PR base: https://github.com/pytorch/executorch/tree/gh/kimishpatel/195/orig Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/kimishpatel/196/orig @diff-train-skip-merge --------- Co-authored-by: Kimish Patel <kimishpatel@fb.com>
1 parent 9905026 commit f11e4d3

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

examples/models/llama/source_transformation/custom_kv_cache.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def __init__(
4343
head_dim,
4444
cache_type: QuantizedCacheType = QuantizedCacheType.AffineSymmetric,
4545
use_custom_update_cache_op: bool = False,
46+
return_float_values: bool = True,
4647
):
4748
super().__init__()
4849
if cache_type not in (
@@ -57,7 +58,7 @@ def __init__(
5758
self.use_custom_update_cache_op = use_custom_update_cache_op
5859
self.quantized_cache_dtype = torch.int8
5960
self.cache_fp_type = torch.float32
60-
self.return_float_values = True
61+
self.return_float_values = return_float_values
6162
self.max_context_length = max_context_length
6263
cache_shape = (max_batch_size, max_context_length, n_heads, head_dim)
6364
scale_shape = (max_batch_size, max_context_length, n_heads, 1)
@@ -400,6 +401,7 @@ def __init__(
400401
head_dim,
401402
cache_type: QuantizedCacheType = QuantizedCacheType.AffineSymmetric,
402403
use_custom_update_cache_op: bool = False,
404+
return_float_values: bool = True,
403405
):
404406
# Look at attention.py for explanation on why max_context_length * 2
405407
super().__init__(
@@ -409,6 +411,7 @@ def __init__(
409411
head_dim,
410412
cache_type,
411413
use_custom_update_cache_op,
414+
return_float_values,
412415
)
413416
self.cache_positions_manager = CachePositionsManager(self.max_context_length)
414417
self.is_ring_buffer = True
@@ -459,6 +462,7 @@ def from_quantized_kv_cache(
459462
head_dim,
460463
kv_cache.cache_type,
461464
kv_cache.use_custom_update_cache_op,
465+
kv_cache.return_float_values,
462466
)
463467

464468

@@ -583,4 +587,12 @@ def replace_kv_cache_with_ring_kv_cache(module, layer_sizes):
583587
# it is not doing causal attention
584588
if "SDPACustom" in attention.SDPA.__class__.__name__:
585589
attention.SDPA.use_attention_mask = True
590+
# QuantizedSDPA has to store kv_cache in order to obtrain
591+
# scales and zero points for k and v cache.
592+
# So if we replcaed attention module's quantized kv cache with
593+
# QuantizedRingKVCache then we also have to replace attention's
594+
# SDPA module kv_cache so that it refers to the same kv_cache
595+
if "QuantizedSDPA" in attention.SDPA.__class__.__name__:
596+
attention.SDPA.use_attention_mask = True
597+
attention.SDPA.kv_cache = attention.kv_cache
586598
return module

0 commit comments

Comments
 (0)