|
3 | 3 |
|
4 | 4 | import random
|
5 | 5 |
|
| 6 | +import numpy as np |
6 | 7 | import pytest
|
7 | 8 | import torch
|
8 | 9 |
|
9 | 10 | from vllm.attention import Attention
|
10 | 11 | from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
11 | 12 | SchedulerConfig, VllmConfig, set_current_vllm_config)
|
| 13 | +from vllm.distributed.parallel_state import (init_distributed_environment, |
| 14 | + initialize_model_parallel) |
| 15 | +from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 |
12 | 16 | from vllm.platforms import current_platform
|
13 | 17 | from vllm.sampling_params import SamplingParams
|
14 |
| -from vllm.utils import GiB_bytes |
| 18 | +from vllm.utils import GiB_bytes, update_environment_variables |
15 | 19 | from vllm.v1.core.kv_cache_utils import (estimate_max_model_len,
|
16 | 20 | get_kv_cache_config)
|
17 | 21 | from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
|
@@ -686,3 +690,147 @@ def test_init_kv_cache_with_kv_sharing_valid():
|
686 | 690 | assert len(kv_cache_config.kv_cache_groups[0].layer_names) == 2
|
687 | 691 | assert kv_cache_config.kv_cache_groups[0].layer_names[0] == layer_0
|
688 | 692 | assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1
|
| 693 | + |
| 694 | + |
| 695 | +def test_hybrid_attention_mamba_tensor_shapes(monkeypatch): |
| 696 | + ''' |
| 697 | + The GPU model runner creates different views into the |
| 698 | + KVCacheTensors for the attention and mamba layers |
| 699 | + (via _reshape_kv_cache_tensors function). This test verifies |
| 700 | + that the views are compatible: writing a mamba block |
| 701 | + will not corrupt an attention block and vice-versa |
| 702 | + ''' |
| 703 | + |
| 704 | + current_platform.seed_everything(42) |
| 705 | + |
| 706 | + update_environment_variables({ |
| 707 | + 'RANK': "0", |
| 708 | + 'LOCAL_RANK': "0", |
| 709 | + 'WORLD_SIZE': "1", |
| 710 | + 'MASTER_ADDR': 'localhost', |
| 711 | + 'MASTER_PORT': '12345', |
| 712 | + }) |
| 713 | + init_distributed_environment() |
| 714 | + initialize_model_parallel(tensor_model_parallel_size=1) |
| 715 | + torch.set_default_dtype(torch.float16) |
| 716 | + |
| 717 | + scheduler_config = SchedulerConfig( |
| 718 | + max_num_seqs=10, |
| 719 | + max_num_batched_tokens=512, |
| 720 | + max_model_len=512, |
| 721 | + ) |
| 722 | + model_config = ModelConfig( |
| 723 | + model="ibm-granite/granite-4.0-tiny-preview", |
| 724 | + dtype="float16", |
| 725 | + ) |
| 726 | + cache_config = CacheConfig( |
| 727 | + block_size=BLOCK_SIZE, |
| 728 | + gpu_memory_utilization=0.9, |
| 729 | + swap_space=0, |
| 730 | + cache_dtype="auto", |
| 731 | + ) |
| 732 | + parallel_config = ParallelConfig() |
| 733 | + vllm_config = VllmConfig( |
| 734 | + model_config=model_config, |
| 735 | + cache_config=cache_config, |
| 736 | + scheduler_config=scheduler_config, |
| 737 | + parallel_config=parallel_config, |
| 738 | + ) |
| 739 | + |
| 740 | + layer_0 = "model.layers.0.self_attn.attn" |
| 741 | + layer_1 = "model.layers.1.self_attn.attn" |
| 742 | + layer_2 = "model.layers.2.mixer" |
| 743 | + layer_3 = "model.layers.3.mixer" |
| 744 | + layer_4 = "model.layers.4.mixer" |
| 745 | + layer_5 = "model.layers.5.mixer" |
| 746 | + |
| 747 | + with set_current_vllm_config(vllm_config): |
| 748 | + hf_config = vllm_config.model_config.hf_config |
| 749 | + fwd_context = {} |
| 750 | + for key in [layer_0, layer_1]: |
| 751 | + fwd_context[key] = Attention( |
| 752 | + num_heads=model_config.get_num_attention_heads( |
| 753 | + parallel_config), |
| 754 | + num_kv_heads=model_config.get_num_kv_heads(parallel_config), |
| 755 | + head_size=model_config.get_head_size(), |
| 756 | + scale=1.0, |
| 757 | + prefix=key, |
| 758 | + ) |
| 759 | + for key in [layer_2, layer_3, layer_4, layer_5]: |
| 760 | + fwd_context[key] = MambaMixer2( |
| 761 | + hidden_size = hf_config.hidden_size, |
| 762 | + ssm_state_size = hf_config.mamba_d_state, |
| 763 | + conv_kernel_size = hf_config.mamba_d_conv, |
| 764 | + intermediate_size = hf_config.mamba_expand *\ |
| 765 | + hf_config.hidden_size, |
| 766 | + use_conv_bias = hf_config.mamba_conv_bias, |
| 767 | + use_bias = hf_config.mamba_proj_bias, |
| 768 | + n_groups=hf_config.mamba_n_groups, |
| 769 | + num_heads=hf_config.mamba_n_heads, |
| 770 | + head_dim=hf_config.mamba_d_head, |
| 771 | + rms_norm_eps=hf_config.rms_norm_eps, |
| 772 | + activation=hf_config.hidden_act, |
| 773 | + prefix=key, |
| 774 | + ) |
| 775 | + # suppress var not used error |
| 776 | + assert fwd_context is not None |
| 777 | + vllm_ctx = vllm_config.compilation_config.static_forward_context |
| 778 | + |
| 779 | + with monkeypatch.context() as m: |
| 780 | + |
| 781 | + m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER") |
| 782 | + |
| 783 | + runner = GPUModelRunner(vllm_config, DEVICE) |
| 784 | + kv_cache_spec = runner.get_kv_cache_spec() |
| 785 | + |
| 786 | + available_memory = 5 * GiB_bytes |
| 787 | + kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec, |
| 788 | + available_memory) |
| 789 | + runner.initialize_kv_cache(kv_cache_config) |
| 790 | + |
| 791 | + # random partition of blocks |
| 792 | + # blocks0 will be assigned to attention layers |
| 793 | + # blocks1 will be assigned to mamba layers |
| 794 | + num_blocks = kv_cache_config.num_blocks |
| 795 | + ind = np.arange(num_blocks) |
| 796 | + np.random.shuffle(ind) |
| 797 | + blocks0, blocks1 = ind[:(num_blocks // 2)], ind[(num_blocks // 2):] |
| 798 | + |
| 799 | + attn_shape = vllm_ctx[layer_0].kv_cache[0].shape |
| 800 | + conv_shape = vllm_ctx[layer_2].kv_cache[0][0].shape |
| 801 | + ssm_shape = vllm_ctx[layer_2].kv_cache[0][1].shape |
| 802 | + |
| 803 | + # assert we are using FlashInfer |
| 804 | + assert attn_shape[0] == num_blocks |
| 805 | + |
| 806 | + attn_blocks_constant = torch.full((len(blocks0), *attn_shape[1:]), |
| 807 | + device=DEVICE, |
| 808 | + fill_value=3.33) |
| 809 | + conv_blocks_constant = torch.full((len(blocks1), *conv_shape[1:]), |
| 810 | + device=DEVICE, |
| 811 | + fill_value=6.66) |
| 812 | + ssm_blocks_constant = torch.full((len(blocks1), *ssm_shape[1:]), |
| 813 | + device=DEVICE, |
| 814 | + fill_value=9.99) |
| 815 | + |
| 816 | + # fill all attention blocks with constant |
| 817 | + for layer in [layer_0, layer_1]: |
| 818 | + vllm_ctx[layer].kv_cache[0][ |
| 819 | + blocks0, :] = attn_blocks_constant.detach().clone() |
| 820 | + |
| 821 | + # fill all mamba blocks with constant |
| 822 | + for layer in [layer_2, layer_3, layer_4, layer_5]: |
| 823 | + vllm_ctx[layer].kv_cache[0][0][ |
| 824 | + blocks1, :] = conv_blocks_constant.detach().clone() |
| 825 | + vllm_ctx[layer].kv_cache[0][1][ |
| 826 | + blocks1, :] = ssm_blocks_constant.detach().clone() |
| 827 | + |
| 828 | + # verify attention and mamba contents are correct |
| 829 | + for layer in [layer_0, layer_1]: |
| 830 | + assert torch.equal(vllm_ctx[layer].kv_cache[0][blocks0, :], |
| 831 | + attn_blocks_constant) |
| 832 | + for layer in [layer_2, layer_3, layer_4, layer_5]: |
| 833 | + assert torch.equal(vllm_ctx[layer].kv_cache[0][0][blocks1, :], |
| 834 | + conv_blocks_constant) |
| 835 | + assert torch.equal(vllm_ctx[layer].kv_cache[0][1][blocks1, :], |
| 836 | + ssm_blocks_constant) |
0 commit comments