@@ -678,7 +678,6 @@ def __init__(
678
678
self .attn_type = attn_type
679
679
680
680
self .lambda_full = None
681
- # self.subln = nn.RMSNorm(2 * self.head_size, eps=1e-5, elementwise_affine=True)
682
681
self .subln = self .differential_flash_attention_config ["subln" ]
683
682
684
683
def split_heads (self , x ):
@@ -705,9 +704,6 @@ def populate_kv_cache(self,
705
704
if (kv_cache .numel () > 0 ):
706
705
if (key is not None ) and (value is not None ):
707
706
updated_slot_mapping = attn_metadata .slot_mapping
708
- # previous_key_cache_sum = key_cache.sum()
709
- # previous_value_cache_sum = value_cache.sum()
710
-
711
707
torch .ops ._C_cache_ops .reshape_and_cache_flash (
712
708
key ,
713
709
value ,
@@ -718,12 +714,6 @@ def populate_kv_cache(self,
718
714
layer ._k_scale ,
719
715
layer ._v_scale ,
720
716
)
721
- # assert key_cache.sum() - previous_key_cache_sum == key.sum(), "key_cache sum mismatch"
722
- # assert value_cache.sum() - previous_value_cache_sum == value.sum(), "value_cache sum mismatch"
723
- # if key_cache.sum() - previous_key_cache_sum != key.sum():
724
- # print("key_cache sum mismatch")
725
- # if value_cache.sum() - previous_value_cache_sum != value.sum():
726
- # print("value_cache sum mismatch")
727
717
728
718
def forward_generate_kv_cache (
729
719
self ,
0 commit comments