Skip to content

Commit ada7b41

Browse files
yaochengjihuydhn
authored andcommitted
[TPU] kv cache update kernel supports dynamic grid (vllm-project#20235)
Signed-off-by: Chengji Yao <chengjiyao@google.com>
1 parent 26367f2 commit ada7b41

File tree

4 files changed

+42
-17
lines changed

4 files changed

+42
-17
lines changed

tests/v1/tpu/test_kv_cache_update_kernel.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def test_kv_cache_update_kernel(page_size: int, combined_kv_head_num: int,
3232
new_kv_xla = new_kv_cpu.to(torch_xla.device())
3333
slice_lens = np.array([7, page_size, page_size, 1, 1, 1, 9],
3434
dtype=np.int32)
35+
num_kv_update_slices = len(slice_lens)
3536
kv_cache_start_indices = np.array([
3637
page_size * 2 - 7, page_size * 2, page_size * 3, page_size * 4 + 6,
3738
page_size * 5 + 7, page_size * 6 + 8, page_size * 15 + 3
@@ -52,12 +53,15 @@ def test_kv_cache_update_kernel(page_size: int, combined_kv_head_num: int,
5253
device="cpu",
5354
dtype=torch.int32)
5455
slot_mapping_xla = slot_mapping_cpu.to(torch_xla.device())
56+
num_kv_update_slices_xla = torch.tensor([num_kv_update_slices],
57+
device=torch_xla.device(),
58+
dtype=torch.int32)
5559
torch_xla.sync()
5660

5761
torch.ops.xla.dynamo_set_buffer_donor_(kv_cache_xla, True)
5862
new_kv_cache_xla = torch.ops.xla.kv_cache_update_op(
59-
new_kv_xla, slot_mapping_xla, kv_cache_xla, page_size,
60-
num_slices_per_block)
63+
new_kv_xla, slot_mapping_xla, kv_cache_xla, num_kv_update_slices_xla,
64+
page_size, num_slices_per_block)
6165
kv_cache_xla.copy_(new_kv_cache_xla)
6266
torch_xla.sync()
6367

vllm/attention/ops/pallas_kv_cache_update.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,13 @@
77
from jax.experimental import pallas as pl
88
from jax.experimental.pallas import tpu as pltpu
99

10+
from vllm.utils import cdiv
11+
1012

1113
def _kv_cache_update_kernel(
1214
# Prefetch
13-
slices_ref, # [3, num_slices], list of (kv_cache_start, new_kv_start,
14-
# slice_len)
15+
slices_ref, # [3, padded_num_slices], list of (kv_cache_start,
16+
# new_kv_start, slice_len)
1517
# Input
1618
new_kv_hbm_ref, # [num_tokens, num_combined_kv_heads, head_dim]
1719
kv_cache_hbm_ref, # [total_num_pages * page_size, num_combined_kv_heads,
@@ -70,6 +72,7 @@ def kv_cache_update(
7072
Array, # [3, slices], list of (kv_cache_start, new_kv_start, slice_len)
7173
kv_cache: jax.
7274
Array, # [total_num_pages * page_size, num_combined_kv_heads, head_dim]
75+
num_kv_update_slices: jax.Array, # [1]
7376
*,
7477
page_size: int = 32,
7578
num_slices_per_block: int = 8,
@@ -107,7 +110,7 @@ def kv_cache_update(
107110
num_scalar_prefetch=len(scalar_prefetches),
108111
in_specs=in_specs,
109112
out_specs=out_specs,
110-
grid=(slices.shape[1] // num_slices_per_block, ),
113+
grid=(cdiv(num_kv_update_slices[0], num_slices_per_block), ),
111114
scratch_shapes=scratch_shapes,
112115
),
113116
out_shape=out_shape,

vllm/v1/attention/backends/pallas.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ class PallasMetadata:
111111
context_lens: torch.Tensor
112112
query_start_loc: torch.Tensor
113113
num_seqs: torch.Tensor
114+
num_kv_update_slices: torch.Tensor
114115
num_slices_per_kv_cache_update_block: int
115116

116117

@@ -219,7 +220,8 @@ def forward(
219220
slot_mapping = attn_metadata.slot_mapping
220221
write_to_kv_cache(
221222
key, value, kv_cache, slot_mapping,
222-
attn_metadata.num_slices_per_kv_cache_update_block)
223+
attn_metadata.num_slices_per_kv_cache_update_block,
224+
attn_metadata.num_kv_update_slices)
223225

224226
output = torch.ops.xla.ragged_paged_attention(
225227
query,
@@ -252,6 +254,7 @@ def write_to_kv_cache(
252254
kv_cache: torch.Tensor,
253255
slot_mapping: torch.Tensor,
254256
num_slices_per_kv_cache_update_block: int,
257+
num_kv_update_slices: torch.Tensor,
255258
) -> None:
256259
""" Write the key and values to the KV cache.
257260
@@ -271,40 +274,47 @@ def write_to_kv_cache(
271274

272275
kv_cache = kv_cache.flatten(0, 1)
273276
new_kv_cache = torch.ops.xla.kv_cache_update_op(
274-
kv, slot_mapping, kv_cache, page_size,
277+
kv, slot_mapping, kv_cache, num_kv_update_slices, page_size,
275278
num_slices_per_kv_cache_update_block)
276279
# NOTE: the in-place copy will be optimized away by XLA compiler.
277280
kv_cache.copy_(new_kv_cache)
278281

279282

280283
@requires_jax
281284
def kv_cache_update_op_impl(kv: torch.Tensor, slot_mapping: torch.Tensor,
282-
kv_cache: torch.Tensor, page_size: int,
285+
kv_cache: torch.Tensor,
286+
num_kv_update_slices: torch.Tensor, page_size: int,
283287
num_slices_per_block: int):
284288
from vllm.attention.ops.pallas_kv_cache_update import kv_cache_update
285-
new_kv_cache = xb.call_jax(kv_cache_update, (kv, slot_mapping, kv_cache), {
286-
"page_size": page_size,
287-
"num_slices_per_block": num_slices_per_block
288-
})
289+
new_kv_cache = xb.call_jax(
290+
kv_cache_update, (kv, slot_mapping, kv_cache, num_kv_update_slices), {
291+
"page_size": page_size,
292+
"num_slices_per_block": num_slices_per_block
293+
})
289294
return new_kv_cache
290295

291296

292297
XLA_LIB.define(
293-
"kv_cache_update_op(Tensor kv, Tensor slot_mapping, Tensor kv_cache, "
294-
"int page_size, int num_slices_per_block) -> Tensor", )
298+
"kv_cache_update_op(Tensor kv, Tensor slot_mapping, Tensor kv_cache," \
299+
"Tensor num_kv_update_slices, int page_size, int num_slices_per_block)" \
300+
"-> Tensor", )
295301

296302

297303
@impl(XLA_LIB, "kv_cache_update_op", "XLA")
298304
def kv_cache_update_op_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
299-
kv_cache: torch.Tensor, page_size: int,
305+
kv_cache: torch.Tensor,
306+
num_kv_update_slices: torch.Tensor, page_size: int,
300307
num_slices_per_block: int) -> torch.Tensor:
301308
new_kv_cache = kv_cache_update_op_impl(kv, slot_mapping, kv_cache,
302-
page_size, num_slices_per_block)
309+
num_kv_update_slices, page_size,
310+
num_slices_per_block)
303311
return new_kv_cache
304312

305313

306314
@impl(XLA_LIB, "kv_cache_update_op", "CompositeExplicitAutograd")
307315
def kv_cache_update_op_non_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
308-
kv_cache: torch.Tensor, page_size: int,
316+
kv_cache: torch.Tensor,
317+
num_kv_update_slices: torch.Tensor,
318+
page_size: int,
309319
num_slices_per_block: int) -> torch.Tensor:
310320
return kv_cache

vllm/v1/worker/tpu_model_runner.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -713,8 +713,10 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput",
713713
self.device)
714714
block_tables = block_tables.to(self.device)
715715

716+
# Calculate the slot mapping
716717
slot_mapping_metadata = self._get_slot_mapping_metadata(
717718
num_reqs, num_scheduled_tokens_per_req)
719+
num_kv_update_slices = slot_mapping_metadata.shape[0]
718720
padded_num_slices = _get_padded_num_kv_cache_update_slices(
719721
padded_total_num_scheduled_tokens, self.max_num_reqs,
720722
self.block_size)
@@ -745,6 +747,9 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput",
745747
num_seqs=torch.tensor([num_reqs],
746748
dtype=torch.int32,
747749
device=self.device),
750+
num_kv_update_slices=torch.tensor([num_kv_update_slices],
751+
dtype=torch.int32,
752+
device=self.device),
748753
num_slices_per_kv_cache_update_block=
749754
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK,
750755
)
@@ -1174,6 +1179,8 @@ def _dummy_run(self, num_tokens: int, num_reqs: int,
11741179
dtype=torch.int32).to(self.device)
11751180
padded_num_slices = _get_padded_num_kv_cache_update_slices(
11761181
num_tokens, self.max_num_reqs, self.block_size)
1182+
num_kv_update_slices = torch.tensor([padded_num_slices],
1183+
dtype=torch.int32).to(self.device)
11771184
slot_mapping = torch.zeros((3, padded_num_slices),
11781185
dtype=torch.int32).to(self.device)
11791186
block_tables = torch.zeros((num_reqs, num_blocks),
@@ -1193,6 +1200,7 @@ def _dummy_run(self, num_tokens: int, num_reqs: int,
11931200
context_lens=context_lens,
11941201
query_start_loc=query_start_loc,
11951202
num_seqs=num_seqs,
1203+
num_kv_update_slices=num_kv_update_slices,
11961204
num_slices_per_kv_cache_update_block=
11971205
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK,
11981206
)

0 commit comments

Comments
 (0)