@@ -641,6 +641,7 @@ def batch_prefill_paged_kv(
641
641
if T .tvm_thread_invariant (batch_idx [0 ] < batch_size ):
642
642
b_idx : T .int32 = batch_idx [0 ]
643
643
LH_start : T .int32 = tile_id [0 ] * tile_x
644
+ q_indptr_val : T .int32 = q_indptr [b_idx ]
644
645
645
646
cur_page_indptr_begin : T .int32 = page_indptr [b_idx ]
646
647
cur_page_indptr_end : T .int32 = page_indptr [b_idx + 1 ]
@@ -670,7 +671,7 @@ def batch_prefill_paged_kv(
670
671
i , j = T .axis .remap ("SS" , [li , lj ])
671
672
T .reads ()
672
673
T .writes ()
673
- cur_L = q_indptr [ b_idx ] + (LH_start + i ) // group_size
674
+ cur_L = q_indptr_val + (LH_start + i ) // group_size
674
675
cur_H_qo = by * group_size + (LH_start + i ) % group_size
675
676
if cur_L < q_indptr [b_idx + 1 ]:
676
677
Q_smem [i , j ] = T .if_then_else (
@@ -1316,6 +1317,7 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches
1316
1317
1317
1318
if T .tvm_thread_invariant (batch_idx [0 ] < batch_size ):
1318
1319
b_idx : T .int32 = batch_idx [0 ]
1320
+ q_indptr_val : T .int32 = q_indptr [b_idx ]
1319
1321
LH_start : T .int32 = tile_id [0 ] * tile_x
1320
1322
1321
1323
kv_chunk_len [0 ] = kv_indptr [b_idx + 1 ] - kv_indptr [b_idx ]
@@ -1340,7 +1342,7 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches
1340
1342
i , j = T .axis .remap ("SS" , [li , lj ])
1341
1343
T .reads ()
1342
1344
T .writes ()
1343
- cur_L = q_indptr [ b_idx ] + (LH_start + i ) // group_size
1345
+ cur_L = q_indptr_val + (LH_start + i ) // group_size
1344
1346
cur_H_qo = by * group_size + (LH_start + i ) % group_size
1345
1347
if cur_L < q_indptr [b_idx + 1 ]:
1346
1348
Q_smem [i , j ] = T .if_then_else (
0 commit comments