File tree Expand file tree Collapse file tree 1 file changed +13
-1
lines changed Expand file tree Collapse file tree 1 file changed +13
-1
lines changed Original file line number Diff line number Diff line change @@ -145,7 +145,19 @@ def kernel_unified_attention_2d(
145
145
mask = query_mask_1 ,
146
146
other = 0.0 )
147
147
148
- num_blocks = cdiv_fn (seq_len , BLOCK_SIZE )
148
+ # compute the length of the longest sequence prefix spanned by any
149
+ # query token in the current q_block (q_block_local_idx)
150
+ max_seq_prefix_len = context_len + q_block_local_idx * BLOCK_Q + (
151
+ BLOCK_M - 1 ) // num_queries_per_kv + 1
152
+
153
+ # adjust for potential padding in the last q_block by considering the
154
+ # actual sequence length
155
+ max_seq_prefix_len = tl .minimum (max_seq_prefix_len , seq_len )
156
+
157
+ # calculate the number of tiles (blocks) that need to be processed to
158
+ # cover the longest sequence prefix (due to causal masking, blocks beyond
159
+ # this prefix can be skipped)
160
+ num_blocks = cdiv_fn (max_seq_prefix_len , BLOCK_SIZE )
149
161
150
162
# iterate through tiles
151
163
for j in range (0 , num_blocks ):
You can’t perform that action at this time.
0 commit comments