Skip to content

Commit c4b45a9

Browse files
authored
Use new tuned table (#9041)
1 parent 14256e6 commit c4b45a9

File tree

4 files changed

+533
-137
lines changed

4 files changed

+533
-137
lines changed

test/test_pallas.py

+3-23
Original file line numberDiff line numberDiff line change
@@ -637,8 +637,8 @@ def _test_ragged_paged_attention(
637637
sm_scale=1.0,
638638
sliding_window=None,
639639
soft_cap=None,
640-
num_kv_pages_per_block=16,
641-
num_queries_per_block=128,
640+
num_kv_pages_per_block=None,
641+
num_queries_per_block=None,
642642
pad_tokens_and_seqs=False,
643643
use_dynamo=True,
644644
):
@@ -751,16 +751,6 @@ def ragged_paged_attention_wrapper(
751751
num_seqs_jax = jnp.array([num_seqs], dtype=jnp.int32)
752752

753753
from torch_xla.experimental.pallas_kernels.ragged_paged_attention_v2 import ragged_paged_attention as jax_ragged_paged_attention
754-
from torch_xla.experimental.tuned_block_sizes import get_ragged_attention_tuned_block_size
755-
if num_kv_pages_per_block is None:
756-
assert num_queries_per_block is None
757-
token_num, q_head_num, _ = q.shape
758-
_, page_size, num_combined_kv_heads, _ = kv_pages.shape
759-
_, pages_per_seq = page_indices.shape
760-
num_kv_heads = num_combined_kv_heads // 2
761-
max_model_len = pages_per_seq * page_size
762-
num_kv_pages_per_block, num_queries_per_block = get_ragged_attention_tuned_block_size(
763-
q_head_num, num_kv_heads, token_num, max_model_len)
764754
jax_kernel_output = torch.from_numpy(
765755
np.array(
766756
jax_ragged_paged_attention(
@@ -790,8 +780,7 @@ def ragged_paged_attention_wrapper(
790780
sm_scale=[1.0, 0.5],
791781
sliding_window=[None, 128],
792782
soft_cap=[None, 10.0],
793-
pad_tokens_and_seqs=[False, True],
794-
block_sizes=[(16, 128), (None, None)])
783+
pad_tokens_and_seqs=[False, True])
795784
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
796785
"This test only works on TPUv4+.")
797786
def test_ragged_paged_attention_wrapper_with_dynamo(
@@ -803,12 +792,10 @@ def test_ragged_paged_attention_wrapper_with_dynamo(
803792
sliding_window,
804793
soft_cap,
805794
pad_tokens_and_seqs,
806-
block_sizes,
807795
):
808796
head_dim = 128
809797
page_size = 16
810798
num_pages = 1000
811-
num_kv_pages_per_block, num_queries_per_block = block_sizes
812799

813800
self._test_ragged_paged_attention(
814801
seq_lens,
@@ -822,8 +809,6 @@ def test_ragged_paged_attention_wrapper_with_dynamo(
822809
soft_cap=soft_cap,
823810
pad_tokens_and_seqs=pad_tokens_and_seqs,
824811
use_dynamo=True,
825-
num_kv_pages_per_block=num_kv_pages_per_block,
826-
num_queries_per_block=num_queries_per_block,
827812
)
828813

829814
@parameterized.product(
@@ -834,7 +819,6 @@ def test_ragged_paged_attention_wrapper_with_dynamo(
834819
sliding_window=[None, 128],
835820
soft_cap=[None, 10.0],
836821
pad_tokens_and_seqs=[False, True],
837-
block_sizes=[(16, 128), (None, None)],
838822
)
839823
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
840824
"This test only works on TPUv4+.")
@@ -847,12 +831,10 @@ def test_ragged_paged_attention_wrapper_without_dynamo(
847831
sliding_window,
848832
soft_cap,
849833
pad_tokens_and_seqs,
850-
block_sizes,
851834
):
852835
head_dim = 128
853836
page_size = 16
854837
num_pages = 1000
855-
num_kv_pages_per_block, num_queries_per_block = block_sizes
856838

857839
self._test_ragged_paged_attention(
858840
seq_lens,
@@ -866,8 +848,6 @@ def test_ragged_paged_attention_wrapper_without_dynamo(
866848
soft_cap=soft_cap,
867849
pad_tokens_and_seqs=pad_tokens_and_seqs,
868850
use_dynamo=False,
869-
num_kv_pages_per_block=num_kv_pages_per_block,
870-
num_queries_per_block=num_queries_per_block,
871851
)
872852

873853
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,

torch_xla/experimental/custom_kernel.py

-11
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from torch_xla.distributed.spmd import Mesh
1010
import torch_xla.distributed.spmd as xs
1111
from torch_xla._internal.jax_workarounds import requires_jax
12-
from torch_xla.experimental.tuned_block_sizes import get_ragged_attention_tuned_block_size
1312

1413
# Re-expose this API used that is referenced by docs
1514
from torch_xla._internal.jax_workarounds import jax_import_guard # noqa: F401, pylint: disable=unused-import
@@ -990,16 +989,6 @@ def ragged_paged_attention(
990989
# in the global scope which could cause problems for xmp.spawn.
991990
from torch_xla.experimental.pallas_kernels.ragged_paged_attention_v2 import ragged_paged_attention as ragged_attention
992991

993-
if num_kv_pages_per_block is None:
994-
assert num_queries_per_block is None
995-
token_num, q_head_num, _ = q.shape
996-
_, page_size, num_combined_kv_heads, _ = kv_pages.shape
997-
_, pages_per_seq = page_indices.shape
998-
num_kv_heads = num_combined_kv_heads // 2
999-
max_model_len = pages_per_seq * page_size
1000-
num_kv_pages_per_block, num_queries_per_block = get_ragged_attention_tuned_block_size(
1001-
q_head_num, num_kv_heads, token_num, max_model_len)
1002-
1003992
if vmem_limit_bytes is None:
1004993
vmem_limit_bytes = 64 * 1024 * 1024
1005994

0 commit comments

Comments
 (0)