Skip to content

Commit 7d44c46

Browse files
authored
[TPU]Fix KV cache sharing tests (#19371)
1 parent 31f58be commit 7d44c46

File tree

1 file changed

+52
-60
lines changed

1 file changed

+52
-60
lines changed

tests/v1/tpu/worker/test_tpu_model_runner.py

Lines changed: 52 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3-
import unittest.mock as mock
43

54
import pytest
65

@@ -17,24 +16,8 @@
1716
TPUModelRunner, _get_padded_num_reqs_with_upper_limit,
1817
_get_padded_token_len, _get_req_paddings, _get_token_paddings)
1918

20-
# Mock torch_xla module since it may not be available in the test environments
21-
torch_xla_patcher = mock.patch.dict(
22-
"sys.modules", {
23-
"torch_xla": mock.MagicMock(),
24-
"torch_xla.core.xla_model": mock.MagicMock(),
25-
"torch_xla.runtime": mock.MagicMock(),
26-
})
27-
torch_xla_patcher.start()
2819

29-
# Mock the PallasAttentionBackend
30-
pallas_attention_backend_patcher = mock.patch(
31-
"vllm.v1.worker.tpu_model_runner.PallasAttentionBackend", )
32-
pallas_attention_backend_patcher.start()
33-
34-
35-
@pytest.fixture
36-
def model_runner():
37-
# Patchers have already been started at module level.
20+
def get_vllm_config():
3821
scheduler_config = SchedulerConfig(
3922
max_num_seqs=10,
4023
max_num_batched_tokens=512,
@@ -60,18 +43,19 @@ def model_runner():
6043
cache_config=cache_config,
6144
scheduler_config=scheduler_config,
6245
)
46+
return vllm_config
47+
48+
49+
def get_model_runner(vllm_config):
6350
device = "xla:0" # Mocking TPU device
64-
with mock.patch("vllm.v1.worker.tpu_model_runner.torch"), \
65-
mock.patch("vllm.v1.worker.tpu_model_runner.xm"), \
66-
mock.patch("vllm.v1.worker.tpu_model_runner.xr"):
67-
return TPUModelRunner(vllm_config, device)
51+
return TPUModelRunner(vllm_config, device)
6852

6953

70-
@pytest.fixture(autouse=True, scope="session")
71-
def cleanup_patches():
72-
yield
73-
torch_xla_patcher.stop()
74-
pallas_attention_backend_patcher.stop()
54+
@pytest.fixture
55+
def model_runner():
56+
# Patchers have already been started at module level.
57+
vllm_config = get_vllm_config()
58+
return get_model_runner(vllm_config)
7559

7660

7761
def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
@@ -370,12 +354,14 @@ def test_get_req_paddings():
370354
assert _get_req_paddings(8, 36) == [8, 16, 32, 36]
371355

372356

