Skip to content

Commit 488d8a9

Browse files
authored
[V1] [Hybrid] Add new test to verify that hybrid views into KVCacheTensor are compatible (#21300)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
1 parent af376ca commit 488d8a9

File tree

1 file changed

+149
-1
lines changed

1 file changed

+149
-1
lines changed

tests/v1/worker/test_gpu_model_runner.py

Lines changed: 149 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,19 @@
33

44
import random
55

6+
import numpy as np
67
import pytest
78
import torch
89

910
from vllm.attention import Attention
1011
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
1112
SchedulerConfig, VllmConfig, set_current_vllm_config)
13+
from vllm.distributed.parallel_state import (init_distributed_environment,
14+
initialize_model_parallel)
15+
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
1216
from vllm.platforms import current_platform
1317
from vllm.sampling_params import SamplingParams
14-
from vllm.utils import GiB_bytes
18+
from vllm.utils import GiB_bytes, update_environment_variables
1519
from vllm.v1.core.kv_cache_utils import (estimate_max_model_len,
1620
get_kv_cache_config)
1721
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
@@ -686,3 +690,147 @@ def test_init_kv_cache_with_kv_sharing_valid():
686690
assert len(kv_cache_config.kv_cache_groups[0].layer_names) == 2
687691
assert kv_cache_config.kv_cache_groups[0].layer_names[0] == layer_0
688692
assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1
693+
694+
695+
def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
696+
'''
697+
The GPU model runner creates different views into the
698+
KVCacheTensors for the attention and mamba layers
699+
(via _reshape_kv_cache_tensors function). This test verifies
700+
that the views are compatible: writing a mamba block
701+
will not corrupt an attention block and vice-versa
702+
'''
703+
704+
current_platform.seed_everything(42)
705+
706+
update_environment_variables({
707+
'RANK': "0",
708+
'LOCAL_RANK': "0",
709+
'WORLD_SIZE': "1",
710+
'MASTER_ADDR': 'localhost',
711+
'MASTER_PORT': '12345',
712+
})
713+
init_distributed_environment()
714+
initialize_model_parallel(tensor_model_parallel_size=1)
715+
torch.set_default_dtype(torch.float16)
716+
717+
scheduler_config = SchedulerConfig(
718+
max_num_seqs=10,
719+
max_num_batched_tokens=512,
720+
max_model_len=512,
721+
)
722+
model_config = ModelConfig(
723+
model="ibm-granite/granite-4.0-tiny-preview",
724+
dtype="float16",
725+
)
726+
cache_config = CacheConfig(
727+
block_size=BLOCK_SIZE,
728+
gpu_memory_utilization=0.9,
729+
swap_space=0,
730+
cache_dtype="auto",
731+
)
732+
parallel_config = ParallelConfig()
733+
vllm_config = VllmConfig(
734+
model_config=model_config,
735+
cache_config=cache_config,
736+
scheduler_config=scheduler_config,
737+
parallel_config=parallel_config,
738+
)
739+
740+
layer_0 = "model.layers.0.self_attn.attn"
741+
layer_1 = "model.layers.1.self_attn.attn"
742+
layer_2 = "model.layers.2.mixer"
743+
layer_3 = "model.layers.3.mixer"
744+
layer_4 = "model.layers.4.mixer"
745+
layer_5 = "model.layers.5.mixer"
746+
747+
with set_current_vllm_config(vllm_config):
748+
hf_config = vllm_config.model_config.hf_config
749+
fwd_context = {}
750+
for key in [layer_0, layer_1]:
751+
fwd_context[key] = Attention(
752+
num_heads=model_config.get_num_attention_heads(
753+
parallel_config),
754+
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
755+
head_size=model_config.get_head_size(),
756+
scale=1.0,
757+
prefix=key,
758+
)
759+
for key in [layer_2, layer_3, layer_4, layer_5]:
760+
fwd_context[key] = MambaMixer2(
761+
hidden_size = hf_config.hidden_size,
762+
ssm_state_size = hf_config.mamba_d_state,
763+
conv_kernel_size = hf_config.mamba_d_conv,
764+
intermediate_size = hf_config.mamba_expand *\
765+
hf_config.hidden_size,
766+
use_conv_bias = hf_config.mamba_conv_bias,
767+
use_bias = hf_config.mamba_proj_bias,
768+
n_groups=hf_config.mamba_n_groups,
769+
num_heads=hf_config.mamba_n_heads,
770+
head_dim=hf_config.mamba_d_head,
771+
rms_norm_eps=hf_config.rms_norm_eps,
772+
activation=hf_config.hidden_act,
773+
prefix=key,
774+
)
775+
# suppress var not used error
776+
assert fwd_context is not None
777+
vllm_ctx = vllm_config.compilation_config.static_forward_context
778+
779+
with monkeypatch.context() as m:
780+
781+
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
782+
783+
runner = GPUModelRunner(vllm_config, DEVICE)
784+
kv_cache_spec = runner.get_kv_cache_spec()
785+
786+
available_memory = 5 * GiB_bytes
787+
kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec,
788+
available_memory)
789+
runner.initialize_kv_cache(kv_cache_config)
790+
791+
# random partition of blocks
792+
# blocks0 will be assigned to attention layers
793+
# blocks1 will be assigned to mamba layers
794+
num_blocks = kv_cache_config.num_blocks
795+
ind = np.arange(num_blocks)
796+
np.random.shuffle(ind)
797+
blocks0, blocks1 = ind[:(num_blocks // 2)], ind[(num_blocks // 2):]
798+
799+
attn_shape = vllm_ctx[layer_0].kv_cache[0].shape
800+
conv_shape = vllm_ctx[layer_2].kv_cache[0][0].shape
801+
ssm_shape = vllm_ctx[layer_2].kv_cache[0][1].shape
802+
803+
# assert we are using FlashInfer
804+
assert attn_shape[0] == num_blocks
805+
806+
attn_blocks_constant = torch.full((len(blocks0), *attn_shape[1:]),
807+
device=DEVICE,
808+
fill_value=3.33)
809+
conv_blocks_constant = torch.full((len(blocks1), *conv_shape[1:]),
810+
device=DEVICE,
811+
fill_value=6.66)
812+
ssm_blocks_constant = torch.full((len(blocks1), *ssm_shape[1:]),
813+
device=DEVICE,
814+
fill_value=9.99)
815+
816+
# fill all attention blocks with constant
817+
for layer in [layer_0, layer_1]:
818+
vllm_ctx[layer].kv_cache[0][
819+
blocks0, :] = attn_blocks_constant.detach().clone()
820+
821+
# fill all mamba blocks with constant
822+
for layer in [layer_2, layer_3, layer_4, layer_5]:
823+
vllm_ctx[layer].kv_cache[0][0][
824+
blocks1, :] = conv_blocks_constant.detach().clone()
825+
vllm_ctx[layer].kv_cache[0][1][
826+
blocks1, :] = ssm_blocks_constant.detach().clone()
827+
828+
# verify attention and mamba contents are correct
829+
for layer in [layer_0, layer_1]:
830+
assert torch.equal(vllm_ctx[layer].kv_cache[0][blocks0, :],
831+
attn_blocks_constant)
832+
for layer in [layer_2, layer_3, layer_4, layer_5]:
833+
assert torch.equal(vllm_ctx[layer].kv_cache[0][0][blocks1, :],
834+
conv_blocks_constant)
835+
assert torch.equal(vllm_ctx[layer].kv_cache[0][1][blocks1, :],
836+
ssm_blocks_constant)

0 commit comments

Comments
 (0)