Skip to content

Commit 04e1642

Browse files
authored
[TPU] add kv cache update kernel (#19928)
Signed-off-by: Chengji Yao <chengjiyao@google.com>
1 parent b69781f commit 04e1642

File tree

6 files changed

+342
-38
lines changed

6 files changed

+342
-38
lines changed

.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,8 @@ run_and_track_test 14 "test_tpu_qkv_linear.py" \
159159
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_tpu_qkv_linear.py"
160160
run_and_track_test 15 "test_spmd_model_weight_loading.py" \
161161
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_spmd_model_weight_loading.py"
162+
run_and_track_test 16 "test_kv_cache_update_kernel.py" \
163+
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_kv_cache_update_kernel.py"
162164
163165
# After all tests have been attempted, exit with the overall status.
164166
if [ "$overall_script_exit_code" -ne 0 ]; then
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import numpy as np
5+
import pytest
6+
import torch
7+
import torch_xla
8+
9+
import vllm.v1.attention.backends.pallas # noqa: F401
10+
from vllm.platforms import current_platform
11+
12+
13+
@pytest.mark.skipif(not current_platform.is_tpu(),
14+
reason="This is a test for TPU only")
15+
@pytest.mark.parametrize("page_size", [32, 33])
16+
@pytest.mark.parametrize("combined_kv_head_num", [2, 16])
17+
@pytest.mark.parametrize("head_dim", [128, 256])
18+
@pytest.mark.parametrize("num_slices_per_block", [4, 8])
19+
def test_kv_cache_update_kernel(page_size: int, combined_kv_head_num: int,
20+
head_dim: int, num_slices_per_block: int):
21+
page_num = 1000
22+
padded_num_tokens = 128
23+
kv_cache_cpu = torch.zeros(
24+
(page_num * page_size, combined_kv_head_num, head_dim),
25+
dtype=torch.bfloat16,
26+
device="cpu")
27+
kv_cache_xla = kv_cache_cpu.to(torch_xla.device())
28+
new_kv_cpu = torch.randn(
29+
(padded_num_tokens, combined_kv_head_num, head_dim),
30+
dtype=torch.bfloat16,
31+
device="cpu")
32+
new_kv_xla = new_kv_cpu.to(torch_xla.device())
33+
slice_lens = np.array([7, page_size, page_size, 1, 1, 1, 9],
34+
dtype=np.int32)
35+
kv_cache_start_indices = np.array([
36+
page_size * 2 - 7, page_size * 2, page_size * 3, page_size * 4 + 6,
37+
page_size * 5 + 7, page_size * 6 + 8, page_size * 15 + 3
38+
],
39+
dtype=np.int32)
40+
new_kv_cache_indices = np.concatenate(
41+
[np.array([0], dtype=np.int32),
42+
np.cumsum(slice_lens[:-1])])
43+
slot_mapping = np.stack(
44+
[kv_cache_start_indices, new_kv_cache_indices, slice_lens], axis=1)
45+
padded_size = (slot_mapping.shape[0] + num_slices_per_block -
46+
1) // num_slices_per_block * num_slices_per_block
47+
slot_mapping = np.pad(slot_mapping,
48+
[[0, padded_size - slot_mapping.shape[0]], [0, 0]],
49+
constant_values=0)
50+
slot_mapping = np.transpose(slot_mapping)
51+
slot_mapping_cpu = torch.tensor(slot_mapping,
52+
device="cpu",
53+
dtype=torch.int32)
54+
slot_mapping_xla = slot_mapping_cpu.to(torch_xla.device())
55+
torch_xla.sync()
56+
57+
torch.ops.xla.dynamo_set_buffer_donor_(kv_cache_xla, True)
58+
new_kv_cache_xla = torch.ops.xla.kv_cache_update_op(
59+
new_kv_xla, slot_mapping_xla, kv_cache_xla, page_size,
60+
num_slices_per_block)
61+
kv_cache_xla.copy_(new_kv_cache_xla)
62+
torch_xla.sync()
63+
64+
for ni, ci, sl in zip(new_kv_cache_indices, kv_cache_start_indices,
65+
slice_lens):
66+
kv_cache_cpu[ci:ci + sl, :, :] = new_kv_cpu[ni:ni + sl, :, :]
67+
68+
assert torch.allclose(kv_cache_xla.cpu(),
69+
kv_cache_cpu,
70+
atol=1e-4,
71+
rtol=1e-4)

tests/v1/tpu/test_pallas.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ class FakeAttentionLayer:
4747
key = torch.zeros(num_tokens, num_kv_heads * head_size)
4848
value = torch.zeros(num_tokens, num_kv_heads * head_size)
4949
kv_cache = torch.zeros(num_blocks, block_size, num_kv_heads * 2, head_size)
50-
slot_mapping = torch.zeros(num_tokens, dtype=torch.int64)
50+
slot_mapping = torch.zeros((3, num_tokens), dtype=torch.int64)
5151
max_num_reqs = 8
5252
max_num_blocks_per_req = 8
5353
block_tables = torch.zeros((max_num_reqs, max_num_blocks_per_req),
@@ -65,6 +65,7 @@ class FakeAttentionLayer:
6565
context_lens=context_lens,
6666
query_start_loc=query_start_loc,
6767
num_seqs=num_seqs,
68+
num_slices_per_kv_cache_update_block=8,
6869
)
6970

7071
with patch("torch.ops.xla.ragged_paged_attention"
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import functools
5+
6+
import jax
7+
from jax.experimental import pallas as pl
8+
from jax.experimental.pallas import tpu as pltpu
9+
10+
11+
def _kv_cache_update_kernel(
12+
# Prefetch
13+
slices_ref, # [3, num_slices], list of (kv_cache_start, new_kv_start,
14+
# slice_len)
15+
# Input
16+
new_kv_hbm_ref, # [num_tokens, num_combined_kv_heads, head_dim]
17+
kv_cache_hbm_ref, # [total_num_pages * page_size, num_combined_kv_heads,
18+
# head_dim]
19+
# Output
20+
_, # [total_num_pages * page_size, num_combined_kv_heads, head_dim]
21+
# Scratch
22+
scratch, # [num_slices_per_block, page_size, num_combined_kv_heads,
23+
# head_dim]
24+
sem,
25+
):
26+
async_copies = []
27+
block_idx = pl.program_id(0)
28+
num_slices_per_block = scratch.shape[0]
29+
30+
# Copy from new_kv_hbm_ref to scratch
31+
for i in range(num_slices_per_block):
32+
offset_i = i + block_idx * num_slices_per_block
33+
new_kv_start = slices_ref[1, offset_i]
34+
length = slices_ref[2, offset_i]
35+
async_copy = pltpu.make_async_copy(
36+
new_kv_hbm_ref.at[pl.ds(new_kv_start, length), ...],
37+
scratch.at[i, pl.ds(0, length), ...],
38+
sem,
39+
)
40+
async_copy.start()
41+
async_copies.append(async_copy)
42+
43+
for async_copy in async_copies:
44+
async_copy.wait()
45+
46+
# Copy from scratch to kv_cache_hbm_ref
47+
async_copies.clear()
48+
for i in range(num_slices_per_block):
49+
offset_i = i + block_idx * num_slices_per_block
50+
kv_cache_start = slices_ref[0, offset_i]
51+
length = slices_ref[2, offset_i]
52+
async_copy = pltpu.make_async_copy(
53+
scratch.at[i, pl.ds(0, length), ...],
54+
kv_cache_hbm_ref.at[pl.ds(kv_cache_start, length), ...],
55+
sem,
56+
)
57+
async_copy.start()
58+
async_copies.append(async_copy)
59+
for async_copy in async_copies:
60+
async_copy.wait()
61+
62+
63+
@functools.partial(
64+
jax.jit,
65+
static_argnames=["page_size", "num_slices_per_block"],
66+
)
67+
def kv_cache_update(
68+
new_kv: jax.Array, # [total_num_token, num_combined_kv_heads, head_dim]
69+
slices: jax.
70+
Array, # [3, slices], list of (kv_cache_start, new_kv_start, slice_len)
71+
kv_cache: jax.
72+
Array, # [total_num_pages * page_size, num_combined_kv_heads, head_dim]
73+
*,
74+
page_size: int = 32,
75+
num_slices_per_block: int = 8,
76+
):
77+
assert slices.shape[1] % num_slices_per_block == 0
78+
_, num_combined_kv_heads, head_dim = new_kv.shape
79+
assert kv_cache.shape[1] == num_combined_kv_heads
80+
assert kv_cache.shape[2] == head_dim
81+
assert head_dim % 128 == 0
82+
# TODO: Add dynamic check to make sure that the all the slice lengths are
83+
# smaller or equal to page_size
84+
85+
in_specs = [
86+
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
87+
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
88+
]
89+
90+
out_specs = [pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)]
91+
out_shape = [jax.ShapeDtypeStruct(kv_cache.shape, dtype=kv_cache.dtype)]
92+
93+
scalar_prefetches = [slices]
94+
scratch = pltpu.VMEM(
95+
(num_slices_per_block, page_size, num_combined_kv_heads, head_dim),
96+
new_kv.dtype,
97+
)
98+
99+
scratch_shapes = [
100+
scratch,
101+
pltpu.SemaphoreType.DMA,
102+
]
103+
104+
kernel = pl.pallas_call(
105+
_kv_cache_update_kernel,
106+
grid_spec=pltpu.PrefetchScalarGridSpec(
107+
num_scalar_prefetch=len(scalar_prefetches),
108+
in_specs=in_specs,
109+
out_specs=out_specs,
110+
grid=(slices.shape[1] // num_slices_per_block, ),
111+
scratch_shapes=scratch_shapes,
112+
),
113+
out_shape=out_shape,
114+
input_output_aliases={len(scalar_prefetches) + 1: 0},
115+
)
116+
117+
return kernel(*scalar_prefetches, new_kv, kv_cache)[0]

vllm/v1/attention/backends/pallas.py

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,12 @@
55
from typing import Any, Optional
66

77
import torch
8-
# Required to register custom ops.
8+
import torch_xla.core.xla_builder as xb
99
import torch_xla.experimental.custom_kernel # noqa: F401
10+
# Required to register custom ops.
11+
from torch.library import impl
12+
from torch_xla._internal.jax_workarounds import requires_jax
13+
from torch_xla.experimental.custom_kernel import XLA_LIB
1014

1115
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
1216
AttentionLayer, AttentionType)
@@ -107,6 +111,7 @@ class PallasMetadata:
107111
context_lens: torch.Tensor
108112
query_start_loc: torch.Tensor
109113
num_seqs: torch.Tensor
114+
num_slices_per_kv_cache_update_block: int
110115

111116

112117
class PallasAttentionBackendImpl(AttentionImpl):
@@ -212,7 +217,9 @@ def forward(
212217
# Write input keys and values to the KV cache.
213218
# Skip this if sharing KV cache with an earlier attention layer.
214219
slot_mapping = attn_metadata.slot_mapping
215-
write_to_kv_cache(key, value, kv_cache, slot_mapping)
220+
write_to_kv_cache(
221+
key, value, kv_cache, slot_mapping,
222+
attn_metadata.num_slices_per_kv_cache_update_block)
216223

217224
output = torch.ops.xla.ragged_paged_attention(
218225
query,
@@ -244,16 +251,17 @@ def write_to_kv_cache(
244251
value: torch.Tensor,
245252
kv_cache: torch.Tensor,
246253
slot_mapping: torch.Tensor,
254+
num_slices_per_kv_cache_update_block: int,
247255
) -> None:
248256
""" Write the key and values to the KV cache.
249257
250258
Args:
251259
key: shape = [num_tokens, num_kv_heads * head_size]
252260
value: shape = [num_tokens, num_kv_heads * head_size]
253261
kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size]
254-
262+
num_slices_per_kv_cache_update_block: int
255263
"""
256-
_, _, num_combined_kv_heads, head_size = kv_cache.shape
264+
_, page_size, num_combined_kv_heads, head_size = kv_cache.shape
257265
head_size = cdiv(head_size,
258266
TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
259267
kv = torch.cat([key, value], axis=-1).reshape(-1, num_combined_kv_heads,
@@ -262,4 +270,41 @@ def write_to_kv_cache(
262270
torch.ops.xla.dynamo_set_buffer_donor_(kv_cache, True)
263271

264272
kv_cache = kv_cache.flatten(0, 1)
265-
kv_cache.index_copy_(0, slot_mapping, kv)
273+
new_kv_cache = torch.ops.xla.kv_cache_update_op(
274+
kv, slot_mapping, kv_cache, page_size,
275+
num_slices_per_kv_cache_update_block)
276+
# NOTE: the in-place copy will be optimized away by XLA compiler.
277+
kv_cache.copy_(new_kv_cache)
278+
279+
280+
@requires_jax
281+
def kv_cache_update_op_impl(kv: torch.Tensor, slot_mapping: torch.Tensor,
282+
kv_cache: torch.Tensor, page_size: int,
283+
num_slices_per_block: int):
284+
from vllm.attention.ops.pallas_kv_cache_update import kv_cache_update
285+
new_kv_cache = xb.call_jax(kv_cache_update, (kv, slot_mapping, kv_cache), {
286+
"page_size": page_size,
287+
"num_slices_per_block": num_slices_per_block
288+
})
289+
return new_kv_cache
290+
291+
292+
XLA_LIB.define(
293+
"kv_cache_update_op(Tensor kv, Tensor slot_mapping, Tensor kv_cache, "
294+
"int page_size, int num_slices_per_block) -> Tensor", )
295+
296+
297+
@impl(XLA_LIB, "kv_cache_update_op", "XLA")
298+
def kv_cache_update_op_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
299+
kv_cache: torch.Tensor, page_size: int,
300+
num_slices_per_block: int) -> torch.Tensor:
301+
new_kv_cache = kv_cache_update_op_impl(kv, slot_mapping, kv_cache,
302+
page_size, num_slices_per_block)
303+
return new_kv_cache
304+
305+
306+
@impl(XLA_LIB, "kv_cache_update_op", "CompositeExplicitAutograd")
307+
def kv_cache_update_op_non_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
308+
kv_cache: torch.Tensor, page_size: int,
309+
num_slices_per_block: int) -> torch.Tensor:
310+
return kv_cache

0 commit comments

Comments
 (0)