Skip to content

Commit b24e6e9

Browse files
authored
Remove dynamic grid (#8896)
1 parent 9287337 commit b24e6e9

File tree

3 files changed

+5
-18
lines changed

3 files changed

+5
-18
lines changed

test/test_pallas.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -669,10 +669,6 @@ def _test_ragged_paged_attention(
669669
page_indices_xla = page_indices.to("xla")
670670
cu_q_lens_xla = cu_q_lens.to("xla")
671671
num_seqs_xla = torch.tensor([num_seqs], dtype=torch.int32).to("xla")
672-
sliding_window = sliding_window
673-
soft_cap = soft_cap
674-
# Test mask_value
675-
mask_value = None
676672

677673
if use_dynamo:
678674

@@ -686,7 +682,6 @@ def ragged_paged_attention_wrapper(
686682
sm_scale=sm_scale,
687683
sliding_window=sliding_window,
688684
soft_cap=soft_cap,
689-
mask_value=mask_value,
690685
use_kernel=True,
691686
num_kv_pages_per_block=num_kv_pages_per_block,
692687
num_queries_per_block=num_queries_per_block,
@@ -701,7 +696,6 @@ def ragged_paged_attention_wrapper(
701696
sm_scale=sm_scale,
702697
sliding_window=sliding_window,
703698
soft_cap=soft_cap,
704-
mask_value=mask_value,
705699
use_kernel=use_kernel,
706700
num_kv_pages_per_block=num_kv_pages_per_block,
707701
num_queries_per_block=num_queries_per_block,
@@ -722,7 +716,6 @@ def ragged_paged_attention_wrapper(
722716
sm_scale=sm_scale,
723717
sliding_window=sliding_window,
724718
soft_cap=soft_cap,
725-
mask_value=mask_value,
726719
use_kernel=True,
727720
num_kv_pages_per_block=num_kv_pages_per_block,
728721
num_queries_per_block=num_queries_per_block,
@@ -738,7 +731,6 @@ def ragged_paged_attention_wrapper(
738731
sm_scale=sm_scale,
739732
sliding_window=sliding_window,
740733
soft_cap=soft_cap,
741-
mask_value=mask_value,
742734
use_kernel=False,
743735
)
744736

@@ -778,7 +770,6 @@ def ragged_paged_attention_wrapper(
778770
sm_scale=sm_scale,
779771
sliding_window=sliding_window,
780772
soft_cap=soft_cap,
781-
mask_value=mask_value,
782773
)[:cu_q_lens[num_seqs]].astype(jnp.float32))).to(dtype)
783774
jax_kernel_output_cpu = jax_kernel_output.cpu()
784775

torch_xla/experimental/custom_kernel.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1015,14 +1015,8 @@ def ragged_paged_attention(
10151015
)
10161016

10171017
seq_buf_idx = torch.tensor([0, 0], dtype=torch.int32).to("xla")
1018-
num_q_blks = torch.tensor(
1019-
[(cu_q_lens[num_seqs[0]] + num_queries_per_block - 1) //
1020-
num_queries_per_block],
1021-
dtype=torch.int32).to("xla")
1022-
10231018
output = torch_xla._XLAC._xla_tpu_custom_call(
10241019
[
1025-
num_q_blks, # dynamic grid
10261020
kv_lens,
10271021
page_indices,
10281022
cu_q_lens,

torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,9 @@ def prefetch_first_kv_blk():
295295

296296
def is_cur_q_blk_needed(q_states):
297297
done, cur_seq_idx, _ = q_states
298-
return jnp.logical_and(done == 0, cur_seq_idx < num_seqs)
298+
should_run = jnp.logical_and(q_len_start < cu_q_lens_ref[num_seqs],
299+
cur_seq_idx < num_seqs)
300+
return jnp.logical_and(done == 0, should_run)
299301

300302
def compute_with_cur_q_blk(q_states):
301303
done, cur_seq_idx, cur_buf_idx = q_states
@@ -640,14 +642,14 @@ def ragged_paged_attention(
640642
check_inputs_shapes(q, kv_pages, kv_lens, page_indices, cu_q_lens, num_seqs)
641643
if mask_value is None:
642644
mask_value = DEFAULT_MASK_VALUE
643-
_, num_q_heads, head_dim = q.shape
645+
num_q, num_q_heads, head_dim = q.shape
644646
_, page_size, num_combined_kv_heads, _ = kv_pages.shape
645647
assert num_combined_kv_heads % 2 == 0
646648
num_kv_heads = num_combined_kv_heads // 2
647649
num_q_per_blk = num_queries_per_block
648650
num_kv_pages_per_blk = num_kv_pages_per_block
649651
num_q_heads_per_kv_head = num_q_heads // num_kv_heads
650-
num_q_blks = cdiv(cu_q_lens[num_seqs[0]], num_q_per_blk)
652+
num_q_blks = cdiv(num_q, num_q_per_blk)
651653
num_q_heads_per_blk, num_combined_kv_heads_per_blk = get_min_heads_per_blk(
652654
num_q_heads, num_combined_kv_heads, q.dtype, kv_pages.dtype)
653655
assert num_combined_kv_heads_per_blk % 2 == 0

0 commit comments

Comments
 (0)