Skip to content

Commit 2675e68

Browse files
authored
Add sm_scale in ragged attention kernel (#8771)
1 parent 2e4f073 commit 2675e68

File tree

3 files changed

+36
-14
lines changed

3 files changed

+36
-14
lines changed

test/test_pallas.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -746,6 +746,7 @@ def _verify_ragged_paged_attention_with_dynamo(
746746
num_kv_pages_per_block,
747747
num_queries_per_block,
748748
pad_num_q_tokens=False,
749+
sm_scale=1.0,
749750
):
750751
num_seqs = len(seq_lens)
751752
q, k_pages, v_pages, page_indices, cu_q_lens, kv_lens = self._ragged_pagedattention_generate_qkv(
@@ -768,7 +769,8 @@ def _verify_ragged_paged_attention_with_dynamo(
768769
def ragged_paged_attention_wrapper(q, k_pages, v_pages, kv_lens,
769770
page_indices, cu_q_lens, num_seqs,
770771
num_kv_pages_per_block,
771-
num_queries_per_block, use_kernel):
772+
num_queries_per_block, use_kernel,
773+
sm_scale):
772774
return torch.ops.xla.ragged_paged_attention(
773775
q,
774776
k_pages,
@@ -780,6 +782,7 @@ def ragged_paged_attention_wrapper(q, k_pages, v_pages, kv_lens,
780782
num_kv_pages_per_block,
781783
num_queries_per_block,
782784
use_kernel=use_kernel,
785+
sm_scale=sm_scale,
783786
)
784787

785788
compiled_paged_attention = torch.compile(
@@ -796,6 +799,7 @@ def ragged_paged_attention_wrapper(q, k_pages, v_pages, kv_lens,
796799
num_kv_pages_per_block=num_kv_pages_per_block,
797800
num_queries_per_block=num_queries_per_block,
798801
use_kernel=True,
802+
sm_scale=sm_scale,
799803
)
800804

801805
nonkernel_output = compiled_paged_attention(
@@ -809,6 +813,7 @@ def ragged_paged_attention_wrapper(q, k_pages, v_pages, kv_lens,
809813
num_kv_pages_per_block=num_kv_pages_per_block,
810814
num_queries_per_block=num_queries_per_block,
811815
use_kernel=False,
816+
sm_scale=sm_scale,
812817
)
813818

814819
kernel_output_cpu = kernel_output.cpu()
@@ -836,6 +841,7 @@ def ragged_paged_attention_wrapper(q, k_pages, v_pages, kv_lens,
836841
num_seqs=num_seqs,
837842
num_kv_pages_per_block=num_kv_pages_per_block,
838843
num_queries_per_block=num_queries_per_block,
844+
sm_scale=sm_scale,
839845
)[1]))
840846
jax_kernel_output_cpu = jax_kernel_output.cpu()
841847

@@ -845,21 +851,21 @@ def ragged_paged_attention_wrapper(q, k_pages, v_pages, kv_lens,
845851
torch.allclose(
846852
kernel_output_cpu[:actual_num_q_tokens],
847853
nonkernel_output_cpu[:actual_num_q_tokens],
848-
atol=2e-1,
854+
atol=2e-2,
849855
rtol=1e-2))
850856
self.assertTrue(
851857
torch.allclose(
852858
kernel_output_cpu[:actual_num_q_tokens],
853859
jax_kernel_output_cpu[:actual_num_q_tokens],
854-
atol=2e-1,
860+
atol=2e-2,
855861
rtol=1e-2))
856862
else:
857863
self.assertTrue(
858864
torch.allclose(
859-
kernel_output_cpu, nonkernel_output_cpu, atol=2e-1, rtol=1e-2))
865+
kernel_output_cpu, nonkernel_output_cpu, atol=2e-2, rtol=1e-2))
860866
self.assertTrue(
861867
torch.allclose(
862-
kernel_output_cpu, jax_kernel_output_cpu, atol=2e-1, rtol=1e-2))
868+
kernel_output_cpu, jax_kernel_output_cpu, atol=2e-2, rtol=1e-2))
863869

864870
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
865871
"This test only works on TPUv4+.")
@@ -882,6 +888,7 @@ def test_ragged_paged_attention_wrapper_no_query_padding_with_dynamo(self):
882888
dtype = torch.float32
883889
page_size = 16
884890
num_pages = 32768
891+
sm_scale = head_dim**-0.5
885892

