@@ -637,8 +637,8 @@ def _test_ragged_paged_attention(
637
637
sm_scale = 1.0 ,
638
638
sliding_window = None ,
639
639
soft_cap = None ,
640
- num_kv_pages_per_block = 16 ,
641
- num_queries_per_block = 128 ,
640
+ num_kv_pages_per_block = None ,
641
+ num_queries_per_block = None ,
642
642
pad_tokens_and_seqs = False ,
643
643
use_dynamo = True ,
644
644
):
@@ -751,16 +751,6 @@ def ragged_paged_attention_wrapper(
751
751
num_seqs_jax = jnp .array ([num_seqs ], dtype = jnp .int32 )
752
752
753
753
from torch_xla .experimental .pallas_kernels .ragged_paged_attention_v2 import ragged_paged_attention as jax_ragged_paged_attention
754
- from torch_xla .experimental .tuned_block_sizes import get_ragged_attention_tuned_block_size
755
- if num_kv_pages_per_block is None :
756
- assert num_queries_per_block is None
757
- token_num , q_head_num , _ = q .shape
758
- _ , page_size , num_combined_kv_heads , _ = kv_pages .shape
759
- _ , pages_per_seq = page_indices .shape
760
- num_kv_heads = num_combined_kv_heads // 2
761
- max_model_len = pages_per_seq * page_size
762
- num_kv_pages_per_block , num_queries_per_block = get_ragged_attention_tuned_block_size (
763
- q_head_num , num_kv_heads , token_num , max_model_len )
764
754
jax_kernel_output = torch .from_numpy (
765
755
np .array (
766
756
jax_ragged_paged_attention (
@@ -790,8 +780,7 @@ def ragged_paged_attention_wrapper(
790
780
sm_scale = [1.0 , 0.5 ],
791
781
sliding_window = [None , 128 ],
792
782
soft_cap = [None , 10.0 ],
793
- pad_tokens_and_seqs = [False , True ],
794
- block_sizes = [(16 , 128 ), (None , None )])
783
+ pad_tokens_and_seqs = [False , True ])
795
784
@unittest .skipIf (xr .device_type () != 'TPU' or tpu .version () < 4 ,
796
785
"This test only works on TPUv4+." )
797
786
def test_ragged_paged_attention_wrapper_with_dynamo (
@@ -803,12 +792,10 @@ def test_ragged_paged_attention_wrapper_with_dynamo(
803
792
sliding_window ,
804
793
soft_cap ,
805
794
pad_tokens_and_seqs ,
806
- block_sizes ,
807
795
):
808
796
head_dim = 128
809
797
page_size = 16
810
798
num_pages = 1000
811
- num_kv_pages_per_block , num_queries_per_block = block_sizes
812
799
813
800
self ._test_ragged_paged_attention (
814
801
seq_lens ,
@@ -822,8 +809,6 @@ def test_ragged_paged_attention_wrapper_with_dynamo(
822
809
soft_cap = soft_cap ,
823
810
pad_tokens_and_seqs = pad_tokens_and_seqs ,
824
811
use_dynamo = True ,
825
- num_kv_pages_per_block = num_kv_pages_per_block ,
826
- num_queries_per_block = num_queries_per_block ,
827
812
)
828
813
829
814
@parameterized .product (
@@ -834,7 +819,6 @@ def test_ragged_paged_attention_wrapper_with_dynamo(
834
819
sliding_window = [None , 128 ],
835
820
soft_cap = [None , 10.0 ],
836
821
pad_tokens_and_seqs = [False , True ],
837
- block_sizes = [(16 , 128 ), (None , None )],
838
822
)
839
823
@unittest .skipIf (xr .device_type () != 'TPU' or tpu .version () < 4 ,
840
824
"This test only works on TPUv4+." )
@@ -847,12 +831,10 @@ def test_ragged_paged_attention_wrapper_without_dynamo(
847
831
sliding_window ,
848
832
soft_cap ,
849
833
pad_tokens_and_seqs ,
850
- block_sizes ,
851
834
):
852
835
head_dim = 128
853
836
page_size = 16
854
837
num_pages = 1000
855
- num_kv_pages_per_block , num_queries_per_block = block_sizes
856
838
857
839
self ._test_ragged_paged_attention (
858
840
seq_lens ,
@@ -866,8 +848,6 @@ def test_ragged_paged_attention_wrapper_without_dynamo(
866
848
soft_cap = soft_cap ,
867
849
pad_tokens_and_seqs = pad_tokens_and_seqs ,
868
850
use_dynamo = False ,
869
- num_kv_pages_per_block = num_kv_pages_per_block ,
870
- num_queries_per_block = num_queries_per_block ,
871
851
)
872
852
873
853
@unittest .skipIf (xr .device_type () != 'TPU' or tpu .version () < 4 ,
0 commit comments