Skip to content

Commit 8dfb45c

Browse files
authored
[Bugfix] Fix the tensor non-contiguous issue for Flashinfer TRT-LLM backend attention kernel (#21133)
1 parent 8a8fc94 commit 8dfb45c

File tree

1 file changed

+23
-11
lines changed

1 file changed

+23
-11
lines changed

vllm/v1/attention/backends/flashinfer.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -353,8 +353,9 @@ def _plan(self, num_prefills: int, num_decodes: int,
353353
attn_metadata.decode_wrapper = self._get_decode_wrapper()
354354
if not FlashInferBackend.use_trtllm_decode_attention(
355355
num_decodes, attn_metadata.max_seq_len,
356-
attn_metadata.kv_data_type, attn_metadata.num_qo_heads,
357-
attn_metadata.num_kv_heads, attn_metadata.head_dim):
356+
self.cache_config.cache_dtype,
357+
attn_metadata.num_qo_heads, attn_metadata.num_kv_heads,
358+
attn_metadata.head_dim):
358359
attn_metadata.decode_wrapper.plan(
359360
attn_metadata.paged_kv_indptr[:num_decodes + 1],
360361
attn_metadata.paged_kv_indices,
@@ -539,10 +540,10 @@ def forward(
539540
query: shape = [num_tokens, num_heads, head_size]
540541
key: shape = [num_tokens, num_kv_heads, head_size]
541542
value: shape = [num_tokens, num_kv_heads, head_size]
542-
kv_cache: shape -
543+
kv_cache: shape -
543544
# NHD: [num_blocks, 2, block_size, num_kv_heads, head_size]
544545
# HND: [num_blocks, 2, num_kv_heads, block_size, head_size]
545-
546+
546547
547548
attn_metadata: Metadata for attention.
548549
Returns:
@@ -614,6 +615,7 @@ def forward(
614615
num_prefill_tokens = attn_metadata.num_prefill_tokens
615616

616617
stride_order = FlashInferBackend.get_kv_cache_stride_order()
618+
kv_cache_permute = kv_cache.permute(*stride_order)
617619
# Regular attention (common case).
618620
# Decodes are at the front and prefills are at the back,
619621
# according to reorder_batch()
@@ -628,7 +630,7 @@ def forward(
628630
assert prefill_wrapper._sm_scale == self.scale
629631
prefill_wrapper.run(
630632
prefill_query,
631-
kv_cache.permute(*stride_order),
633+
kv_cache_permute,
632634
k_scale=layer._k_scale_float,
633635
v_scale=layer._v_scale_float,
634636
out=output[num_decode_tokens:],
@@ -647,27 +649,37 @@ def forward(
647649
assert decode_wrapper._sm_scale == self.scale
648650
decode_wrapper.run(
649651
decode_query,
650-
kv_cache.permute(*stride_order),
652+
kv_cache_permute,
651653
k_scale=layer._k_scale_float,
652654
v_scale=layer._v_scale_float,
653655
out=output[:num_decode_tokens],
654656
)
655657
else:
656658
# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
657659
if num_decode_tokens > 0:
660+
# decode_query may be non-contiguous
661+
decode_query = decode_query.contiguous()
662+
block_tables_decode = attn_metadata.block_table_tensor[:
663+
num_decode_tokens]
664+
seq_lens_decode = attn_metadata.seq_lens[:
665+
num_decode_tokens]
666+
658667
assert get_kv_cache_layout() == "HND"
668+
assert decode_query.is_contiguous()
669+
assert kv_cache_permute.is_contiguous()
670+
assert block_tables_decode.is_contiguous()
671+
assert seq_lens_decode.is_contiguous()
672+
659673
output[:num_decode_tokens] = (
660674
trtllm_batch_decode_with_kv_cache(
661675
query=decode_query,
662-
kv_cache=kv_cache.permute(*stride_order),
676+
kv_cache=kv_cache_permute,
663677
workspace_buffer=attn_metadata.workspace_buffer,
664678
num_heads=self.num_heads,
665679
num_kv_heads=self.num_kv_heads,
666680
scale=self.scale,
667-
block_tables=attn_metadata.
668-
block_table_tensor[:num_decode_tokens],
669-
seq_lens=attn_metadata.
670-
seq_lens[:num_decode_tokens],
681+
block_tables=block_tables_decode,
682+
seq_lens=seq_lens_decode,
671683
block_size=attn_metadata.page_size,
672684
max_seq_len=attn_metadata.max_seq_len,
673685
kv_cache_dtype=self.kv_cache_dtype,

0 commit comments

Comments
 (0)