1
1
# SPDX-License-Identifier: Apache-2.0
2
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
- import unittest .mock as mock
4
3
5
4
import pytest
6
5
17
16
TPUModelRunner , _get_padded_num_reqs_with_upper_limit ,
18
17
_get_padded_token_len , _get_req_paddings , _get_token_paddings )
19
18
20
- # Mock torch_xla module since it may not be available in the test environments
21
- torch_xla_patcher = mock .patch .dict (
22
- "sys.modules" , {
23
- "torch_xla" : mock .MagicMock (),
24
- "torch_xla.core.xla_model" : mock .MagicMock (),
25
- "torch_xla.runtime" : mock .MagicMock (),
26
- })
27
- torch_xla_patcher .start ()
28
19
29
- # Mock the PallasAttentionBackend
30
- pallas_attention_backend_patcher = mock .patch (
31
- "vllm.v1.worker.tpu_model_runner.PallasAttentionBackend" , )
32
- pallas_attention_backend_patcher .start ()
33
-
34
-
35
- @pytest .fixture
36
- def model_runner ():
37
- # Patchers have already been started at module level.
20
+ def get_vllm_config ():
38
21
scheduler_config = SchedulerConfig (
39
22
max_num_seqs = 10 ,
40
23
max_num_batched_tokens = 512 ,
@@ -60,18 +43,19 @@ def model_runner():
60
43
cache_config = cache_config ,
61
44
scheduler_config = scheduler_config ,
62
45
)
46
+ return vllm_config
47
+
48
+
49
+ def get_model_runner (vllm_config ):
63
50
device = "xla:0" # Mocking TPU device
64
- with mock .patch ("vllm.v1.worker.tpu_model_runner.torch" ), \
65
- mock .patch ("vllm.v1.worker.tpu_model_runner.xm" ), \
66
- mock .patch ("vllm.v1.worker.tpu_model_runner.xr" ):
67
- return TPUModelRunner (vllm_config , device )
51
+ return TPUModelRunner (vllm_config , device )
68
52
69
53
70
- @pytest .fixture ( autouse = True , scope = "session" )
71
- def cleanup_patches ():
72
- yield
73
- torch_xla_patcher . stop ()
74
- pallas_attention_backend_patcher . stop ( )
54
+ @pytest .fixture
55
+ def model_runner ():
56
+ # Patchers have already been started at module level.
57
+ vllm_config = get_vllm_config ()
58
+ return get_model_runner ( vllm_config )
75
59
76
60
77
61
def _schedule_new_request (* req_ids : str ) -> SchedulerOutput :
@@ -370,12 +354,14 @@ def test_get_req_paddings():
370
354
assert _get_req_paddings (8 , 36 ) == [8 , 16 , 32 , 36 ]
371
355
372
356
373
- @ pytest . mark . skip ( reason = "Test is broken on TPU when it's added." )
374
- def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order ( ):
357
+ def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order (
358
+ model_runner ):
375
359
layer_0 = "model.layers.0.self_attn.attn"
376
360
layer_1 = "model.layers.1.self_attn.attn"
377
361
error_msg = f"{ layer_1 } must come before the current layer"
378
- with pytest .raises (ValueError , match = error_msg ):
362
+ vllm_config = model_runner .vllm_config
363
+ with pytest .raises (ValueError , match = error_msg ), \
364
+ set_current_vllm_config (vllm_config ):
379
365
fwd_context = {
380
366
# initialization below will fail because target layer is invalid;
381
367
# the target layer needs to come before layer 1
@@ -399,13 +385,14 @@ def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order():
399
385
assert fwd_context is not None
400
386
401
387
402
- @pytest .mark .skip (reason = "Test is broken on TPU when it's added." )
403
- def test_init_kv_cache_with_kv_sharing_target_layer_not_exist ():
388
+ def test_init_kv_cache_with_kv_sharing_target_layer_not_exist (model_runner ):
404
389
layer_0 = "model.layers.0.self_attn.attn"
405
390
layer_1 = "model.layers.1.self_attn.attn"
406
391
invalid_layer = "model.layers.0.cross_attn.attn"
407
392
error_msg = f"{ invalid_layer } is not a valid Attention layer in the model"
408
- with pytest .raises (ValueError , match = error_msg ):
393
+ vllm_config = model_runner .vllm_config
394
+ with pytest .raises (ValueError , match = error_msg ), \
395
+ set_current_vllm_config (vllm_config ):
409
396
fwd_context = {
410
397
layer_0 :
411
398
Attention (
@@ -428,12 +415,13 @@ def test_init_kv_cache_with_kv_sharing_target_layer_not_exist():
428
415
assert fwd_context is not None
429
416
430
417
431
- @pytest .mark .skip (reason = "Test is broken on TPU when it's added." )
432
- def test_init_kv_cache_with_kv_sharing_target_same_as_current ():
418
+ def test_init_kv_cache_with_kv_sharing_target_same_as_current (model_runner ):
433
419
layer_0 = "model.layers.0.self_attn.attn"
434
420
layer_1 = "model.layers.1.self_attn.attn"
435
421
error_msg = f"{ layer_1 } cannot be the same as the current layer"
436
- with pytest .raises (ValueError , match = error_msg ):
422
+ vllm_config = model_runner .vllm_config
423
+ with pytest .raises (ValueError , match = error_msg ), \
424
+ set_current_vllm_config (vllm_config ):
437
425
fwd_context = {
438
426
# initialization below will fail because target layer is invalid;
439
427
# the target layer needs to come before layer 1
@@ -457,11 +445,10 @@ def test_init_kv_cache_with_kv_sharing_target_same_as_current():
457
445
assert fwd_context is not None
458
446
459
447
460
- @pytest .mark .skip (reason = "Test is broken on TPU when it's added." )
461
- def test_init_kv_cache_without_kv_sharing (model_runner ):
448
+ def test_init_kv_cache_without_kv_sharing ():
462
449
layer_0 = "model.layers.0.self_attn.attn"
463
450
layer_1 = "model.layers.1.self_attn.attn"
464
- vllm_config = model_runner . vllm_config
451
+ vllm_config = get_vllm_config ()
465
452
with set_current_vllm_config (vllm_config ):
466
453
fwd_context = {
467
454
layer_0 :
@@ -482,33 +469,38 @@ def test_init_kv_cache_without_kv_sharing(model_runner):
482
469
# suppress var not used error
483
470
assert fwd_context is not None
484
471
# Set high context length to test max context length estimation
485
- vllm_config .model_config .max_model_len = 3_000_000
472
+ vllm_config .model_config .max_model_len = 1_000_000
486
473
vllm_ctx = vllm_config .compilation_config .static_forward_context
474
+ model_runner = get_model_runner (vllm_config )
487
475
kv_cache_spec = model_runner .get_kv_cache_spec ()
488
476
assert len (kv_cache_spec ) == 2
489
477
assert len (model_runner .shared_kv_cache_layers ) == 0
490
478
491
479
available_memory = 20 * GiB_bytes
492
- # page size for layer 0's kv_cache_spec is 32KB
493
- num_expected_blocks = 327680 # 20GB / 32KB / 2 (num layers)
480
+ # page size for each layer KV can be calculated as
481
+ # 2 (non-MLA) * 8 (num_heads) * 128 (head_dim)
482
+ # * 2 (bfloat16, kv_cache dtype) * 128 (block_size) = 512KB
483
+ num_expected_blocks = 20480 # 20GB / 512KB / 2 (num layers)
494
484
kv_cache_config = get_kv_cache_config (vllm_config , kv_cache_spec ,
495
485
available_memory )
496
486
assert kv_cache_config .num_blocks == num_expected_blocks
497
- assert len (kv_cache_config .tensors ) == 2
498
- assert kv_cache_config .tensors [ layer_0 ].size == available_memory // 2
499
- assert kv_cache_config .tensors [ layer_1 ].size == available_memory // 2
487
+ assert len (kv_cache_config .kv_cache_tensors ) == 2
488
+ assert kv_cache_config .kv_cache_tensors [ 0 ].size == available_memory // 2
489
+ assert kv_cache_config .kv_cache_tensors [ 1 ].size == available_memory // 2
500
490
501
491
max_context_len = \
502
492
estimate_max_model_len (vllm_config , kv_cache_spec , 5 * GiB_bytes )
503
493
# max context len with KV sharing should be 2x as large as without
504
- assert max_context_len == 1310720
494
+ # max_context_len = available_memory / (page_size / block_size) / num_caches
495
+ # max_context_len = 5GB / (512KB / 128) / 2 = 655360
496
+ assert max_context_len == 655360
505
497
506
498
# important: override tensor size to prevent large mem alloc during test
507
- # this will only allocate 2 block worth of memory (2 * 32kb )
499
+ # this will only allocate 2 block worth of memory (2 * 512kb )
508
500
kv_cache_config .num_blocks = 1
509
- for layer in kv_cache_config .tensors :
510
- kv_cache_config . tensors [ layer ]. size = \
511
- kv_cache_spec [layer ] .page_size_bytes
501
+ for kv_cache_tensor in kv_cache_config .kv_cache_tensors :
502
+ kv_cache_tensor . size = (
503
+ kv_cache_spec [kv_cache_tensor . shared_by [ 0 ]] .page_size_bytes )
512
504
513
505
model_runner .initialize_kv_cache (kv_cache_config )
514
506
@@ -524,11 +516,10 @@ def test_init_kv_cache_without_kv_sharing(model_runner):
524
516
assert kv_cache_config .kv_cache_groups [0 ].layer_names [1 ] == layer_1
525
517
526
518
527
- @pytest .mark .skip (reason = "Test is broken on TPU when it's added." )
528
- def test_init_kv_cache_with_kv_sharing_valid (model_runner ):
519
+ def test_init_kv_cache_with_kv_sharing_valid ():
529
520
layer_0 = "model.layers.0.self_attn.attn"
530
521
layer_1 = "model.layers.1.self_attn.attn"
531
- vllm_config = model_runner . vllm_config
522
+ vllm_config = get_vllm_config ()
532
523
with set_current_vllm_config (vllm_config ):
533
524
fwd_context = {
534
525
layer_0 :
@@ -552,33 +543,34 @@ def test_init_kv_cache_with_kv_sharing_valid(model_runner):
552
543
# Set high context length to test max context length estimation
553
544
vllm_config .model_config .max_model_len = 3_000_000
554
545
vllm_ctx = vllm_config .compilation_config .static_forward_context
546
+ model_runner = get_model_runner (vllm_config )
555
547
kv_cache_spec = model_runner .get_kv_cache_spec ()
556
548
assert len (kv_cache_spec ) == 1
557
549
assert layer_0 in kv_cache_spec
558
550
assert model_runner .shared_kv_cache_layers [layer_1 ] == layer_0
559
551
560
552
available_memory = 20 * GiB_bytes
561
- # page size for layer 0's kv_cache_spec is 32KB
553
+ # page size for layer 0's kv_cache_spec is 512KB
562
554
# with KV sharing, we can allocate (available_mem//page_size//1) blocks
563
555
# which is twice as many as without KV sharing
564
- num_expected_blocks = 655360 # 20GB / 32KB
556
+ num_expected_blocks = 2 * 20480 # 20GB / 512KB
565
557
kv_cache_config = get_kv_cache_config (vllm_config , kv_cache_spec ,
566
558
available_memory )
567
559
assert kv_cache_config .num_blocks == num_expected_blocks
568
- assert len (kv_cache_config .tensors ) == 1
560
+ assert len (kv_cache_config .kv_cache_tensors ) == 1
569
561
# Each layer now has twice the available memory for KV cache
570
562
# compared to no KV sharing
571
- assert kv_cache_config .tensors [ layer_0 ].size == available_memory
563
+ assert kv_cache_config .kv_cache_tensors [ 0 ].size == available_memory
572
564
573
565
max_context_len = \
574
566
estimate_max_model_len (vllm_config , kv_cache_spec , 5 * GiB_bytes )
575
567
# max context len with KV sharing should be 2x as large as without
576
- assert max_context_len == 2 * 1310720
568
+ assert max_context_len == ( 2 * 655360 )
577
569
578
570
# important: override tensor size to prevent large mem alloc during test
579
- # this will only allocate 1 block worth of memory (32kb )
571
+ # this will only allocate 1 block worth of memory (512kb )
580
572
kv_cache_config .num_blocks = 1
581
- kv_cache_config .tensors [ layer_0 ].size = \
573
+ kv_cache_config .kv_cache_tensors [ 0 ].size = \
582
574
kv_cache_spec [layer_0 ].page_size_bytes
583
575
584
576
model_runner .initialize_kv_cache (kv_cache_config )
0 commit comments