@@ -43,6 +43,7 @@ def __init__(
43
43
head_dim ,
44
44
cache_type : QuantizedCacheType = QuantizedCacheType .AffineSymmetric ,
45
45
use_custom_update_cache_op : bool = False ,
46
+ return_float_values : bool = True ,
46
47
):
47
48
super ().__init__ ()
48
49
if cache_type not in (
@@ -57,7 +58,7 @@ def __init__(
57
58
self .use_custom_update_cache_op = use_custom_update_cache_op
58
59
self .quantized_cache_dtype = torch .int8
59
60
self .cache_fp_type = torch .float32
60
- self .return_float_values = True
61
+ self .return_float_values = return_float_values
61
62
self .max_context_length = max_context_length
62
63
cache_shape = (max_batch_size , max_context_length , n_heads , head_dim )
63
64
scale_shape = (max_batch_size , max_context_length , n_heads , 1 )
@@ -400,6 +401,7 @@ def __init__(
400
401
head_dim ,
401
402
cache_type : QuantizedCacheType = QuantizedCacheType .AffineSymmetric ,
402
403
use_custom_update_cache_op : bool = False ,
404
+ return_float_values : bool = True ,
403
405
):
404
406
# Look at attention.py for explanation on why max_context_length * 2
405
407
super ().__init__ (
@@ -409,6 +411,7 @@ def __init__(
409
411
head_dim ,
410
412
cache_type ,
411
413
use_custom_update_cache_op ,
414
+ return_float_values ,
412
415
)
413
416
self .cache_positions_manager = CachePositionsManager (self .max_context_length )
414
417
self .is_ring_buffer = True
@@ -459,6 +462,7 @@ def from_quantized_kv_cache(
459
462
head_dim ,
460
463
kv_cache .cache_type ,
461
464
kv_cache .use_custom_update_cache_op ,
465
+ kv_cache .return_float_values ,
462
466
)
463
467
464
468
@@ -583,4 +587,12 @@ def replace_kv_cache_with_ring_kv_cache(module, layer_sizes):
583
587
# it is not doing causal attention
584
588
if "SDPACustom" in attention .SDPA .__class__ .__name__ :
585
589
attention .SDPA .use_attention_mask = True
590
+ # QuantizedSDPA has to store kv_cache in order to obtrain
591
+ # scales and zero points for k and v cache.
592
+ # So if we replcaed attention module's quantized kv cache with
593
+ # QuantizedRingKVCache then we also have to replace attention's
594
+ # SDPA module kv_cache so that it refers to the same kv_cache
595
+ if "QuantizedSDPA" in attention .SDPA .__class__ .__name__ :
596
+ attention .SDPA .use_attention_mask = True
597
+ attention .SDPA .kv_cache = attention .kv_cache
586
598
return module
0 commit comments