@@ -353,8 +353,9 @@ def _plan(self, num_prefills: int, num_decodes: int,
353
353
attn_metadata .decode_wrapper = self ._get_decode_wrapper ()
354
354
if not FlashInferBackend .use_trtllm_decode_attention (
355
355
num_decodes , attn_metadata .max_seq_len ,
356
- attn_metadata .kv_data_type , attn_metadata .num_qo_heads ,
357
- attn_metadata .num_kv_heads , attn_metadata .head_dim ):
356
+ self .cache_config .cache_dtype ,
357
+ attn_metadata .num_qo_heads , attn_metadata .num_kv_heads ,
358
+ attn_metadata .head_dim ):
358
359
attn_metadata .decode_wrapper .plan (
359
360
attn_metadata .paged_kv_indptr [:num_decodes + 1 ],
360
361
attn_metadata .paged_kv_indices ,
@@ -539,10 +540,10 @@ def forward(
539
540
query: shape = [num_tokens, num_heads, head_size]
540
541
key: shape = [num_tokens, num_kv_heads, head_size]
541
542
value: shape = [num_tokens, num_kv_heads, head_size]
542
- kv_cache: shape -
543
+ kv_cache: shape -
543
544
# NHD: [num_blocks, 2, block_size, num_kv_heads, head_size]
544
545
# HND: [num_blocks, 2, num_kv_heads, block_size, head_size]
545
-
546
+
546
547
547
548
attn_metadata: Metadata for attention.
548
549
Returns:
@@ -614,6 +615,7 @@ def forward(
614
615
num_prefill_tokens = attn_metadata .num_prefill_tokens
615
616
616
617
stride_order = FlashInferBackend .get_kv_cache_stride_order ()
618
+ kv_cache_permute = kv_cache .permute (* stride_order )
617
619
# Regular attention (common case).
618
620
# Decodes are at the front and prefills are at the back,
619
621
# according to reorder_batch()
@@ -628,7 +630,7 @@ def forward(
628
630
assert prefill_wrapper ._sm_scale == self .scale
629
631
prefill_wrapper .run (
630
632
prefill_query ,
631
- kv_cache . permute ( * stride_order ) ,
633
+ kv_cache_permute ,
632
634
k_scale = layer ._k_scale_float ,
633
635
v_scale = layer ._v_scale_float ,
634
636
out = output [num_decode_tokens :],
@@ -647,27 +649,37 @@ def forward(
647
649
assert decode_wrapper ._sm_scale == self .scale
648
650
decode_wrapper .run (
649
651
decode_query ,
650
- kv_cache . permute ( * stride_order ) ,
652
+ kv_cache_permute ,
651
653
k_scale = layer ._k_scale_float ,
652
654
v_scale = layer ._v_scale_float ,
653
655
out = output [:num_decode_tokens ],
654
656
)
655
657
else :
656
658
# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
657
659
if num_decode_tokens > 0 :
660
+ # decode_query may be non-contiguous
661
+ decode_query = decode_query .contiguous ()
662
+ block_tables_decode = attn_metadata .block_table_tensor [:
663
+ num_decode_tokens ]
664
+ seq_lens_decode = attn_metadata .seq_lens [:
665
+ num_decode_tokens ]
666
+
658
667
assert get_kv_cache_layout () == "HND"
668
+ assert decode_query .is_contiguous ()
669
+ assert kv_cache_permute .is_contiguous ()
670
+ assert block_tables_decode .is_contiguous ()
671
+ assert seq_lens_decode .is_contiguous ()
672
+
659
673
output [:num_decode_tokens ] = (
660
674
trtllm_batch_decode_with_kv_cache (
661
675
query = decode_query ,
662
- kv_cache = kv_cache . permute ( * stride_order ) ,
676
+ kv_cache = kv_cache_permute ,
663
677
workspace_buffer = attn_metadata .workspace_buffer ,
664
678
num_heads = self .num_heads ,
665
679
num_kv_heads = self .num_kv_heads ,
666
680
scale = self .scale ,
667
- block_tables = attn_metadata .
668
- block_table_tensor [:num_decode_tokens ],
669
- seq_lens = attn_metadata .
670
- seq_lens [:num_decode_tokens ],
681
+ block_tables = block_tables_decode ,
682
+ seq_lens = seq_lens_decode ,
671
683
block_size = attn_metadata .page_size ,
672
684
max_seq_len = attn_metadata .max_seq_len ,
673
685
kv_cache_dtype = self .kv_cache_dtype ,
0 commit comments