373-
@pytest.mark.skip(reason="Test is broken on TPU when it's added.")
374-
def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order():
357+
def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order(
358+
model_runner):
375359
layer_0 = "model.layers.0.self_attn.attn"
376360
layer_1 = "model.layers.1.self_attn.attn"
377361
error_msg = f"{layer_1} must come before the current layer"
378-
with pytest.raises(ValueError, match=error_msg):
362+
vllm_config = model_runner.vllm_config
363+
with pytest.raises(ValueError, match=error_msg), \
364+
set_current_vllm_config(vllm_config):
379365
fwd_context = {
380366
# initialization below will fail because target layer is invalid;
381367
# the target layer needs to come before layer 1
@@ -399,13 +385,14 @@ def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order():
399385
assert fwd_context is not None
400386

401387

402-
@pytest.mark.skip(reason="Test is broken on TPU when it's added.")
403-
def test_init_kv_cache_with_kv_sharing_target_layer_not_exist():
388+
def test_init_kv_cache_with_kv_sharing_target_layer_not_exist(model_runner):
404389
layer_0 = "model.layers.0.self_attn.attn"
405390
layer_1 = "model.layers.1.self_attn.attn"
406391
invalid_layer = "model.layers.0.cross_attn.attn"
407392
error_msg = f"{invalid_layer} is not a valid Attention layer in the model"
408-
with pytest.raises(ValueError, match=error_msg):
393+
vllm_config = model_runner.vllm_config
394+
with pytest.raises(ValueError, match=error_msg), \
395+
set_current_vllm_config(vllm_config):
409396
fwd_context = {
410397
layer_0:
411398
Attention(
@@ -428,12 +415,13 @@ def test_init_kv_cache_with_kv_sharing_target_layer_not_exist():
428415
assert fwd_context is not None
429416

430417

431-
@pytest.mark.skip(reason="Test is broken on TPU when it's added.")
432-
def test_init_kv_cache_with_kv_sharing_target_same_as_current():
418+
def test_init_kv_cache_with_kv_sharing_target_same_as_current(model_runner):
433419
layer_0 = "model.layers.0.self_attn.attn"
434420
layer_1 = "model.layers.1.self_attn.attn"
435421
error_msg = f"{layer_1} cannot be the same as the current layer"
436-
with pytest.raises(ValueError, match=error_msg):
422+
vllm_config = model_runner.vllm_config
423+
with pytest.raises(ValueError, match=error_msg), \
424+
set_current_vllm_config(vllm_config):
437425
fwd_context = {
438426
# initialization below will fail because target layer is invalid;
439427
# the target layer needs to come before layer 1
@@ -457,11 +445,10 @@ def test_init_kv_cache_with_kv_sharing_target_same_as_current():
457445
assert fwd_context is not None
458446

459447

460-
@pytest.mark.skip(reason="Test is broken on TPU when it's added.")
461-
def test_init_kv_cache_without_kv_sharing(model_runner):
448+
def test_init_kv_cache_without_kv_sharing():
462449
layer_0 = "model.layers.0.self_attn.attn"
463450
layer_1 = "model.layers.1.self_attn.attn"
464-
vllm_config = model_runner.vllm_config
451+
vllm_config = get_vllm_config()
465452
with set_current_vllm_config(vllm_config):
466453
fwd_context = {
467454
layer_0:
@@ -482,33 +469,38 @@ def test_init_kv_cache_without_kv_sharing(model_runner):
482469
# suppress var not used error
483470
assert fwd_context is not None
484471
# Set high context length to test max context length estimation
485-
vllm_config.model_config.max_model_len = 3_000_000
472+
vllm_config.model_config.max_model_len = 1_000_000
486473
vllm_ctx = vllm_config.compilation_config.static_forward_context
474+
model_runner = get_model_runner(vllm_config)
487475
kv_cache_spec = model_runner.get_kv_cache_spec()
488476
assert len(kv_cache_spec) == 2
489477
assert len(model_runner.shared_kv_cache_layers) == 0
490478

491479
available_memory = 20 * GiB_bytes
492-
# page size for layer 0's kv_cache_spec is 32KB
493-
num_expected_blocks = 327680 # 20GB / 32KB / 2 (num layers)
480+
# page size for each layer KV can be calculated as
481+
# 2 (non-MLA) * 8 (num_heads) * 128 (head_dim)
482+
# * 2 (bfloat16, kv_cache dtype) * 128 (block_size) = 512KB
483+
num_expected_blocks = 20480 # 20GB / 512KB / 2 (num layers)
494484
kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec,
495485
available_memory)
496486
assert kv_cache_config.num_blocks == num_expected_blocks
497-
assert len(kv_cache_config.tensors) == 2
498-
assert kv_cache_config.tensors[layer_0].size == available_memory // 2
499-
assert kv_cache_config.tensors[layer_1].size == available_memory // 2
487+
assert len(kv_cache_config.kv_cache_tensors) == 2
488+
assert kv_cache_config.kv_cache_tensors[0].size == available_memory // 2
489+
assert kv_cache_config.kv_cache_tensors[1].size == available_memory // 2
500490

501491
max_context_len =\
502492
estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes)
503493
# max context len with KV sharing should be 2x as large as without
504-
assert max_context_len == 1310720
494+
# max_context_len = available_memory / (page_size / block_size) / num_caches
495+
# max_context_len = 5GB / (512KB / 128) / 2 = 655360
496+
assert max_context_len == 655360
505497

506498
# important: override tensor size to prevent large mem alloc during test
507-
# this will only allocate 2 block worth of memory (2 * 32kb)
499+
# this will only allocate 2 block worth of memory (2 * 512kb)
508500
kv_cache_config.num_blocks = 1
509-
for layer in kv_cache_config.tensors:
510-
kv_cache_config.tensors[layer].size =\
511-
kv_cache_spec[layer].page_size_bytes
501+
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
502+
kv_cache_tensor.size = (
503+
kv_cache_spec[kv_cache_tensor.shared_by[0]].page_size_bytes)
512504

513505
model_runner.initialize_kv_cache(kv_cache_config)
514506

@@ -524,11 +516,10 @@ def test_init_kv_cache_without_kv_sharing(model_runner):
524516
assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1
525517

526518

527-
@pytest.mark.skip(reason="Test is broken on TPU when it's added.")
528-
def test_init_kv_cache_with_kv_sharing_valid(model_runner):
519+
def test_init_kv_cache_with_kv_sharing_valid():
529520
layer_0 = "model.layers.0.self_attn.attn"
530521
layer_1 = "model.layers.1.self_attn.attn"
531-
vllm_config = model_runner.vllm_config
522+
vllm_config = get_vllm_config()
532523
with set_current_vllm_config(vllm_config):
533524
fwd_context = {
534525
layer_0:
@@ -552,33 +543,34 @@ def test_init_kv_cache_with_kv_sharing_valid(model_runner):
552543
# Set high context length to test max context length estimation
553544
vllm_config.model_config.max_model_len = 3_000_000
554545
vllm_ctx = vllm_config.compilation_config.static_forward_context
546+
model_runner = get_model_runner(vllm_config)
555547
kv_cache_spec = model_runner.get_kv_cache_spec()
556548
assert len(kv_cache_spec) == 1
557549
assert layer_0 in kv_cache_spec
558550
assert model_runner.shared_kv_cache_layers[layer_1] == layer_0
559551

560552
available_memory = 20 * GiB_bytes
561-
# page size for layer 0's kv_cache_spec is 32KB
553+
# page size for layer 0's kv_cache_spec is 512KB
562554
# with KV sharing, we can allocate (available_mem//page_size//1) blocks
563555
# which is twice as many as without KV sharing
564-
num_expected_blocks = 655360 # 20GB / 32KB
556+
num_expected_blocks = 2 * 20480 # 20GB / 512KB
565557
kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec,
566558
available_memory)
567559
assert kv_cache_config.num_blocks == num_expected_blocks
568-
assert len(kv_cache_config.tensors) == 1
560+
assert len(kv_cache_config.kv_cache_tensors) == 1
569561
# Each layer now has twice the available memory for KV cache
570562
# compared to no KV sharing
571-
assert kv_cache_config.tensors[layer_0].size == available_memory
563+
assert kv_cache_config.kv_cache_tensors[0].size == available_memory
572564

573565
max_context_len =\
574566
estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes)
575567
# max context len with KV sharing should be 2x as large as without
576-
assert max_context_len == 2 * 1310720
568+
assert max_context_len == (2 * 655360)
577569

578570
# important: override tensor size to prevent large mem alloc during test
579-
# this will only allocate 1 block worth of memory (32kb)
571+
# this will only allocate 1 block worth of memory (512kb)
580572
kv_cache_config.num_blocks = 1
581-
kv_cache_config.tensors[layer_0].size =\
573+
kv_cache_config.kv_cache_tensors[0].size =\
582574
kv_cache_spec[layer_0].page_size_bytes
583575

584576
model_runner.initialize_kv_cache(kv_cache_config)

0 commit comments

Comments
 (0)