Skip to content

Commit 22dd9c2

Browse files
authored
[Kernel] Optimize Prefill Attention in Unified Triton Attention Kernel (#20308)
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
1 parent a6d795d commit 22dd9c2

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

vllm/attention/ops/triton_unified_attention.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,19 @@ def kernel_unified_attention_2d(
145145
mask=query_mask_1,
146146
other=0.0)
147147

148-
num_blocks = cdiv_fn(seq_len, BLOCK_SIZE)
148+
# compute the length of the longest sequence prefix spanned by any
149+
# query token in the current q_block (q_block_local_idx)
150+
max_seq_prefix_len = context_len + q_block_local_idx * BLOCK_Q + (
151+
BLOCK_M - 1) // num_queries_per_kv + 1
152+
153+
# adjust for potential padding in the last q_block by considering the
154+
# actual sequence length
155+
max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len)
156+
157+
# calculate the number of tiles (blocks) that need to be processed to
158+
# cover the longest sequence prefix (due to causal masking, blocks beyond
159+
# this prefix can be skipped)
160+
num_blocks = cdiv_fn(max_seq_prefix_len, BLOCK_SIZE)
149161

150162
# iterate through tiles
151163
for j in range(0, num_blocks):

0 commit comments

Comments
 (0)