886893
self._verify_ragged_paged_attention_with_dynamo(
887894
seq_lens,
@@ -892,6 +899,7 @@ def test_ragged_paged_attention_wrapper_no_query_padding_with_dynamo(self):
892899
dtype,
893900
num_kv_pages_per_block=128,
894901
num_queries_per_block=8,
902+
sm_scale=sm_scale,
895903
)
896904

897905
@parameterized.product(
@@ -910,6 +918,7 @@ def test_ragged_paged_attention_wrapper_with_query_padding_with_dynamo(
910918
dtype = torch.float32
911919
page_size = 16
912920
num_pages = 32768
921+
sm_scale = head_dim**-0.5
913922

914923
self._verify_ragged_paged_attention_with_dynamo(
915924
seq_lens,
@@ -921,6 +930,7 @@ def test_ragged_paged_attention_wrapper_with_query_padding_with_dynamo(
921930
num_kv_pages_per_block=128,
922931
num_queries_per_block=num_queries_per_block,
923932
pad_num_q_tokens=True,
933+
sm_scale=sm_scale,
924934
)
925935

926936
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,

torch_xla/experimental/custom_kernel.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -708,6 +708,7 @@ def _ragged_paged_attention_nonkernel(
708708
page_indices, # i32[num_tokens, pages_per_sequence]
709709
cu_q_lens, # i32[num_tokens + 1]
710710
num_seqs, # int
711+
sm_scale, # float
711712
):
712713
_, num_q_heads, head_dim = queries.shape
713714
num_kv_heads, total_num_pages, page_size, _ = k_pages.shape
@@ -751,6 +752,7 @@ def _ragged_paged_attention_nonkernel(
751752
attn = torch.einsum("qhd,khd->hqk", q,
752753
k) # [num_query_heads, cur_q_len, kv_len]
753754
attn = attn.float()
755+
attn = attn * sm_scale
754756
empty_mask = torch.ones(cur_q_len, cur_kv_len, device=attn.device)
755757
mask = torch.triu(empty_mask, diagonal=cur_kv_len - cur_q_len + 1).bool()
756758
attn.masked_fill_(mask, float("-inf"))
@@ -784,6 +786,7 @@ def ragged_paged_attention(
784786
num_kv_pages_per_block,
785787
num_queries_per_block,
786788
use_kernel=True,
789+
sm_scale=1.0,
787790
# TODO(jevinjiang, xiowei): add attn_logits_soft_cap.
788791
# attn_logits_soft_cap: float | None = None,
789792
): # [batch_size, query_len, num_heads, head_dim]:
@@ -797,6 +800,7 @@ def ragged_paged_attention(
797800
page_indices,
798801
cu_q_lens,
799802
num_seqs,
803+
sm_scale,
800804
)
801805

802806
# Import JAX within the function such that we don't need to call the jax_import_guard()
@@ -813,11 +817,13 @@ def ragged_paged_attention(
813817
num_seqs=num_seqs,
814818
num_kv_pages_per_block=num_kv_pages_per_block,
815819
num_queries_per_block=num_queries_per_block,
820+
sm_scale=sm_scale,
816821
static_argnames=[
817822
"num_kv_pages_per_block",
818823
"num_queries_per_block",
819824
"mask_value",
820825
"num_seqs",
826+
"sm_scale",
821827
],
822828
)
823829

@@ -1541,28 +1547,27 @@ def multi_queries_paged_attention_non_xla(q: torch.Tensor,
15411547

15421548

15431549
XLA_LIB.define(
1544-
"ragged_paged_attention(Tensor q, Tensor k_pages, Tensor v_pages, Tensor kv_lens, Tensor page_indices, Tensor cu_q_lens, int num_seqs, int num_kv_pages_per_block, int num_queries_per_block, bool use_kernel) -> Tensor",
1550+
"ragged_paged_attention(Tensor q, Tensor k_pages, Tensor v_pages, Tensor kv_lens, Tensor page_indices, Tensor cu_q_lens, int num_seqs, int num_kv_pages_per_block, int num_queries_per_block, bool use_kernel, float sm_scale) -> Tensor",
15451551
)
15461552

15471553

15481554
@impl(XLA_LIB, "ragged_paged_attention", "XLA")
1549-
def ragged_paged_attention_xla(q: torch.Tensor, k_pages: torch.Tensor,
1550-
v_pages: torch.Tensor, kv_lens: torch.Tensor,
1551-
page_indices: torch.Tensor,
1552-
cu_q_lens: torch.Tensor, num_seqs: int,
1553-
num_kv_pages_per_block: int,
1554-
num_queries_per_block: int, use_kernel: bool):
1555+
def ragged_paged_attention_xla(
1556+
q: torch.Tensor, k_pages: torch.Tensor, v_pages: torch.Tensor,
1557+
kv_lens: torch.Tensor, page_indices: torch.Tensor, cu_q_lens: torch.Tensor,
1558+
num_seqs: int, num_kv_pages_per_block: int, num_queries_per_block: int,
1559+
use_kernel: bool, sm_scale: float):
15551560
return ragged_paged_attention(q, k_pages, v_pages, kv_lens, page_indices,
15561561
cu_q_lens, num_seqs, num_kv_pages_per_block,
1557-
num_queries_per_block, use_kernel)
1562+
num_queries_per_block, use_kernel, sm_scale)
15581563

15591564

15601565
@impl(XLA_LIB, "ragged_paged_attention", "CompositeExplicitAutograd")
15611566
def ragged_paged_attention_non_xla(
15621567
q: torch.Tensor, k_pages: torch.Tensor, v_pages: torch.Tensor,
15631568
kv_lens: torch.Tensor, page_indices: torch.Tensor, cu_q_lens: torch.Tensor,
15641569
num_seqs: int, num_kv_pages_per_block: int, num_queries_per_block: int,
1565-
use_kernel: bool):
1570+
use_kernel: bool, sm_scale: float):
15661571
return non_xla_attetion(q, k_pages, v_pages, "paged")
15671572

15681573

torch_xla/experimental/pallas_kernels/ragged_paged_attention_kernel.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,7 @@ def _flash_attention(
355355
page_size: int,
356356
head_dim: int,
357357
num_q_heads_per_kv_head: int,
358+
sm_scale: float,
358359
):
359360
assert q_ref.shape == (num_q_heads_per_kv_head, num_queries_per_block,
360361
head_dim)
@@ -405,6 +406,7 @@ def init_scratch_ref(): # pylint: disable=unused-variable
405406
'qd,td->qt', q, k,
406407
preferred_element_type=jnp.float32) # [block_q, block_k]
407408
assert s.shape == (num_queries_per_block, kv_blk_size)
409+
s = s * sm_scale
408410

409411
# Modify the mask accordingly: first form the mask. Then move the mask up/down to the right place.
410412
cur_seq_idx = seq_ids[logical_q_blk_idx]
@@ -597,6 +599,7 @@ def paged_flash_attention_kernel(
597599
num_seqs: int,
598600
num_kv_pages_per_block: int,
599601
mask_value: float,
602+
sm_scale: float,
600603
):
601604
kv_head_idx, logical_q_blk_idx, kv_blk_idx = (
602605
pl.program_id(0),
@@ -704,6 +707,7 @@ def prefetch_next_block(): # pylint: disable=unused-variable
704707
page_size=page_size,
705708
head_dim=head_dim,
706709
num_q_heads_per_kv_head=num_q_heads_per_kv_head,
710+
sm_scale=sm_scale,
707711
)
708712
step_ref[0] = step + 1
709713
# end of get_kv_and_run_flash_attention
@@ -724,6 +728,7 @@ def _round_up_to_multiple_of_tm(x, tm):
724728
"num_queries_per_block",
725729
"mask_value",
726730
"num_seqs",
731+
"sm_scale",
727732
],
728733
)
729734
def ragged_paged_attention(
@@ -738,6 +743,7 @@ def ragged_paged_attention(
738743
mask_value: float = DEFAULT_MASK_VALUE,
739744
num_kv_pages_per_block: int = 128,
740745
num_queries_per_block: int = 128,
746+
sm_scale: float = 1.0,
741747
) -> jax.Array:
742748
"""Paged attention kernel with ragged input.
743749
@@ -940,6 +946,7 @@ def next_kv_blk_page_indices_index_map(kv_head_idx, logical_q_blk_idx,
940946
num_seqs=num_seqs,
941947
num_kv_pages_per_block=num_kv_pages_per_block,
942948
mask_value=mask_value,
949+
sm_scale=sm_scale,
943950
),
944951
grid_spec=pltpu.PrefetchScalarGridSpec(
945952
num_scalar_prefetch=6,

0 commit comments

Comments
 (0)