Skip to content

Commit 6bd049e

Browse files
authored
[Fix] Fix attn kernel build issue (#2545)
This PR fixes TIR issues in the attn kernels.
1 parent fcb50a2 commit 6bd049e

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

python/mlc_llm/nn/kv_cache.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -641,6 +641,7 @@ def batch_prefill_paged_kv(
641641
if T.tvm_thread_invariant(batch_idx[0] < batch_size):
642642
b_idx: T.int32 = batch_idx[0]
643643
LH_start: T.int32 = tile_id[0] * tile_x
644+
q_indptr_val: T.int32 = q_indptr[b_idx]
644645

645646
cur_page_indptr_begin: T.int32 = page_indptr[b_idx]
646647
cur_page_indptr_end: T.int32 = page_indptr[b_idx + 1]
@@ -670,7 +671,7 @@ def batch_prefill_paged_kv(
670671
i, j = T.axis.remap("SS", [li, lj])
671672
T.reads()
672673
T.writes()
673-
cur_L = q_indptr[b_idx] + (LH_start + i) // group_size
674+
cur_L = q_indptr_val + (LH_start + i) // group_size
674675
cur_H_qo = by * group_size + (LH_start + i) % group_size
675676
if cur_L < q_indptr[b_idx + 1]:
676677
Q_smem[i, j] = T.if_then_else(
@@ -1316,6 +1317,7 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches
13161317

13171318
if T.tvm_thread_invariant(batch_idx[0] < batch_size):
13181319
b_idx: T.int32 = batch_idx[0]
1320+
q_indptr_val: T.int32 = q_indptr[b_idx]
13191321
LH_start: T.int32 = tile_id[0] * tile_x
13201322

13211323
kv_chunk_len[0] = kv_indptr[b_idx + 1] - kv_indptr[b_idx]
@@ -1340,7 +1342,7 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches
13401342
i, j = T.axis.remap("SS", [li, lj])
13411343
T.reads()
13421344
T.writes()
1343-
cur_L = q_indptr[b_idx] + (LH_start + i) // group_size
1345+
cur_L = q_indptr_val + (LH_start + i) // group_size
13441346
cur_H_qo = by * group_size + (LH_start + i) % group_size
13451347
if cur_L < q_indptr[b_idx + 1]:
13461348
Q_smem[i, j] = T.if_then_else(

python/mlc_llm/op/tree_attn.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ def batch_tree_attn( # pylint: disable=too-many-branches
170170
if T.tvm_thread_invariant(batch_idx[0] < batch_size):
171171
b_idx: T.int32 = batch_idx[0]
172172
LH_start: T.int32 = tile_id[0] * tile_x
173+
q_indptr_val: T.int32 = q_indptr[b_idx]
173174

174175
kv_chunk_len[0] = kv_indptr[b_idx + 1] - kv_indptr[b_idx]
175176
T.tvm_storage_sync("shared")
@@ -193,7 +194,7 @@ def batch_tree_attn( # pylint: disable=too-many-branches
193194
i, j = T.axis.remap("SS", [li, lj])
194195
T.reads()
195196
T.writes()
196-
cur_L = q_indptr[b_idx] + (LH_start + i) // group_size
197+
cur_L = q_indptr_val + (LH_start + i) // group_size
197198
cur_H_qo = by * group_size + (LH_start + i) % group_size
198199
if cur_L < q_indptr[b_idx + 1]:
199200
Q_smem[i, j] = T.if_then_else(

0 commit comments

Comments
 (0)