diff --git a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu index 785d316025e..5f920997934 100644 --- a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu +++ b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu @@ -312,19 +312,20 @@ void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) { // kIsVariableB, kIsVariableC and kHasZ are all set to True to reduce binary size constexpr bool kIsVariableB = true; constexpr bool kIsVariableC = true; - constexpr bool kHasZ = true; BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] { - BOOL_SWITCH(params.query_start_loc_ptr != nullptr , kVarlen, [&] { - using Ktraits = Selective_Scan_fwd_kernel_traits; - constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t); - dim3 grid(params.batch, params.dim / kNRows); - auto kernel = &selective_scan_fwd_kernel; - if (kSmemSize >= 48 * 1024) { - C10_CUDA_CHECK(cudaFuncSetAttribute( - (void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); - } - kernel<<>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); + BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] { + BOOL_SWITCH(params.query_start_loc_ptr != nullptr , kVarlen, [&] { + using Ktraits = Selective_Scan_fwd_kernel_traits; + constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t); + dim3 grid(params.batch, params.dim / kNRows); + auto kernel = &selective_scan_fwd_kernel; + if (kSmemSize >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); }); }); } @@ -612,19 +613,20 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, at::Tensor z, out_z; const bool has_z = z_.has_value(); - TORCH_CHECK(has_z, "has_z = False is disabled in favor of reduced binary size") - z = z_.value(); - TORCH_CHECK(z.scalar_type() == input_type); - TORCH_CHECK(z.is_cuda()); - TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1); - if (varlen){ - CHECK_SHAPE(z, dim, seqlen); - } else { - CHECK_SHAPE(z, batch_size, dim, seqlen); + if (has_z) { + z = z_.value(); + TORCH_CHECK(z.scalar_type() == input_type); + TORCH_CHECK(z.is_cuda()); + TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1); + if (varlen){ + CHECK_SHAPE(z, dim, seqlen); + } else { + CHECK_SHAPE(z, batch_size, dim, seqlen); + } + + out_z = z; } - out_z = z; - // Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout at::Tensor out = delta; TORCH_CHECK(ssm_states.scalar_type() == input_type); @@ -653,4 +655,3 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, selective_scan_fwd_cuda(params, stream); }); } - diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index ddc920aeb2d..eca37a09058 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -374,6 +374,7 @@ Specified using `--task generate`. | `Phi3ForCausalLM` | Phi-4, Phi-3 | `microsoft/Phi-4-mini-instruct`, `microsoft/Phi-4`, `microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, `microsoft/Phi-3-medium-128k-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Phi3SmallForCausalLM` | Phi-3-Small | `microsoft/Phi-3-small-8k-instruct`, `microsoft/Phi-3-small-128k-instruct`, etc. | | ✅︎ | ✅︎ | | `PhiMoEForCausalLM` | Phi-3.5-MoE | `microsoft/Phi-3.5-MoE-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Phi4FlashForCausalLM` | Phi-4-mini-flash-reasoning | `microsoft/microsoft/Phi-4-mini-instruct`, etc. | | | | | `PersimmonForCausalLM` | Persimmon | `adept/persimmon-8b-base`, `adept/persimmon-8b-chat`, etc. | | ✅︎ | ✅︎ | | `Plamo2ForCausalLM` | PLaMo2 | `pfnet/plamo-2-1b`, `pfnet/plamo-2-8b`, etc. | | | | | `QWenLMHeadModel` | Qwen | `Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ | diff --git a/tests/models/registry.py b/tests/models/registry.py index fa10857313a..c10d375683e 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -248,6 +248,10 @@ def check_available_online( "Phi3SmallForCausalLM": _HfExamplesInfo("microsoft/Phi-3-small-8k-instruct", trust_remote_code=True, v0_only=True), + "Phi4FlashForCausalLM": _HfExamplesInfo("microsoft/Phi-4-mini-flash-reasoning", # noqa: E501 + trust_remote_code=True, + v0_only=True, + max_model_len=10240), "PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct", trust_remote_code=True), "Plamo2ForCausalLM": _HfExamplesInfo("pfnet/plamo-2-1b", diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py index 76726c0c820..038717a129e 100644 --- a/tests/models/test_initialization.py +++ b/tests/models/test_initialization.py @@ -95,6 +95,9 @@ def _initialize_kv_caches_v1(self, vllm_config): _initialize_kv_caches_v1), monkeypatch.context() as m): if model_info.v0_only: m.setenv("VLLM_USE_V1", "0") + if model_arch == "Phi4FlashForCausalLM": + # Phi4FlashForCausalLM only supports DIFFERENTIAL_FLASH_ATTN backend + m.setenv("VLLM_ATTENTION_BACKEND", "DIFFERENTIAL_FLASH_ATTN") LLM( model_info.default, tokenizer=model_info.tokenizer, diff --git a/tests/test_utils.py b/tests/test_utils.py index f90715fd751..28acacd2519 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -458,6 +458,31 @@ def test_bind_kv_cache(): assert ctx['layers.2.self_attn'].kv_cache[0] is kv_cache[2] assert ctx['layers.3.self_attn'].kv_cache[0] is kv_cache[3] +def test_bind_kv_cache_kv_sharing(): + from vllm.attention import Attention + + ctx = { + 'layers.0.self_attn': Attention(32, 128, 0.1), + 'layers.1.self_attn': Attention(32, 128, 0.1), + 'layers.2.self_attn': Attention(32, 128, 0.1), + 'layers.3.self_attn': Attention(32, 128, 0.1), + } + kv_cache = [ + torch.zeros((1, )), + torch.zeros((1, )), + torch.zeros((1, )), + torch.zeros((1, )), + ] + shared_kv_cache_layers = { + 'layers.2.self_attn': 'layers.1.self_attn', + 'layers.3.self_attn': 'layers.0.self_attn' + } + bind_kv_cache(ctx, [kv_cache], shared_kv_cache_layers) + assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[0] + assert ctx['layers.1.self_attn'].kv_cache[0] is kv_cache[1] + assert ctx['layers.2.self_attn'].kv_cache[0] is kv_cache[1] + assert ctx['layers.3.self_attn'].kv_cache[0] is kv_cache[0] + def test_bind_kv_cache_non_attention(): from vllm.attention import Attention diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py index fe9738d804c..e4338805f56 100644 --- a/vllm/attention/backends/blocksparse_attn.py +++ b/vllm/attention/backends/blocksparse_attn.py @@ -308,7 +308,8 @@ def __init__( kv_sharing_target_layer_name: Optional[str] = None, ) -> None: if kv_sharing_target_layer_name is not None: - raise NotImplementedError("KV sharing is not supported in V0.") + raise NotImplementedError("KV sharing is not supported in V0 " + "BLOCK_SPARSE_FLASH_ATTN Backend.") assert blocksparse_params is not None assert alibi_slopes is None, ValueError( "Alibi not support for blocksparse flash attention.") diff --git a/vllm/attention/backends/differential_flash_attn.py b/vllm/attention/backends/differential_flash_attn.py new file mode 100644 index 00000000000..7c35e58967d --- /dev/null +++ b/vllm/attention/backends/differential_flash_attn.py @@ -0,0 +1,1000 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""" An implementation of https://arxiv.org/pdf/2410.05258 """ +from collections import defaultdict +from dataclasses import dataclass +from itertools import accumulate +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type + +import torch +from einops import rearrange + +from vllm import _custom_ops as ops +# yapf conflicts with isort for this block +# yapf: disable +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, + AttentionMetadata, + AttentionMetadataBuilder, + AttentionType, + is_quantized_kv_cache) +from vllm.attention.backends.flash_attn import FlashAttentionBackend +# yapf: enable +from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState, + compute_slot_mapping, + compute_slot_mapping_start_idx, + is_all_cross_attn_metadata_set, + is_all_encoder_attn_metadata_set, + is_block_tables_empty) +from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8, + get_flash_attn_version) +from vllm.logger import init_logger +from vllm.multimodal import MultiModalPlaceholderMap +from vllm.utils import async_tensor_h2d, make_tensor_with_pad +from vllm.vllm_flash_attn import (flash_attn_varlen_func, + flash_attn_with_kvcache) + +if TYPE_CHECKING: + from vllm.worker.model_runner import (ModelInputForGPUBuilder, + ModelInputForGPUWithSamplingMetadata) + +logger = init_logger(__name__) + + +class DifferentialFlashAttentionBackend(AttentionBackend): + accept_output_buffer = False + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [32, 64, 96, 128, 160, 192, 224, 256] + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + if block_size % 16 != 0: + raise ValueError("Block size must be a multiple of 16.") + assert num_kv_heads % 2 == 0, "num_kv_heads must be divisible by 2" + return (2, 2, num_blocks, block_size, num_kv_heads // 2, head_size) + + @staticmethod + def get_name() -> str: + return "DIFFERENTIAL_FLASH_ATTN" + + @staticmethod + def get_impl_cls() -> Type["DifferentialFlashAttentionImpl"]: + return DifferentialFlashAttentionImpl + + @staticmethod + def get_metadata_cls() -> Type["DifferentialFlashAttentionMetadata"]: + return DifferentialFlashAttentionMetadata + + @staticmethod + def get_builder_cls() -> Type["DifferentialFlashAttentionMetadataBuilder"]: + return DifferentialFlashAttentionMetadataBuilder + + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + src_key_cache = src_kv_cache[0] + dst_key_cache = dst_kv_cache[0] + ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) + src_value_cache = src_kv_cache[1] + dst_value_cache = dst_kv_cache[1] + ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + key_caches = [kv_cache[0] for kv_cache in kv_caches] + value_caches = [kv_cache[1] for kv_cache in kv_caches] + + ops.copy_blocks(key_caches, value_caches, src_to_dists) + + +@dataclass +class DifferentialFlashAttentionMetadata(AttentionMetadata): + """Metadata for FlashAttentionBackend. + + NOTE: Any python object stored here is not updated when it is + cuda-graph replayed. If you have values that need to be changed + dynamically, it should be stored in tensor. The tensor has to be + updated from `CUDAGraphRunner.forward` API. + """ + # (batch_size,). The sequence length per sequence. Sequence length means + # the computed tokens + new tokens None if it is a decoding. + seq_lens: Optional[List[int]] + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] + + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Maximum sequence length among decode batch. 0 if there are prefill + # requests only. + max_decode_seq_len: int + # (batch_size,) A tensor of context lengths (tokens that are computed + # so far). + context_lens_tensor: Optional[torch.Tensor] + + # (batch_size, max_blocks_per_seq). + # Block addresses per sequence. (Seq id -> list of physical block) + # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks + # in the kv cache. Each block can contain up to block_size tokens. + # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph + # captured. + block_tables: Optional[torch.Tensor] + + # Whether or not if cuda graph is enabled. + # Cuda-graph is currently enabled for decoding only. + # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. + + use_cuda_graph: bool + + # Maximum query length in the batch. + max_query_len: Optional[int] = None + + # Max number of query tokens among request in the batch. + max_decode_query_len: Optional[int] = None + + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + query_start_loc: Optional[torch.Tensor] = None + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] = None + + _cached_prefill_metadata: Optional[ + "DifferentialFlashAttentionMetadata"] = None + _cached_decode_metadata: Optional[ + "DifferentialFlashAttentionMetadata"] = None + + # Begin encoder attn & enc/dec cross-attn fields... + + # Encoder sequence lengths representation + encoder_seq_lens: Optional[List[int]] = None + encoder_seq_lens_tensor: Optional[torch.Tensor] = None + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + encoder_seq_start_loc: Optional[torch.Tensor] = None + # Maximum sequence length among encoder sequences + max_encoder_seq_len: Optional[int] = None + # Number of tokens input to encoder + num_encoder_tokens: Optional[int] = None + + # Cross-attention memory-mapping data structures: slot mapping + # and block tables + cross_slot_mapping: Optional[torch.Tensor] = None + cross_block_tables: Optional[torch.Tensor] = None + + # Cross-layer shared attention block tables + cross_layer_shared_block_tables: Optional[torch.Tensor] = None + + @property + def is_all_encoder_attn_metadata_set(self): + ''' + All attention metadata required for encoder attention is set. + ''' + return is_all_encoder_attn_metadata_set(self) + + @property + def is_all_cross_attn_metadata_set(self): + ''' + All attention metadata required for enc/dec cross-attention is set. + + Superset of encoder attention required metadata. + ''' + return is_all_cross_attn_metadata_set(self) + + @property + def prefill_metadata( + self) -> Optional["DifferentialFlashAttentionMetadata"]: + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + return self._cached_prefill_metadata + + assert ((self.seq_lens is not None) + or (self.encoder_seq_lens is not None)) + assert ((self.seq_lens_tensor is not None) + or (self.encoder_seq_lens_tensor is not None)) + + # Compute some attn_metadata fields which default to None + query_start_loc = (None if self.query_start_loc is None else + self.query_start_loc[:self.num_prefills + 1]) + slot_mapping = (None if self.slot_mapping is None else + self.slot_mapping[:self.num_prefill_tokens]) + seq_lens = (None if self.seq_lens is None else + self.seq_lens[:self.num_prefills]) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[:self.num_prefills]) + seq_start_loc = (None if self.seq_start_loc is None else + self.seq_start_loc[:self.num_prefills + 1]) + context_lens_tensor = (None if self.context_lens_tensor is None else + self.context_lens_tensor[:self.num_prefills]) + block_tables = (None if self.block_tables is None else + self.block_tables[:self.num_prefills]) + cross_layer_shared_block_tables = ( + None if self.cross_layer_shared_block_tables is None else + self.cross_layer_shared_block_tables[:self.num_prefills]) + + self._cached_prefill_metadata = DifferentialFlashAttentionMetadata( + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=self. + multi_modal_placeholder_index_maps, + enable_kv_scales_calculation=self.enable_kv_scales_calculation, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_query_len=0, + max_decode_seq_len=0, + query_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + cross_layer_shared_block_tables=cross_layer_shared_block_tables, + use_cuda_graph=False, + # Begin encoder & cross attn fields below... + encoder_seq_lens=self.encoder_seq_lens, + encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + encoder_seq_start_loc=self.encoder_seq_start_loc, + max_encoder_seq_len=self.max_encoder_seq_len, + cross_slot_mapping=self.cross_slot_mapping, + cross_block_tables=self.cross_block_tables) + return self._cached_prefill_metadata + + @property + def decode_metadata( + self) -> Optional["DifferentialFlashAttentionMetadata"]: + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + return self._cached_decode_metadata + assert ((self.seq_lens_tensor is not None) + or (self.encoder_seq_lens_tensor is not None)) + + # Compute some attn_metadata fields which default to None + slot_mapping = (None if self.slot_mapping is None else + self.slot_mapping[self.num_prefill_tokens:]) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[self.num_prefills:]) + block_tables = (None if self.block_tables is None else + self.block_tables[self.num_prefills:]) + cross_layer_shared_block_tables = ( + None if self.cross_layer_shared_block_tables is None else + self.cross_layer_shared_block_tables[self.num_prefills:]) + self._cached_decode_metadata = DifferentialFlashAttentionMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, + seq_lens=None, + seq_lens_tensor=seq_lens_tensor, + max_decode_query_len=self.max_decode_query_len, + max_query_len=self.max_query_len, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + # Batch may be composed of prefill|decodes, adjust query start + # indices to refer to the start of decodes. E.g. + # in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6]. + query_start_loc=(self.query_start_loc[self.num_prefills:] - + self.query_start_loc[self.num_prefills]) + if self.query_start_loc is not None else None, + seq_start_loc=self.seq_start_loc[self.num_prefills:] + if self.seq_start_loc is not None else None, + context_lens_tensor=None, + block_tables=block_tables, + cross_layer_shared_block_tables=cross_layer_shared_block_tables, + use_cuda_graph=self.use_cuda_graph, + # Begin encoder & cross attn fields below... + encoder_seq_lens=self.encoder_seq_lens, + encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + encoder_seq_start_loc=self.encoder_seq_start_loc, + max_encoder_seq_len=self.max_encoder_seq_len, + cross_slot_mapping=self.cross_slot_mapping, + cross_block_tables=self.cross_block_tables) + return self._cached_decode_metadata + + def advance_step(self, + model_input: "ModelInputForGPUWithSamplingMetadata", + sampled_token_ids: Optional[torch.Tensor], + block_size: int, + num_seqs: int, + num_queries: int, + turn_prefills_into_decodes: bool = False): + """ + Update metadata in-place to advance one decode step. + """ + # When using cudagraph, the num_seqs is padded to the next captured + # batch sized, but num_queries tracks the actual number of requests in + # the batch. For --enforce-eager mode, num_seqs == num_queries + if num_seqs != num_queries: + assert num_seqs > num_queries + assert self.use_cuda_graph + + if turn_prefills_into_decodes: + # When Multi-Step is enabled with Chunked-Prefill, prefills and + # decodes are scheduled together. In the first step, all the + # prefills turn into decodes. This update reflects that + # conversion. + assert self.num_decode_tokens + self.num_prefills == num_seqs + self.num_decode_tokens += self.num_prefills + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.max_prefill_seq_len = 0 + self.max_query_len = 1 + + self.slot_mapping = self.slot_mapping[:num_seqs] + else: + assert self.seq_lens is not None + assert self.max_decode_seq_len == max(self.seq_lens) + + assert self.num_prefills == 0 + assert self.num_prefill_tokens == 0 + assert self.num_decode_tokens == num_seqs + assert self.slot_mapping.shape == (num_seqs, ) + + assert self.seq_lens is not None + assert len(self.seq_lens) == num_seqs + assert self.seq_lens_tensor is not None + assert self.seq_lens_tensor.shape == (num_seqs, ) + assert self.max_query_len == 1 + assert self.max_prefill_seq_len == 0 + + assert self.query_start_loc is not None + assert self.query_start_loc.shape == (num_queries + 1, ) + assert self.seq_start_loc is not None + assert self.seq_start_loc.shape == (num_seqs + 1, ) + + assert self.context_lens_tensor is not None + assert self.context_lens_tensor.shape == (num_queries, ) + + assert self.block_tables is not None + assert self.block_tables.shape[0] == num_seqs + + # Update query lengths. Note that we update only queries and not seqs, + # since tensors may be padded due to captured cuda graph batch size + for i in range(num_queries): + self.seq_lens[i] += 1 + self.max_decode_seq_len = max(self.seq_lens) + + ops.advance_step_flashattn(num_seqs=num_seqs, + num_queries=num_queries, + block_size=block_size, + input_tokens=model_input.input_tokens, + sampled_token_ids=sampled_token_ids, + input_positions=model_input.input_positions, + seq_lens=self.seq_lens_tensor, + slot_mapping=self.slot_mapping, + block_tables=self.block_tables) + + +class DifferentialFlashAttentionMetadataBuilder( + AttentionMetadataBuilder[DifferentialFlashAttentionMetadata]): + + def __init__(self, input_builder: "ModelInputForGPUBuilder"): + self.input_builder = input_builder + self.runner = input_builder.runner + self.sliding_window = input_builder.sliding_window + self.block_size = input_builder.block_size + + def prepare(self): + self.slot_mapping: List[int] = [] + self.prefill_seq_lens: List[int] = [] + self.context_lens: List[int] = [] + self.block_tables: List[List[int]] = [] + self.cross_layer_shared_block_tables: List[List[int]] = [] + self.curr_seq_lens: List[int] = [] + self.multimodal_placeholder_maps: Dict[ + str, + MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.num_decode_tokens = 0 + self.has_prefix_cache_hit = False + + def _add_seq_group( + self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", + chunked_prefill_enabled: bool, prefix_cache_hit: bool): + """Add a sequence group to the metadata. Specifically update/append + 1. context length. + 2. block table. + 3. slot mapping. + """ + # TODO: add support for chunked prefill and prefix caching. + assert not chunked_prefill_enabled, \ + "chunked prefill is not supported for now" + assert not prefix_cache_hit, "prefix caching is not supported for now" + + is_prompt = inter_data.is_prompt + block_tables = inter_data.block_tables + + for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, + curr_sliding_window_block) in zip( + inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], + inter_data.orig_seq_lens, inter_data.seq_lens, + inter_data.query_lens, inter_data.context_lens, + inter_data.curr_sliding_window_blocks): + self.context_lens.append(context_len) + + if is_prompt: + mm_maps = inter_data.multi_modal_placeholder_maps + if mm_maps: + for modality, placeholders in mm_maps.items(): + self.multimodal_placeholder_maps[modality].extend( + placeholders) + + self.num_prefills += 1 + self.num_prefill_tokens += token_len + self.prefill_seq_lens.append(seq_len) + else: + self.num_decode_tokens += query_len + self.curr_seq_lens.append(curr_seq_len) + + # Compute block table. + # TODO(sang): Combine chunked prefill and prefix caching by + # only allowing multiple of block_size chunk size. + # NOTE: This only works for oooooooxxx style attention. + block_table = [] + if prefix_cache_hit: + # NOTE(woosuk): For flash-attn, the block table should + # include the entries for the incoming prefill tokens. + block_table = block_tables[seq_id] + elif ((chunked_prefill_enabled or not is_prompt) + and block_tables is not None): + if curr_sliding_window_block == 0: + block_table = block_tables[seq_id] + else: + block_table = block_tables[seq_id][ + -curr_sliding_window_block:] + self.block_tables.append(block_table) + + cross_layer_shared_block_table = [] + if prefix_cache_hit: + cross_layer_shared_block_table = block_tables[seq_id] + elif block_tables is not None: + if curr_sliding_window_block == 0: + cross_layer_shared_block_table = block_tables[seq_id] + else: + cross_layer_shared_block_table = block_tables[seq_id][ + -curr_sliding_window_block:] + self.cross_layer_shared_block_tables.append( + cross_layer_shared_block_table) + + # Compute slot mapping. + is_profile_run = is_block_tables_empty(block_tables) + start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, + context_len, + self.sliding_window) + compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, + seq_len, context_len, start_idx, + self.block_size, inter_data.block_tables) + + def _get_graph_runner_block_tables(self, num_seqs: int, + block_tables: List[List[int]], + graph_block_tables) -> torch.Tensor: + # The shape of graph_block_tables is + # [max batch size, max context len // block size]. + # max_batch_size, max_blocks = self.runner.graph_block_tables.shape + max_batch_size, max_blocks = graph_block_tables.shape + assert max_batch_size >= num_seqs + + # graph_block_tables = self.runner.graph_block_tables[:num_seqs] + graph_block_tables = graph_block_tables[:num_seqs] + for i, block_table in enumerate(block_tables): + if block_table: + num_blocks = len(block_table) + if num_blocks <= max_blocks: + graph_block_tables[i, :num_blocks] = block_table + else: + # It may be possible to have more blocks allocated due + # to lookahead slots of multi-step, however, they are + # not used anyway, so can be safely ignored. + graph_block_tables[ + i, :max_blocks] = block_table[:max_blocks] + + return torch.from_numpy(graph_block_tables).to( + device=self.runner.device, non_blocking=True) + + def build(self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, batch_size: int): + """Build attention metadata with on-device tensors. + + Args: + seq_lens: The maybe padded sequence lengths of the input sequences. + query_lens: The query lengths of the input sequences. + cuda_graph_pad_size: The padding size for cuda graph. + -1 if cuda graph is not used. + batch_size: The maybe padded batch size. + """ + prefix_cache_hit = any([ + inter_data.prefix_cache_hit + for inter_data in self.input_builder.inter_data_list + ]) + for inter_data in self.input_builder.inter_data_list: + self._add_seq_group(inter_data, + self.input_builder.chunked_prefill_enabled, + prefix_cache_hit) + + device = self.runner.device + use_captured_graph = cuda_graph_pad_size != -1 + + max_query_len = max(query_lens) + decode_query_lens = query_lens[self.num_prefills:] + if len(decode_query_lens) > 0: + max_decode_query_len = max(decode_query_lens) + else: + max_decode_query_len = 1 + max_prefill_seq_len = max(self.prefill_seq_lens, default=0) + max_decode_seq_len = max(self.curr_seq_lens, default=0) + num_decode_tokens = self.num_decode_tokens + query_start_loc = list(accumulate(query_lens, initial=0)) + seq_start_loc = list(accumulate(seq_lens, initial=0)) + + num_seqs = len(seq_lens) + if use_captured_graph: + self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) + self.block_tables.extend([] * cuda_graph_pad_size) + + self.cross_layer_shared_block_tables.extend([] * + cuda_graph_pad_size) + + num_decode_tokens = batch_size - self.num_prefill_tokens + block_tables = self._get_graph_runner_block_tables( + num_seqs, self.block_tables, self.runner.graph_block_tables) + cross_layer_shared_block_tables = \ + self._get_graph_runner_block_tables( + num_seqs, self.cross_layer_shared_block_tables, + self.runner.cross_layer_shared_graph_block_tables) + else: + block_tables = make_tensor_with_pad( + self.block_tables, + pad=0, + dtype=torch.int, + device=device, + ) + cross_layer_shared_block_tables = make_tensor_with_pad( + self.cross_layer_shared_block_tables, + pad=0, + dtype=torch.int, + device=device, + ) + assert max_query_len > 0, ("query_lens: {}".format(query_lens)) + + assert device is not None + context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int, + device, self.runner.pin_memory) + seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, + self.runner.pin_memory) + slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, + device, self.runner.pin_memory) + query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32, + device, + self.runner.pin_memory) + seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32, + device, self.runner.pin_memory) + placeholder_index_maps = { + modality: placeholder_map.index_map() + for modality, placeholder_map in + self.multimodal_placeholder_maps.items() + } + + return DifferentialFlashAttentionMetadata( + num_prefills=self.num_prefills, + slot_mapping=slot_mapping_tensor, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + seq_lens=seq_lens, + multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=True, + seq_lens_tensor=seq_lens_tensor, + max_query_len=max_query_len, + max_decode_query_len=max_decode_query_len, + max_prefill_seq_len=max_prefill_seq_len, + max_decode_seq_len=max_decode_seq_len, + query_start_loc=query_start_loc_tensor, + seq_start_loc=seq_start_loc_tensor, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + cross_layer_shared_block_tables=cross_layer_shared_block_tables, + use_cuda_graph=use_captured_graph, + ) + + +class DifferentialFlashAttentionImpl(AttentionImpl): + """ + If the input tensors contain prompt tokens, the layout is as follows: + |<--------------- num_prefill_tokens ----------------->| + |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->| + + Otherwise, the layout is as follows: + |<----------------- num_decode_tokens ------------------>| + |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->| + + Generation tokens can contain padding when cuda-graph is used. + Currently, prompt tokens don't contain any padding. + + The prompts might have different lengths, while the generation tokens + always have length 1. + + If chunked prefill is enabled, prefill tokens and decode tokens can be + batched together in a flattened 1D query. + + |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->| + |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->| + + Currently, cuda graph is disabled for chunked prefill, meaning there's no + padding between prefill and decode tokens. + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, + kv_sharing_target_layer_name: Optional[str] = None, + use_irope: bool = False, + differential_flash_attention_config: Optional[Dict[str, Any]] = None, + ) -> None: + if differential_flash_attention_config is None: + differential_flash_attention_config = {} + self.differential_flash_attention_config = \ + differential_flash_attention_config + self.used_shared_kv_cache = kv_sharing_target_layer_name is not None + self.kv_sharing_target_layer_name = kv_sharing_target_layer_name + if blocksparse_params is not None: + raise ValueError( + "FlashAttention does not support block-sparse attention.") + if use_irope: + logger.warning( + "Using irope in V0 is not supported yet, it will fall back " + "to global attention for long context.") + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.sliding_window = ((sliding_window - 1, + 0) if sliding_window is not None else (-1, -1)) + self.kv_cache_dtype = kv_cache_dtype + self.vllm_flash_attn_version = get_flash_attn_version( + requires_alibi=self.alibi_slopes is not None) + if is_quantized_kv_cache(self.kv_cache_dtype) and ( + not self.kv_cache_dtype.startswith("fp8") + or not flash_attn_supports_fp8()): + raise NotImplementedError( + f"FlashAttention does not support {self.kv_cache_dtype} " + "kv-cache on this device " + f"(FA supports fp8 = {flash_attn_supports_fp8()}).") + if logits_soft_cap is None: + # In flash-attn, setting logits_soft_cap as 0 means no soft cap. + logits_soft_cap = 0 + self.logits_soft_cap = logits_soft_cap + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + support_head_sizes = FlashAttentionBackend.get_supported_head_sizes() + if head_size not in support_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by FlashAttention. " + f"Supported head sizes are: {support_head_sizes}.") + self.attn_type = attn_type + + self.lambda_full = None + self.subln = self.differential_flash_attention_config["subln"] + + def split_heads(self, x): + # split by num_heads, the stripe pattern is friendly to tensor parallel. + x = rearrange(x, "... (H two) D -> ... H two D", two=2) + x1 = x[..., 0, :] + x2 = x[..., 1, :] + return x1.contiguous(), x2.contiguous() + + def split_kv_cache(self, x): + # split by num_heads, the stripe pattern is friendly to tensor parallel. + if x.numel() == 0: + return torch.empty(0), torch.empty(0) + + x1, x2 = x[0], x[1] + return x1, x2 + + def populate_kv_cache(self, layer: AttentionLayer, key: torch.Tensor, + value: torch.Tensor, kv_cache: torch.Tensor, + attn_metadata: DifferentialFlashAttentionMetadata): + if kv_cache.numel() > 0 and key is not None and value is not None: + updated_slot_mapping = attn_metadata.slot_mapping + torch.ops._C_cache_ops.reshape_and_cache_flash( + key, + value, + kv_cache[0], + kv_cache[1], + updated_slot_mapping.flatten(), + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + + def forward_generate_kv_cache( + self, query: torch.Tensor, key: Optional[torch.Tensor], + value: Optional[torch.Tensor], k_cache: torch.Tensor, + v_cache: torch.Tensor, + attn_metadata: DifferentialFlashAttentionMetadata) -> torch.Tensor: + + head_size = self.head_size + num_heads = self.num_heads // 2 + num_kv_heads = self.num_kv_heads // 2 + + query = query.view(-1, num_heads, head_size) + if key is not None: + assert value is not None + key = key.view(-1, num_kv_heads, head_size) + value = value.view(-1, num_kv_heads, head_size) + else: + assert value is None + + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + assert key.shape[ + 0] == num_prefill_tokens + num_decode_tokens, "key shape mismatch" + assert value.shape[ + 0] == num_prefill_tokens + num_decode_tokens, "value shape mismatch" + + output = torch.empty_like(query) + # Query for decode. KV is not needed because it is already cached. + decode_query = query[num_prefill_tokens:] + # QKV for prefill. + query = query[:num_prefill_tokens] + if key is not None and value is not None: + key = key[:num_prefill_tokens] + value = value[:num_prefill_tokens] + + assert query.shape[0] == num_prefill_tokens, "query shape mismatch" + assert decode_query.shape[ + 0] == num_decode_tokens, "decode query shape mismatch" + + if prefill_meta := attn_metadata.prefill_metadata: + # Prompt run. + if k_cache.numel() == 0 \ + or prefill_meta.block_tables is None \ + or prefill_meta.block_tables.numel() == 0: + # normal attention + prefill_output = flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=prefill_meta.seq_start_loc, + cu_seqlens_k=prefill_meta.seq_start_loc, + max_seqlen_q=prefill_meta.max_prefill_seq_len, + max_seqlen_k=prefill_meta.max_prefill_seq_len, + softmax_scale=self.scale, + causal=True, + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes, + softcap=self.logits_soft_cap, + ) + assert prefill_output.shape == output[: + num_prefill_tokens].shape + output[:num_prefill_tokens] = prefill_output + else: + raise Exception("prefix caching not supported") + + if decode_meta := attn_metadata.decode_metadata: + block_tables_arg = decode_meta.block_tables + try: + output[num_prefill_tokens:] = flash_attn_with_kvcache( + q=decode_query.unsqueeze(1), + k_cache=k_cache, + v_cache=v_cache, + block_table=block_tables_arg, + cache_seqlens=decode_meta.seq_lens_tensor, + softmax_scale=self.scale, + causal=True, + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes, + softcap=self.logits_soft_cap, + ).squeeze(1) + except Exception as e: + logger.error("Error in PagedAttention.forward_decode: %s", + str(e)) + raise e + + # Reshape the output tensor. + return output.view(-1, num_heads, head_size) + + def forward_with_kv_cache_only( + self, + query: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + attn_metadata: DifferentialFlashAttentionMetadata, + ): + if not attn_metadata.decode_metadata: + block_tables_arg = attn_metadata.cross_layer_shared_block_tables + else: + block_tables_arg = attn_metadata.block_tables + + output = flash_attn_with_kvcache( + q=query.unsqueeze(1), + k_cache=k_cache, + v_cache=v_cache, + block_table=block_tables_arg, + cache_seqlens=attn_metadata.seq_lens_tensor, + softmax_scale=self.scale, + causal=True, + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes, + softcap=self.logits_soft_cap, + ).squeeze(1) + return output + + def forward( + self, + layer: AttentionLayer, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: DifferentialFlashAttentionMetadata, + output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with FlashAttention. + + Args: + query: shape = [num_tokens, num_heads, head_size] + key: shape = [num_tokens, num_kv_heads, head_size] + value: shape = [num_tokens, num_kv_heads, head_size] + output: shape = [num_tokens, num_heads, head_size] + kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] + NOTE: kv_cache will be an empty tensor with shape [0] + for profiling run. + attn_metadata: Metadata for attention. + NOTE: It in-place updates the output tensor. + NOTE: FP8 quantization, flash-attn expect the size of + {q,k,v}_descale to be (num_sequences, num_kv_heads). + We use torch's .expand() to avoid duplicating values + """ + if self.lambda_full is None: + self.lambda_init = self.differential_flash_attention_config[ + "lambda_init"] + lambda_q1 = self.differential_flash_attention_config["lambda_q1"] + lambda_k1 = self.differential_flash_attention_config["lambda_k1"] + lambda_q2 = self.differential_flash_attention_config["lambda_q2"] + lambda_k2 = self.differential_flash_attention_config["lambda_k2"] + lambda_1 = torch.exp( + torch.sum(lambda_q1 * lambda_k1, dim=-1).float()).type_as(q) + lambda_2 = torch.exp( + torch.sum(lambda_q2 * lambda_k2, dim=-1).float()).type_as(q) + self.lambda_full = lambda_1 - lambda_2 + self.lambda_init + + if not self.used_shared_kv_cache: # need to generate kv-cache + q = q.view(-1, self.num_heads, self.head_size) + k = k.view(-1, self.num_kv_heads, self.head_size) + v = v.view(-1, self.num_kv_heads, self.head_size) + + q1, q2 = self.split_heads(q) + k1, k2 = self.split_heads(k) + v1, v2 = self.split_heads(v) + + # kv_cache shape is (2, 2, num_blocks, block_size, num_kv_heads // 2, head_size) # noqa: E501 + # Split by half along the first dimension. + kv_cache1, kv_cache2 = self.split_kv_cache(kv_cache) + assert kv_cache1.is_contiguous(), "kv_cache1 is not contiguous" + assert kv_cache2.is_contiguous(), "kv_cache2 is not contiguous" + + if kv_cache1.numel() != 0: + self.populate_kv_cache(layer, k1, v1, kv_cache1, attn_metadata) + self.populate_kv_cache(layer, k2, v2, kv_cache2, attn_metadata) + + key_cache1, value_cache1 = self.split_kv_cache(kv_cache1) + key_cache2, value_cache2 = self.split_kv_cache(kv_cache2) + else: + key_cache1, value_cache1 = torch.empty(0), torch.empty(0) + key_cache2, value_cache2 = torch.empty(0), torch.empty(0) + attn11 = self.forward_generate_kv_cache(q1, k1, v1, key_cache1, + value_cache1, + attn_metadata) + attn12 = self.forward_generate_kv_cache(q1, k1, v2, key_cache1, + value_cache2, + attn_metadata) + attn11 = attn11.view(q1.shape) + attn12 = attn12.view(q1.shape) + attn1 = torch.cat([attn11, attn12], dim=-1) + + attn21 = self.forward_generate_kv_cache(q2, k2, v1, key_cache2, + value_cache1, + attn_metadata) + attn22 = self.forward_generate_kv_cache(q2, k2, v2, key_cache2, + value_cache2, + attn_metadata) + attn21 = attn21.view(q2.shape) + attn22 = attn22.view(q2.shape) + attn2 = torch.cat([attn21, attn22], dim=-1) + + attn = attn1 - self.lambda_full * attn2 + # attn shape (-1, self.num_heads // 2, 2 * self.head_dim) + attn = self.subln(attn) + attn = attn * (1 - self.lambda_init) + # reshape back to 2 * num_head + attn_output = rearrange(attn, + "... H (two D) -> ... (H two) D", + two=2) + + else: # re-use the kv cache, full attention + q = q.view(-1, self.num_heads, self.head_size) + q1, q2 = self.split_heads(q) + # kv_cache shape is (2, num_blocks, block_size, num_kv_heads, head_size) # noqa: E501 + kv_cache1, kv_cache2 = self.split_kv_cache(kv_cache) + key_cache1, value_cache1 = kv_cache1[0], kv_cache1[1] + key_cache2, value_cache2 = kv_cache2[0], kv_cache2[1] + + attn11 = self.forward_with_kv_cache_only(q1, key_cache1, + value_cache1, + attn_metadata) + attn12 = self.forward_with_kv_cache_only(q1, key_cache1, + value_cache2, + attn_metadata) + attn11 = attn11.view(q1.shape) + attn12 = attn12.view(q1.shape) + attn1 = torch.cat([attn11, attn12], dim=-1) + + attn21 = self.forward_with_kv_cache_only(q2, key_cache2, + value_cache1, + attn_metadata) + attn22 = self.forward_with_kv_cache_only(q2, key_cache2, + value_cache2, + attn_metadata) + attn21 = attn21.view(q2.shape) + attn22 = attn22.view(q2.shape) + attn2 = torch.cat([attn21, attn22], dim=-1) + + attn = attn1 - self.lambda_full * attn2 + attn = self.subln(attn) + attn = attn * (1 - self.lambda_init) + # reshape back to 2 * num_head + attn_output = rearrange(attn, + "... H (two D) -> ... (H two) D", + two=2) + attn_output = attn_output.view(-1, self.num_heads * self.head_size) + return attn_output diff --git a/vllm/attention/backends/dual_chunk_flash_attn.py b/vllm/attention/backends/dual_chunk_flash_attn.py index f62a43b441f..40557a4e8f8 100644 --- a/vllm/attention/backends/dual_chunk_flash_attn.py +++ b/vllm/attention/backends/dual_chunk_flash_attn.py @@ -295,7 +295,8 @@ def __init__( dual_chunk_attention_config: Optional[Dict[str, Any]] = None, ) -> None: if kv_sharing_target_layer_name is not None: - raise NotImplementedError("KV sharing is not supported in V0.") + raise NotImplementedError("KV sharing is not supported in V0 " + "DUAL_CHUNK_FLASH_ATTN backend.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index bf8e373802f..20e67eb9b40 100755 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -622,7 +622,8 @@ def __init__( use_irope: bool = False, ) -> None: if kv_sharing_target_layer_name is not None: - raise NotImplementedError("KV sharing is not supported in V0.") + raise NotImplementedError("KV sharing is not supported in V0 " + "FLASH_ATTN backend.") if blocksparse_params is not None: raise ValueError( "FlashAttention does not support block-sparse attention.") diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 5bbe340b143..1f913ad8952 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -1006,7 +1006,8 @@ def __init__( use_irope: bool = False, ) -> None: if kv_sharing_target_layer_name is not None: - raise NotImplementedError("KV sharing is not supported in V0.") + raise NotImplementedError("KV sharing is not supported in V0 " + "FLASHINFER backend.") if use_irope: logger.warning_once( "Using irope in FlashInfer is not supported yet, it will fall" diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index bf778a1e501..b8fdf763a04 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -115,7 +115,8 @@ def __init__( ) -> None: super(AttentionImpl, self).__init__() if kv_sharing_target_layer_name is not None: - raise NotImplementedError("KV sharing is not supported in V0.") + raise NotImplementedError("KV sharing is not supported in V0 " + "HPU_ATTN backend.") if use_irope: logger.warning_once( "Using irope in HPU is not supported yet, it will fall back " diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 0b7783758dd..4653d5267e1 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -501,7 +501,8 @@ def __init__( use_irope: bool = False, ) -> None: if kv_sharing_target_layer_name is not None: - raise NotImplementedError("KV sharing is not supported in V0.") + raise NotImplementedError("KV sharing is not supported in V0 " + "ROCM_FLASH backend.") if use_irope: logger.warning_once( "Using irope in ROCm Flash Attention is not supported yet, it " diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index b583240c73c..3ef79bb6212 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -394,7 +394,8 @@ def __init__( use_irope: bool = False, ) -> None: if kv_sharing_target_layer_name is not None: - raise NotImplementedError("KV sharing is not supported in V0.") + raise NotImplementedError("KV sharing is not supported in V0 " + "XFORMERS backend.") if blocksparse_params is not None: raise ValueError( "XFormers does not support block-sparse attention.") diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 3d5746837be..f9c2d4f4983 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -160,10 +160,6 @@ def __init__( self.attn_type = attn_type if kv_sharing_target_layer_name is not None: - if not envs.VLLM_USE_V1: - raise NotImplementedError( - "Cross-layer KV sharing is not supported in V0.") - validate_kv_sharing_target( prefix, kv_sharing_target_layer_name, diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index 3d01253447c..e93be9bfb16 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -59,11 +59,12 @@ def forward( hidden_states: torch.Tensor, sampling_metadata: Optional[SamplingMetadata] = None, embedding_bias: Optional[torch.Tensor] = None, + prune_hidden_states: bool = True, ) -> Optional[torch.Tensor]: if self.logits_as_input: logits = hidden_states else: - if sampling_metadata is not None: + if sampling_metadata is not None and prune_hidden_states: hidden_states = _prune_hidden_states(hidden_states, sampling_metadata) diff --git a/vllm/model_executor/models/phi4flash.py b/vllm/model_executor/models/phi4flash.py new file mode 100644 index 00000000000..10f8b6552af --- /dev/null +++ b/vllm/model_executor/models/phi4flash.py @@ -0,0 +1,746 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import math +from collections.abc import Iterable +from typing import Optional, Union + +import torch +import torch.nn as nn +from transformers.activations import ACT2FN + +import vllm.envs as envs +from vllm.attention import Attention, AttentionMetadata, AttentionType +from vllm.attention.selector import _Backend +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.forward_context import ForwardContext, get_forward_context +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( + causal_conv1d_fn, causal_conv1d_update) +from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( + selective_scan_fn, selective_state_update) +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid, + SupportsV0Only) +from vllm.model_executor.models.mamba_cache import (MambaCacheManager, + MambaCacheParams) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .utils import make_layers, maybe_prefix + +logger = init_logger(__name__) + + +class SwiGLUActivation(nn.Module): + + def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: + return x1 * nn.functional.silu(x2) + + +class SambaYMLP(nn.Module): + """Gated Linear Unit. + + Reference: + Language Modeling with Gated Convolutional Networks. + https://arxiv.org/pdf/1612.08083v3.pdf. + + """ + + def __init__(self, config): + super().__init__() + + self.config = config + self.fc1 = nn.Linear(config.hidden_size, + 2 * config.intermediate_size, + bias=False) + self.fc2 = nn.Linear(config.intermediate_size, + config.hidden_size, + bias=False) + + self.activation_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states): + y = self.fc1(hidden_states) + gate, y = y.chunk(2, dim=-1) + y = y * self.activation_fn(gate) + return self.fc2(y) + + +def get_virtual_engine(): + forward_context: ForwardContext = get_forward_context() + return forward_context.virtual_engine + + +class SambaYAttention(nn.Module): + + def __init__(self, + config, + layer_idx: Optional[int] = None, + yoco_cross: bool = False, + cache_config: Optional[CacheConfig] = None, + prefix: str = ""): + super().__init__() + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing " + "a `layer_idx` is not recommended and will lead to errors " + "during the forward call if caching is used. Please make " + "sure to provide a `layer_idx` when creating this class.") + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.yoco_cross = yoco_cross + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError("hidden_size must be divisible by num_heads " + f"(got `hidden_size`: {self.hidden_size} and " + f"`num_heads`: {self.num_heads}).") + + op_size = self.num_heads * self.head_dim + 2 * ( + self.num_key_value_heads * self.head_dim) + self.out_proj = nn.Linear(self.num_heads * self.head_dim, + self.hidden_size, + bias=True) + if yoco_cross: + self.Wqkv = nn.Linear(self.hidden_size, + self.num_heads * self.head_dim, + bias=True) + else: + self.Wqkv = nn.Linear(self.hidden_size, op_size, bias=True) + + # disable sliding window for the second half of the model + sliding_window = config.interleaved_sliding_window[layer_idx] + if layer_idx >= config.num_hidden_layers // 2: + assert sliding_window is None, \ + "sliding_window must be none for the second decoder" + else: + assert sliding_window is not None, \ + "sliding_window must be set for the first decoder" + + assert self.num_heads % 2 == 0, 'num_heads should be even' + assert self.num_key_value_heads % 2 == 0, 'num_heads should be even' + + self.lambda_init = self.lambda_init_fn(layer_idx) + self.lambda_q1 = nn.Parameter( + torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, + std=0.1)) + self.lambda_k1 = nn.Parameter( + torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, + std=0.1)) + self.lambda_q2 = nn.Parameter( + torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, + std=0.1)) + self.lambda_k2 = nn.Parameter( + torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, + std=0.1)) + self.subln = nn.RMSNorm(2 * self.head_dim, + eps=1e-5, + elementwise_affine=True) + + params = { + 'differential_flash_attention_config': { + 'lambda_init': self.lambda_init, + 'lambda_q1': self.lambda_q1, + 'lambda_k1': self.lambda_k1, + 'lambda_q2': self.lambda_q2, + 'lambda_k2': self.lambda_k2, + "subln": self.subln, + } + } + + if yoco_cross: + kv_shared_layer_index = config.num_hidden_layers // 2 + 1 + kv_sharing_target_layer_name = \ + f"model.layers.{kv_shared_layer_index}.self_attn.attn" + else: + kv_sharing_target_layer_name = None + + self.attn = Attention( + self.num_heads, + self.head_dim, + self.head_dim**-0.5, + num_kv_heads=self.num_key_value_heads, + cache_config=cache_config, + per_layer_sliding_window=sliding_window, + prefix=f"{prefix}.attn", + attn_type=AttentionType.DECODER, + kv_sharing_target_layer_name=kv_sharing_target_layer_name, + **params) + assert self.attn.backend == _Backend.DIFFERENTIAL_FLASH_ATTN,\ + "DIFFERENTIAL_FLASH_ATTN required" + + def lambda_init_fn(self, depth): + return 0.8 - 0.6 * math.exp(-0.3 * depth) + + def forward( + self, + hidden_states: torch.Tensor, + ): + + if not self.yoco_cross: # need to generate kv-cache + qkv = self.Wqkv(hidden_states) + q, k, v = qkv.split([ + self.hidden_size, self.num_key_value_heads * self.head_dim, + self.num_key_value_heads * self.head_dim + ], + dim=-1) + attn_output = self.attn(q, k, v) + else: # re-use the kv cache, full attention + q = self.Wqkv(hidden_states) + attn_output = self.attn(q, None, None) + attn_output = attn_output.view(-1, self.num_heads * self.head_dim) + return self.out_proj(attn_output) + + +class Phi4Mamba(nn.Module): + + def __init__( + self, + d_model, + d_state=16, + d_conv=4, + expand=2, + dt_rank="auto", + dt_min=0.001, + dt_max=0.1, + dt_init="random", # difference + dt_scale=1.0, # difference + dt_init_floor=1e-4, + conv_bias=True, + bias=False, + use_fast_path=True, # Fused kernel options + layer_idx=None, + device=None, + dtype=None, + yoco_cross=False, + yoco_kv=False, + ): + factory_kwargs = {"params_dtype": dtype} # difference + super().__init__() + self.yoco_cross = yoco_cross + self.yoco_kv = yoco_kv + self.d_model = d_model + self.d_state = d_state + self.d_conv = d_conv + self.expand = expand + self.d_inner = int(self.expand * self.d_model) + self.dt_rank = math.ceil(self.d_model / + 16) if dt_rank == "auto" else dt_rank + self.use_fast_path = use_fast_path + self.layer_idx = layer_idx + self.swiGluActivation = SwiGLUActivation() + if self.yoco_cross: + self.in_proj = MergedColumnParallelLinear(self.d_model, + [self.d_inner], + bias=bias, + **factory_kwargs) + self.out_proj = RowParallelLinear(self.d_inner, + self.d_model, + bias=bias, + **factory_kwargs) + return + self.conv1d = ColumnParallelLinear( + input_size=d_conv, + output_size=self.d_inner, + bias=conv_bias, + params_dtype=dtype, + ) + # unsqueeze to fit conv1d weights shape into the linear weights shape. + # Can't do this in `weight_loader` since it already exists in + # `ColumnParallelLinear` and `set_weight_attrs` + # doesn't allow to override it + self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) + + self.in_proj = MergedColumnParallelLinear( + self.d_model, + [self.d_inner] * 2, + bias=bias, + params_dtype=dtype, + ) + + # selective projection used to make dt, B and C input dependent + self.x_proj = RowParallelLinear( + self.d_inner, + self.dt_rank + self.d_state * 2, + bias=False, + params_dtype=dtype, + ) + + # time step projection (discretization) - + # In the forward we need to apply dt_proj without the bias, + # as the bias is added in the selective scan kernel. + self.dt_proj = ColumnParallelLinear( + self.dt_rank, + self.d_inner, + bias=True, + skip_bias_add=True, + params_dtype=dtype, + ) + + # # D "skip" parameter + # self.D = nn.Parameter(torch.ones(self.d_inner)) # Keep in fp32 + self.A = nn.Parameter( + torch.empty( + self.d_inner, + self.d_state, + dtype=torch.float32, + )) + self.D = nn.Parameter(torch.ones(self.d_inner, dtype=torch.float32)) + + self.out_proj = RowParallelLinear( + self.d_inner, + self.d_model, + bias=bias, + input_is_parallel=True, + params_dtype=dtype, + ) + self.activation = "silu" + + def forward(self, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + mamba_cache_params: MambaCacheParams, + yoco_key_values=None) -> torch.Tensor: + + if self.yoco_cross: + out = self.in_proj(hidden_states)[0] + out = self.swiGluActivation(yoco_key_values, out) + out = self.out_proj(out) + return out[0], yoco_key_values + + # 1. Gated MLP's linear projection + # projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1) + projected_states = self.in_proj( + hidden_states.to(self.in_proj.weight.dtype))[0].transpose(-2, -1) + hidden_states, gate = projected_states.chunk(2, dim=-2) + + # 2. Convolution sequence transformation + conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), + self.conv1d.weight.size(2)) + + if attn_metadata.query_start_loc is not None \ + and attn_metadata.context_lens_tensor is not None: + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + hidden_states = causal_conv1d_fn( + hidden_states, + conv_weights, + self.conv1d.bias, + activation=self.activation, + conv_states=mamba_cache_params.conv_state, + has_initial_state=attn_metadata.context_lens_tensor > 0, + cache_indices=mamba_cache_params.state_indices_tensor, + query_start_loc=attn_metadata.query_start_loc) + else: + hidden_states = causal_conv1d_update( + hidden_states.transpose(0, 1), + mamba_cache_params.conv_state, + conv_weights, + self.conv1d.bias, + self.activation, + conv_state_indices=mamba_cache_params.state_indices_tensor) + hidden_states = hidden_states.transpose(0, 1) + + # 3. State Space Model sequence transformation + # 3.a. input varying initialization of time_step, B and C + ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0] + + time_step, B, C = torch.split( + ssm_parameters, + [self.dt_rank, self.d_state, self.d_state], + dim=-1, + ) + + # Note that Jamba normalizes B, C, and time_step here but Mamba doesn't. + + discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1) + # 3.c perform the recurrence y ← SSM(A, B, C)(x) + time_proj_bias = (self.dt_proj.bias.float() if hasattr( + self.dt_proj, "bias") else None) + + if attn_metadata.query_start_loc is not None \ + and attn_metadata.context_lens_tensor is not None: + scan_outputs = selective_scan_fn( + hidden_states, + mamba_cache_params.ssm_state, + discrete_time_step, + self.A, + B.transpose(-2, -1), + C.transpose(-2, -1), + self.D.float(), + # z, + None if self.yoco_kv else gate, + time_proj_bias, + delta_softplus=True, + cache_indices=mamba_cache_params.state_indices_tensor, + has_initial_state=attn_metadata.context_lens_tensor > 0, + query_start_loc=attn_metadata.query_start_loc) + else: + scan_outputs = selective_state_update( + mamba_cache_params.ssm_state, + hidden_states.transpose(0, 1), + discrete_time_step.transpose(0, 1), + self.A, + B, + C, + self.D, + # z + # gate.transpose(0, 1), + None if self.yoco_kv else gate.transpose(0, 1), + time_proj_bias, + dt_softplus=True, + state_batch_indices=mamba_cache_params.state_indices_tensor) + scan_outputs = scan_outputs.transpose(0, 1) + + # 4. Final linear projection + if self.yoco_kv: + # gate = gate.transpose(-1,-2).contiguous() + yoco_key_values = scan_outputs.transpose(-2, -1) + scan_outputs = self.swiGluActivation(scan_outputs, gate) + + contextualized_states = self.out_proj(scan_outputs.transpose(-2, + -1))[0] + + return contextualized_states, yoco_key_values + + +class SambaYDecoderLayer(nn.Module): + + def __init__( + self, + config, + layer_idx, + cache_config, + prefix: str = "", + ) -> None: + super().__init__() + + self.config = config + self.layer_idx = layer_idx + + self.mlp = SambaYMLP(config) + self.input_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + + self.yoco_mb = False + self.yoco_cross = False + if layer_idx >= config.num_hidden_layers // 2: + self.yoco_mb = True + self.yoco_cross = (layer_idx + >= (config.num_hidden_layers // 2 + 2)) + self.use_mamba = config.mb_per_layer > 0 and \ + layer_idx % config.mb_per_layer == 0 + if self.use_mamba: + factory_kwargs = {"dtype": None} + self.attn = Phi4Mamba(config.hidden_size, + layer_idx=layer_idx, + yoco_cross=self.yoco_cross, + yoco_kv=self.yoco_mb, + **factory_kwargs) + else: + self.attn = SambaYAttention(config, + layer_idx=layer_idx, + yoco_cross=self.yoco_cross, + cache_config=cache_config, + prefix=f"{prefix}.self_attn") + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + positions: torch.Tensor, + attn_metadata: AttentionMetadata, + mamba_cache_params: MambaCacheParams, + ssm_output: Optional[torch.LongTensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if self.use_mamba: + assert mamba_cache_params is not None + else: + assert mamba_cache_params is None + + residual = hidden_states + hidden_states = self.input_layernorm( + hidden_states.to(dtype=self.input_layernorm.weight.dtype)) + + if self.use_mamba: + attn_outputs, ssm_output = self.attn(hidden_states, + attn_metadata, + mamba_cache_params, + yoco_key_values=ssm_output) + residual = residual.to(torch.float32) + else: + attn_outputs = self.attn(hidden_states, ) + hidden_states = residual + attn_outputs + residual = hidden_states + hidden_states = self.post_attention_layernorm( + hidden_states.to(dtype=self.post_attention_layernorm.weight.dtype)) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states, ssm_output + + +class SambaYModel(nn.Module): + + def __init__(self, + config, + cache_config=None, + quant_config=None, + lora_config=None, + prefix: str = "") -> None: + super().__init__() + self.config = config + self.vocab_size = config.vocab_size + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + + # Pipeline parallel is not supported since the second half of + # the layers share the kv cache. + if get_pp_group().world_size != 1: + raise ValueError("Pipeline Parallel not supported") + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: SambaYDecoderLayer(config, + int(prefix.split('.')[-1]), + cache_config, + prefix=prefix), + prefix=f"{prefix}.layers") + self.final_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + attn_metadata: AttentionMetadata, + mamba_cache_params: MambaCacheParams, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + + mamba_state_idx = 0 + ssm_output = None + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + if i == self.config.num_hidden_layers // 2 + 2: + # profile run + kv_cache_idx = self.config.num_hidden_layers // 2 + 1 + cache_layer = self.layers[kv_cache_idx] + kv_cache = cache_layer.attn.attn.kv_cache + if kv_cache[0].numel() == 0: + break + + # Starting from this layer, we do not need to calculate + # the kv cache since we reuse the kv cache from last layer. + # If in prefill phase, we can prune> truncate + # the hidden state to save computation cost. + if attn_metadata.prefill_metadata and not envs.VLLM_USE_V1: + selected_token_indices = torch.cumsum( + attn_metadata.seq_lens_tensor, dim=0) - 1 + hidden_states = hidden_states.index_select( + 0, selected_token_indices) + ssm_output = ssm_output.index_select( + 0, selected_token_indices) + + if layer.use_mamba: + if i < self.config.num_hidden_layers // 2 or \ + not layer.yoco_cross: + mamba_cache = mamba_cache_params.at_layer_idx( + mamba_state_idx) + mamba_state_idx += 1 + else: + mamba_cache = mamba_cache_params.at_layer_idx( + mamba_state_idx - 1) + + hidden_states, ssm_output = layer(hidden_states, + positions, + attn_metadata, + mamba_cache, + ssm_output=ssm_output) + else: + hidden_states, ssm_output = layer( + hidden_states, + positions, + attn_metadata, + None, # mamba_cache_params + ssm_output=ssm_output) + + hidden_states = self.final_layernorm( + hidden_states.to(dtype=self.final_layernorm.weight.dtype)) + return hidden_states + + +class Phi4FlashForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + lora_config = vllm_config.lora_config + quant_config = vllm_config.quant_config + scheduler_config = vllm_config.scheduler_config + self.compilation_config = vllm_config.compilation_config + self.vllm_config = vllm_config + # Prefix caching and chunked prefill is not supported for this model. + assert not cache_config.enable_prefix_caching, \ + "Phi4flash currently does not support prefix caching" + assert not scheduler_config.chunked_prefill_enabled, \ + "Phi4Flash currently does not support prefix caching" + super().__init__() + self.config = config + self.model_config = vllm_config.model_config + self.scheduler_config = scheduler_config + self.model = SambaYModel(config, + cache_config=cache_config, + prefix=maybe_prefix(prefix, "model")) + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=( + DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size), + quant_config=quant_config, + ) + self.embedding_bias = None + # Used to track and store by the Mamba cache between steps. + self.mamba_cache: Optional[MambaCacheManager] = None + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, + logits_as_input=False) + self.sampler = get_sampler() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[torch.Tensor, IntermediateTensors]: + if self.mamba_cache is None: + num_mamba_layers = self.config.num_hidden_layers \ + // 2 // self.config.mb_per_layer + 1 + self.mamba_cache = MambaCacheManager( + self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers, + *self._get_mamba_cache_shape()) + mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) + + attn_metadata = get_forward_context().attn_metadata + # input_ids and hidden_states isn't a one-to-one mapping in prefill + # stage due to YOCO optimization. + hidden_states = self.model(input_ids, positions, attn_metadata, + mamba_cache_params, intermediate_tensors, + inputs_embeds) + return hidden_states + + def _get_mamba_cache_shape( + self + ) -> tuple[Optional[tuple[int, int]], Optional[tuple[int, int]]]: + world_size = get_tensor_model_parallel_world_size() + hidden_size = self.config.hidden_size + mamba_expand = self.config.mamba_expand # 2 + mamba_d_conv = self.config.mamba_d_conv # 4 + mamba_d_state = self.config.mamba_d_state # 16 + conv_state_shape = ( + mamba_expand * hidden_size // world_size, + mamba_d_conv - 1, + ) + temporal_state_shape = ( + mamba_expand * hidden_size // world_size, + mamba_d_state, + ) + return conv_state_shape, temporal_state_shape + + def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): + return self.mamba_cache.copy_inputs_before_cuda_graphs( + input_buffers, **kwargs) + + def get_seqlen_agnostic_capture_inputs(self, batch_size: int): + return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + # If the shape is the same, it means that we have already + # prune hidden states manually. + prune_hidden_states = hidden_states.size( + 0) != sampling_metadata.selected_token_indices.size(0) + processed_logits = self.logits_processor( + self.lm_head, + hidden_states, + sampling_metadata, + self.embedding_bias, + prune_hidden_states=prune_hidden_states) + return processed_logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights( + self, + weights: Iterable[tuple[str, torch.Tensor]], + ): + weights = {name: weight for name, weight in weights} + adjusted_weights = {} + for name, weight in weights.items(): + if "A_log" in name: + name = name.replace("A_log", "A") + weight = -torch.exp(weight.float()) + if "inner_cross_attn." in name: + name = name.replace("inner_cross_attn.", "") + adjusted_weights[name] = weight + adjusted_weights["lm_head.weight"] = weights[ + "model.embed_tokens.weight"] + loaded_params: set[str] = set() + for name, param in self.named_parameters(): + weight = adjusted_weights.get(name) + if weight is not None and weight.shape != param.shape: + logger.warning("Shape mismatch: %s %s %s", name, weight.shape, + param.shape) + loaded_params.add(name) + missing_keys, unexpected_keys = self.load_state_dict(adjusted_weights, + strict=False) + assert len(unexpected_keys) == 0, f"Unexpected keys: {unexpected_keys}" + assert len(missing_keys) == 0, f"Missing keys: {missing_keys}" + return loaded_params diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 17d44fa71d5..5f9b145b661 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -110,6 +110,7 @@ "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"), "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"), "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"), + "Phi4FlashForCausalLM": ("phi4flash", "Phi4FlashForCausalLM"), "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"), "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 00151296a75..878f8f77edf 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -316,6 +316,10 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, logger.info("Using DualChunkFlashAttention backend.") return ("vllm.attention.backends.dual_chunk_flash_attn." "DualChunkFlashAttentionBackend") + elif selected_backend == _Backend.DIFFERENTIAL_FLASH_ATTN: + logger.info("Using DifferentialFlashAttention backend.") + return ("vllm.attention.backends.differential_flash_attn." + "DifferentialFlashAttentionBackend") elif selected_backend == _Backend.FLASH_ATTN: pass elif selected_backend: diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index d3060685e98..ae675bcc8d2 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -60,6 +60,7 @@ class _Backend(enum.Enum): IPEX = enum.auto() BLOCK_SPARSE_FLASH_ATTN = enum.auto() DUAL_CHUNK_FLASH_ATTN = enum.auto() + DIFFERENTIAL_FLASH_ATTN = enum.auto() NO_ATTENTION = enum.auto() FLEX_ATTENTION = enum.auto() diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index 48346c7d6e5..495e359aa6d 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -2888,8 +2888,9 @@ def get_mp_context(): def bind_kv_cache( - ctx: dict[str, Any], - kv_cache: list[list[torch.Tensor]], # [virtual_engine][layer_index] + ctx: dict[str, Any], + kv_cache: list[list[torch.Tensor]], # [virtual_engine][layer_index] + shared_kv_cache_layers: Optional[dict[str, str]] = None ) -> None: # Bind the kv_cache tensor to Attention modules, similar to # ctx[layer_name].kv_cache[ve]=kv_cache[ve][extract_layer_index(layer_name)] @@ -2901,12 +2902,17 @@ def bind_kv_cache( # attention of the same layer (e.g., bart's decoder.layers.1.self_attn # and decoder.layers.1.encoder_attn) is mapped to the same kv cache # tensor + # 5. Some models have attention layers that share kv cache with previous + # layers, this is specified through shared_kv_cache_layers + if shared_kv_cache_layers is None: + shared_kv_cache_layers = {} from vllm.attention import AttentionType from vllm.model_executor.models.utils import extract_layer_index layer_need_kv_cache = [ layer_name for layer_name in ctx if (hasattr(ctx[layer_name], 'attn_type') and ctx[layer_name].attn_type - in (AttentionType.DECODER, AttentionType.ENCODER_DECODER)) + in (AttentionType.DECODER, AttentionType.ENCODER_DECODER)) \ + and ctx[layer_name].kv_sharing_target_layer_name is None ] layer_index_sorted = sorted( set( @@ -2919,6 +2925,12 @@ def bind_kv_cache( assert len(forward_ctx.kv_cache) == len(kv_cache) for ve, ve_kv_cache in enumerate(kv_cache): forward_ctx.kv_cache[ve] = ve_kv_cache[kv_cache_idx] + if shared_kv_cache_layers is not None: + for layer_name, target_layer_name in shared_kv_cache_layers.items(): + assert extract_layer_index(target_layer_name) < \ + extract_layer_index(layer_name), \ + "v0 doesn't support interleaving kv sharing" + ctx[layer_name].kv_cache = ctx[target_layer_name].kv_cache def run_method(obj: Any, method: Union[str, bytes, Callable], args: tuple[Any], diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 9d936f3dbf0..ab926c2d33b 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1112,6 +1112,10 @@ def __init__( (self.max_batchsize_to_capture, self.get_max_block_per_batch()), dtype=np.int32) + self.cross_layer_shared_graph_block_tables = np.zeros( + (self.max_batchsize_to_capture, self.get_max_block_per_batch()), + dtype=np.int32) + # Attention-free but stateful models like Mamba need a placeholder attn # backend, as the attention metadata is needed to manage internal state. # However we must bypass attention selection altogether for some models diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 21e684a3fb5..b2926dbd185 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -9,7 +9,8 @@ import torch.distributed import vllm.envs as envs -from vllm.config import VllmConfig +from vllm.attention.layer import Attention +from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.device_allocator.cumem import CuMemAllocator from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment, @@ -345,8 +346,29 @@ def _init_cache_engine(self): self.cache_engine[ve].gpu_cache for ve in range(self.parallel_config.pipeline_parallel_size) ] + + # Layer pairings for cross-layer KV sharing. + # If an Attention layer `layer_name` is in the keys of this dict, it + # means this layer will perform attention using the keys and values + # from the KV cache of `shared_kv_cache_layers[layer_name]`. + shared_kv_cache_layers: dict[str, str] = {} + + attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) + + for layer_name, attn_module in attn_layers.items(): + if (kv_tgt_layer := + attn_module.kv_sharing_target_layer_name) is not None: + # The layer doesn't need its own KV cache and will use that of + # the target layer. We skip creating a KVCacheSpec for it, so + # that KV cache management logic will act as this layer does + # not exist, and doesn't allocate KV cache for the layer. This + # enables the memory saving of cross-layer kv sharing, allowing + # a given amount of memory to accommodate longer context lengths + # or enable more requests to be processed simultaneously. + shared_kv_cache_layers[layer_name] = kv_tgt_layer + bind_kv_cache(self.compilation_config.static_forward_context, - self.gpu_cache) + self.gpu_cache, shared_kv_cache_layers) def _warm_up_model(self) -> None: # warm up sizes that are not in cudagraph capture sizes,