From b7f8e0c0acdcfca316a7ab5643b6e3bbd6d13e24 Mon Sep 17 00:00:00 2001 From: Congcong Chen Date: Wed, 4 Jun 2025 22:36:59 +0000 Subject: [PATCH 01/24] initial commit Signed-off-by: Congcong Chen --- benchmarks/benchmark_prefix_caching.py | 37 +- csrc/mamba/mamba_ssm/selective_scan_fwd.cu | 49 +- vllm/attention/backends/abstract.py | 1 + vllm/attention/backends/flash_attn.py | 50 +- vllm/config.py | 8 + .../model_executor/layers/logits_processor.py | 3 +- vllm/model_executor/models/phi3samba.py | 1019 +++++++++++++++++ vllm/model_executor/models/registry.py | 1 + vllm/worker/model_runner.py | 4 + 9 files changed, 1127 insertions(+), 45 deletions(-) create mode 100644 vllm/model_executor/models/phi3samba.py diff --git a/benchmarks/benchmark_prefix_caching.py b/benchmarks/benchmark_prefix_caching.py index b5e2613de1c..0b7d23a9bd7 100644 --- a/benchmarks/benchmark_prefix_caching.py +++ b/benchmarks/benchmark_prefix_caching.py @@ -45,13 +45,24 @@ except ImportError: from backend_request_func import get_tokenizer -PROMPT = "You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as fellows. You need to answer my question about the table.\n# Table\n|Opening|Opening|Sl. No.|Film|Cast|Director|Music Director|Notes|\n|----|----|----|----|----|----|----|----|\n|J A N|9|1|Agni Pushpam|Jayabharathi, Kamalahasan|Jeassy|M. K. Arjunan||\n|J A N|16|2|Priyamvada|Mohan Sharma, Lakshmi, KPAC Lalitha|K. S. Sethumadhavan|V. Dakshinamoorthy||\n|J A N|23|3|Yakshagaanam|Madhu, Sheela|Sheela|M. S. Viswanathan||\n|J A N|30|4|Paalkkadal|Sheela, Sharada|T. K. Prasad|A. T. Ummer||\n|F E B|5|5|Amma|Madhu, Srividya|M. Krishnan Nair|M. K. Arjunan||\n|F E B|13|6|Appooppan|Thikkurissi Sukumaran Nair, Kamal Haasan|P. Bhaskaran|M. S. Baburaj||\n|F E B|20|7|Srishti|Chowalloor Krishnankutty, Ravi Alummoodu|K. T. Muhammad|M. S. Baburaj||\n|F E B|20|8|Vanadevatha|Prem Nazir, Madhubala|Yusufali Kechery|G. Devarajan||\n|F E B|27|9|Samasya|Madhu, Kamalahaasan|K. Thankappan|Shyam||\n|F E B|27|10|Yudhabhoomi|K. P. Ummer, Vidhubala|Crossbelt Mani|R. K. Shekhar||\n|M A R|5|11|Seemantha Puthran|Prem Nazir, Jayabharathi|A. B. Raj|M. K. Arjunan||\n|M A R|12|12|Swapnadanam|Rani Chandra, Dr. Mohandas|K. G. George|Bhaskar Chandavarkar||\n|M A R|19|13|Thulavarsham|Prem Nazir, sreedevi, Sudheer|N. Sankaran Nair|V. Dakshinamoorthy||\n|M A R|20|14|Aruthu|Kaviyoor Ponnamma, Kamalahasan|Ravi|G. Devarajan||\n|M A R|26|15|Swimming Pool|Kamal Haasan, M. G. Soman|J. Sasikumar|M. K. Arjunan||\n\n# Question\nWhat' s the content in the (1,1) cells\n" # noqa: E501 - +# PROMPT = "You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as fellows. You need to answer my question about the table.\n# Table\n|Opening|Opening|Sl. No.|Film|Cast|Director|Music Director|Notes|\n|----|----|----|----|----|----|----|----|\n|J A N|9|1|Agni Pushpam|Jayabharathi, Kamalahasan|Jeassy|M. K. Arjunan||\n|J A N|16|2|Priyamvada|Mohan Sharma, Lakshmi, KPAC Lalitha|K. S. Sethumadhavan|V. Dakshinamoorthy||\n|J A N|23|3|Yakshagaanam|Madhu, Sheela|Sheela|M. S. Viswanathan||\n|J A N|30|4|Paalkkadal|Sheela, Sharada|T. K. Prasad|A. T. Ummer||\n|F E B|5|5|Amma|Madhu, Srividya|M. Krishnan Nair|M. K. Arjunan||\n|F E B|13|6|Appooppan|Thikkurissi Sukumaran Nair, Kamal Haasan|P. Bhaskaran|M. S. Baburaj||\n|F E B|20|7|Srishti|Chowalloor Krishnankutty, Ravi Alummoodu|K. T. Muhammad|M. S. Baburaj||\n|F E B|20|8|Vanadevatha|Prem Nazir, Madhubala|Yusufali Kechery|G. Devarajan||\n|F E B|27|9|Samasya|Madhu, Kamalahaasan|K. Thankappan|Shyam||\n|F E B|27|10|Yudhabhoomi|K. P. Ummer, Vidhubala|Crossbelt Mani|R. K. Shekhar||\n|M A R|5|11|Seemantha Puthran|Prem Nazir, Jayabharathi|A. B. Raj|M. K. Arjunan||\n|M A R|12|12|Swapnadanam|Rani Chandra, Dr. Mohandas|K. G. George|Bhaskar Chandavarkar||\n|M A R|19|13|Thulavarsham|Prem Nazir, sreedevi, Sudheer|N. Sankaran Nair|V. Dakshinamoorthy||\n|M A R|20|14|Aruthu|Kaviyoor Ponnamma, Kamalahasan|Ravi|G. Devarajan||\n|M A R|26|15|Swimming Pool|Kamal Haasan, M. G. Soman|J. Sasikumar|M. K. Arjunan||\n\n# Question\nWhat' s the content in the (1,1) cells\n" # noqa: E501 +# PROMPT = "Question: Who is bill gates?\n\nAnswer:" +# content = """China officially the People's Republic of China (PRC), is a country in East Asia. With a population exceeding 1.4 billion, it is the second-most populous country after India, representing 17.4% of the world population. China spans the equivalent of five time zones and borders fourteen countries by land[k] across an area of nearly 9.6 million square kilometers (3,700,000 sq mi), making it the third-largest country by total land area.[l] The country is divided into 33 province-level divisions: 22 provinces,[m] five autonomous regions, four municipalities, and two semi-autonomous special administrative regions. Beijing is the country's capital, while Shanghai is its most populous city by urban area and largest financial center. China is considered one of the cradles of civilization: the first human inhabitants in the region arrived during the Paleolithic. By the late 2nd millennium BCE, the earliest dynastic states had emerged in the Yellow River basin. The 8th–3rd centuries BCE saw a breakdown in the authority of the Zhou dynasty, accompanied by the emergence of administrative and military techniques, literature, philosophy, and historiography. In 221 BCE, China was unified under an emperor, ushering in more than two millennia of imperial dynasties including the Qin, Han, Tang, Yuan, Ming, and Qing. With the invention of gunpowder and paper, the establishment of the Silk Road, and the building of the Great Wall, Chinese culture flourished and has heavily influenced both its neighbors and lands further afield. However, China began to cede parts of the country in the late 19th century to various European powers by a series of unequal treaties. After decades of Qing China on the decline, the 1911 Revolution overthrew the Qing dynasty and the monarchy and the Republic of China (ROC) was established the following year. The country under the nascent Beiyang government was unstable and ultimately fragmented during the Warlord Era, which was ended upon the Northern Expedition conducted by the Kuomintang (KMT) to reunify the country. The Chinese Civil War began in 1927, when KMT forces purged members of the rival Chinese Communist Party (CCP), who proceeded to engage in sporadic fighting against the KMT-led Nationalist government. Following the country's invasion by the Empire of Japan in 1937, the CCP and KMT formed the Second United Front to fight the Japanese. The Second Sino-Japanese War eventually ended in a Chinese victory; however, the CCP and the KMT resumed their civil war as soon as the war ended. In 1949, the resurgent Communists established control over most of the country, proclaiming the People's Republic of China and forcing the Nationalist government to retreat to the island of Taiwan. The country was split, with both sides claiming to be the sole legitimate government of China. Following the implementation of land reforms, further attempts by the PRC to realize communism failed: the Great Leap Forward was largely responsible for the Great Chinese Famine that ended with millions of Chinese people having died, and the subsequent Cultural Revolution was a period of social turmoil and persecution characterized by Maoist populism. Following the Sino-Soviet split, the Shanghai Communiqué in 1972 would precipitate the normalization of relations with the United States. Economic reforms that began in 1978 moved the country away from a socialist planned economy towards an increasingly capitalist market economy, spurring significant economic growth. A movement for increased democracy and liberalization stalled after the Tiananmen Square protests and massacre in 1989.""" +# PROMPT = f'{content}' +PROMPT = "Question: Tell me about Seatttle?\n\nAnswer:" def test_prefix(llm=None, sampling_params=None, prompts=None): start_time = time.time() - llm.generate(prompts, sampling_params=sampling_params) + # llm.generate(prompts, sampling_params=sampling_params) + outputs = llm.generate(prompts, sampling_params=sampling_params) + # Print the outputs. + generated_texts = [] + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text.strip() + generated_texts.append(generated_text) + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") end_time = time.time() print(f"cost time {end_time - start_time}") @@ -136,16 +147,16 @@ def sample_requests_from_random( min_len, max_len = input_length_range for i in range(num_requests): - unique_part_token_ids = sample_tokens( - tokenizer, random.randint(min_len - prefix_len, max_len - prefix_len) - ) - prompt_token_ids = prefix_token_ids + unique_part_token_ids - prompt = tokenizer.decode(prompt_token_ids) - prompt_len = len(prompt_token_ids) - assert min_len <= prompt_len <= max_len, ( - f"prompt_len {prompt_len} out of range {min_len}:{max_len}" - ) - requests.append(Request(prompt, prompt_len, fixed_output_len)) + # unique_part_token_ids = sample_tokens( + # tokenizer, + # random.randint(min_len - prefix_len, max_len - prefix_len)) + # prompt_token_ids = prefix_token_ids + unique_part_token_ids + # prompt = tokenizer.decode(prompt_token_ids) + # prompt_len = len(prompt_token_ids) + # assert (min_len <= prompt_len <= max_len + # ), f"prompt_len {prompt_len} out of range {min_len}:{max_len}" + + requests.append(Request(PROMPT, 10, fixed_output_len)) return requests diff --git a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu index 785d316025e..5f920997934 100644 --- a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu +++ b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu @@ -312,19 +312,20 @@ void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) { // kIsVariableB, kIsVariableC and kHasZ are all set to True to reduce binary size constexpr bool kIsVariableB = true; constexpr bool kIsVariableC = true; - constexpr bool kHasZ = true; BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] { - BOOL_SWITCH(params.query_start_loc_ptr != nullptr , kVarlen, [&] { - using Ktraits = Selective_Scan_fwd_kernel_traits; - constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t); - dim3 grid(params.batch, params.dim / kNRows); - auto kernel = &selective_scan_fwd_kernel; - if (kSmemSize >= 48 * 1024) { - C10_CUDA_CHECK(cudaFuncSetAttribute( - (void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); - } - kernel<<>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); + BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] { + BOOL_SWITCH(params.query_start_loc_ptr != nullptr , kVarlen, [&] { + using Ktraits = Selective_Scan_fwd_kernel_traits; + constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t); + dim3 grid(params.batch, params.dim / kNRows); + auto kernel = &selective_scan_fwd_kernel; + if (kSmemSize >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); }); }); } @@ -612,19 +613,20 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, at::Tensor z, out_z; const bool has_z = z_.has_value(); - TORCH_CHECK(has_z, "has_z = False is disabled in favor of reduced binary size") - z = z_.value(); - TORCH_CHECK(z.scalar_type() == input_type); - TORCH_CHECK(z.is_cuda()); - TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1); - if (varlen){ - CHECK_SHAPE(z, dim, seqlen); - } else { - CHECK_SHAPE(z, batch_size, dim, seqlen); + if (has_z) { + z = z_.value(); + TORCH_CHECK(z.scalar_type() == input_type); + TORCH_CHECK(z.is_cuda()); + TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1); + if (varlen){ + CHECK_SHAPE(z, dim, seqlen); + } else { + CHECK_SHAPE(z, batch_size, dim, seqlen); + } + + out_z = z; } - out_z = z; - // Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout at::Tensor out = delta; TORCH_CHECK(ssm_states.scalar_type() == input_type); @@ -653,4 +655,3 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, selective_scan_fwd_cuda(params, stream); }); } - diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 05c098a58a0..48af715edc7 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -32,6 +32,7 @@ class AttentionType: ENCODER_ONLY = "encoder_only" # Attention between dec. Q and enc. K/V for encoder-decoder ENCODER_DECODER = "encoder_decoder" + DECODER_DECODER = "decoder_decoder" # Attention layer that reuse kv cache class AttentionBackend(ABC): diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index bf8e373802f..410377f9a97 100755 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -75,7 +75,8 @@ def get_kv_cache_shape( ) -> Tuple[int, ...]: if block_size % 16 != 0: raise ValueError("Block size must be a multiple of 16.") - return (2, num_blocks, block_size, num_kv_heads, head_size) + # return (2, num_blocks, block_size, num_kv_heads, head_size) + return (2, 2, num_blocks, block_size, num_kv_heads // 2, head_size) @staticmethod def swap_blocks( @@ -185,6 +186,9 @@ class FlashAttentionMetadata(AttentionMetadata): cross_slot_mapping: Optional[torch.Tensor] = None cross_block_tables: Optional[torch.Tensor] = None + # Cross-layer shared attention block tables + cross_layer_shared_block_tables: Optional[torch.Tensor] = None + @property def is_all_encoder_attn_metadata_set(self): ''' @@ -229,7 +233,9 @@ def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]: self.context_lens_tensor[:self.num_prefills]) block_tables = (None if self.block_tables is None else self.block_tables[:self.num_prefills]) - + cross_layer_shared_block_tables = (None if self.cross_layer_shared_block_tables is None else + self.cross_layer_shared_block_tables[:self.num_prefills]) + self._cached_prefill_metadata = FlashAttentionMetadata( num_prefills=self.num_prefills, num_prefill_tokens=self.num_prefill_tokens, @@ -248,6 +254,7 @@ def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]: seq_start_loc=seq_start_loc, context_lens_tensor=context_lens_tensor, block_tables=block_tables, + cross_layer_shared_block_tables=cross_layer_shared_block_tables, use_cuda_graph=False, # Begin encoder & cross attn fields below... encoder_seq_lens=self.encoder_seq_lens, @@ -275,7 +282,8 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: self.seq_lens_tensor[self.num_prefills:]) block_tables = (None if self.block_tables is None else self.block_tables[self.num_prefills:]) - + cross_layer_shared_block_tables = (None if self.cross_layer_shared_block_tables is None else + self.cross_layer_shared_block_tables[self.num_prefills:]) self._cached_decode_metadata = FlashAttentionMetadata( num_prefills=0, num_prefill_tokens=0, @@ -299,6 +307,7 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: if self.seq_start_loc is not None else None, context_lens_tensor=None, block_tables=block_tables, + cross_layer_shared_block_tables=cross_layer_shared_block_tables, use_cuda_graph=self.use_cuda_graph, # Begin encoder & cross attn fields below... encoder_seq_lens=self.encoder_seq_lens, @@ -397,6 +406,7 @@ def prepare(self): self.prefill_seq_lens: List[int] = [] self.context_lens: List[int] = [] self.block_tables: List[List[int]] = [] + self.cross_layer_shared_block_tables: List[List[int]] = [] self.curr_seq_lens: List[int] = [] self.multimodal_placeholder_maps: Dict[ str, @@ -457,6 +467,17 @@ def _add_seq_group( -curr_sliding_window_block:] self.block_tables.append(block_table) + cross_layer_shared_block_table = [] + if prefix_cache_hit: + cross_layer_shared_block_table = block_tables[seq_id] + elif block_tables is not None: + if curr_sliding_window_block == 0: + cross_layer_shared_block_table = block_tables[seq_id] + else: + cross_layer_shared_block_table = block_tables[seq_id][ + -curr_sliding_window_block:] + self.cross_layer_shared_block_tables.append(cross_layer_shared_block_table) + # Compute slot mapping. is_profile_run = is_block_tables_empty(block_tables) start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, @@ -468,13 +489,16 @@ def _add_seq_group( def _get_graph_runner_block_tables( self, num_seqs: int, - block_tables: List[List[int]]) -> torch.Tensor: + block_tables: List[List[int]], + graph_block_tables) -> torch.Tensor: # The shape of graph_block_tables is # [max batch size, max context len // block size]. - max_batch_size, max_blocks = self.runner.graph_block_tables.shape + # max_batch_size, max_blocks = self.runner.graph_block_tables.shape + max_batch_size, max_blocks = graph_block_tables.shape assert max_batch_size >= num_seqs - graph_block_tables = self.runner.graph_block_tables[:num_seqs] + # graph_block_tables = self.runner.graph_block_tables[:num_seqs] + graph_block_tables = graph_block_tables[:num_seqs] for i, block_table in enumerate(block_tables): if block_table: num_blocks = len(block_table) @@ -529,9 +553,14 @@ def build(self, seq_lens: List[int], query_lens: List[int], if use_captured_graph: self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) self.block_tables.extend([] * cuda_graph_pad_size) + + self.cross_layer_shared_block_tables.extend([] * cuda_graph_pad_size) + num_decode_tokens = batch_size - self.num_prefill_tokens block_tables = self._get_graph_runner_block_tables( - num_seqs, self.block_tables) + num_seqs, self.block_tables, self.runner.graph_block_tables) + cross_layer_shared_block_tables = self._get_graph_runner_block_tables( + num_seqs, self.cross_layer_shared_block_tables, self.runner.cross_layer_shared_graph_block_tables) else: block_tables = make_tensor_with_pad( self.block_tables, @@ -539,6 +568,12 @@ def build(self, seq_lens: List[int], query_lens: List[int], dtype=torch.int, device=device, ) + cross_layer_shared_block_tables = make_tensor_with_pad( + self.cross_layer_shared_block_tables, + pad=0, + dtype=torch.int, + device=device, + ) assert max_query_len > 0, ("query_lens: {}".format(query_lens)) assert device is not None @@ -576,6 +611,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], seq_start_loc=seq_start_loc_tensor, context_lens_tensor=context_lens_tensor, block_tables=block_tables, + cross_layer_shared_block_tables=cross_layer_shared_block_tables, use_cuda_graph=use_captured_graph, ) diff --git a/vllm/config.py b/vllm/config.py index b1f7f9e57a7..3eda6b85881 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -989,6 +989,14 @@ def _verify_cuda_graph(self) -> None: "to eager mode.", self.hf_config.model_type) self.enforce_eager = True + RECOMMENDED_MODEL_SUPPORTS_CUDA_GRAPH = ['phi3samba'] + if (self.hf_config.model_type in RECOMMENDED_MODEL_SUPPORTS_CUDA_GRAPH + and not self.enforce_eager and self.max_seq_len_to_capture < self.max_model_len): + logger.warning( + "%s model performs best with the CUDA graph explicitly enabled. Set `--max-seq-len-to-capture <#>` " + "when starting vLLM.", self.hf_config.model_type) + + def _verify_bnb_config(self) -> None: """ The current version of bitsandbytes (0.46.1) with 8-bit models does not diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index 3d01253447c..e93be9bfb16 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -59,11 +59,12 @@ def forward( hidden_states: torch.Tensor, sampling_metadata: Optional[SamplingMetadata] = None, embedding_bias: Optional[torch.Tensor] = None, + prune_hidden_states: bool = True, ) -> Optional[torch.Tensor]: if self.logits_as_input: logits = hidden_states else: - if sampling_metadata is not None: + if sampling_metadata is not None and prune_hidden_states: hidden_states = _prune_hidden_states(hidden_states, sampling_metadata) diff --git a/vllm/model_executor/models/phi3samba.py b/vllm/model_executor/models/phi3samba.py new file mode 100644 index 00000000000..7ca88ee865f --- /dev/null +++ b/vllm/model_executor/models/phi3samba.py @@ -0,0 +1,1019 @@ +from typing import List, Optional, Tuple, Union, Iterable, Dict +import math +import copy + +import torch +import torch.nn as nn + +from einops import rearrange +from transformers.activations import ACT2FN +from typing import Iterable, List, Optional, Set, Tuple, Union + +from vllm.config import CacheConfig, VllmConfig +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_world_size) +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + RowParallelLinear, + ColumnParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.model_executor.models.mamba_cache import (MambaCacheManager, + MambaCacheParams) +from vllm.model_executor.models.interfaces import (HasInnerState, + IsHybrid, SupportsV0Only) +from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( + causal_conv1d_fn, causal_conv1d_update) +from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( + selective_scan_fn, selective_state_update) +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionMetadata, AttentionType) +from vllm.vllm_flash_attn import (flash_attn_varlen_func, + flash_attn_with_kvcache) + +from vllm.logger import init_logger +from .utils import (maybe_prefix, make_layers) +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.forward_context import ForwardContext, get_forward_context +from vllm.config import CacheConfig, get_current_vllm_config +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + +logger = init_logger(__name__) + + +class SwiGLUActivation(nn.Module): + + def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: + # print(f"x1 shape: {x1.shape}, x2 shape: {x2.shape}") + return x1 * nn.functional.silu(x2) + +class SambaMLP(nn.Module): + """Gated Linear Unit. + + Reference: + Language Modeling with Gated Convolutional Networks. + https://arxiv.org/pdf/1612.08083v3.pdf. + + """ + + def __init__(self, config): + super().__init__() + + self.config = config + self.fc1 = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) + + self.activation_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states): + y = self.fc1(hidden_states) + gate, y = y.chunk(2, dim=-1) + y = y * self.activation_fn(gate) + return self.fc2(y) + + +class SambaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, + config, + layer_idx: Optional[int] = None, + yoco_cross: bool = False, + cache_config: Optional[CacheConfig] = None, + prefix: str = ""): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + self.yoco_cross = yoco_cross + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + op_size = self.num_heads * self.head_dim + 2 * (self.num_key_value_heads * self.head_dim) + self.out_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=True) + if yoco_cross: + self.Wqkv = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) + else: + self.Wqkv = nn.Linear(self.hidden_size, op_size, bias=True) + + assert self.config.attention_dropout == 0.0, 'Attention dropout is not supported for now' + + # disable sliding window for the second half of the model + sliding_window = config.interleaved_sliding_window[layer_idx] + if layer_idx >= config.num_hidden_layers // 2 or layer_idx % 2 == 0: + assert sliding_window == None, "sliding_window is not none" + + assert self.num_heads % 2 == 0, 'num_heads should be even' + assert self.num_key_value_heads % 2 == 0, 'num_heads should be even' + + self.attn = Attention( + self.num_heads//2, + self.head_dim, + self.head_dim**-0.5, + num_kv_heads=self.num_key_value_heads//2, + cache_config=cache_config, + per_layer_sliding_window=sliding_window, + prefix=f"{prefix}.attn", + attn_type=AttentionType.DECODER_DECODER if self.yoco_cross else AttentionType.DECODER + ) + self.subln = nn.RMSNorm(2 * self.head_dim, eps=1e-5, elementwise_affine=True) + + self.lambda_init = self.lambda_init_fn(layer_idx) + self.lambda_q1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1)) + self.lambda_k1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1)) + self.lambda_q2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1)) + self.lambda_k2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1)) + + self._k_scale = torch.tensor(1.0, dtype=torch.float32) + self._v_scale = torch.tensor(1.0, dtype=torch.float32) + + def lambda_init_fn(self, depth): + return 0.8 - 0.6 * math.exp(-0.3 * depth) + + + def split_heads(self, x): + # split by num_heads, the stripe pattern is friendly to tensor parallel. + x = rearrange(x, "... (H two) D -> ... H two D", two=2) + x1 = x[..., 0, :] + x2 = x[..., 1, :] + return x1.contiguous(), x2.contiguous() + + def split_kv_cache(self, x): + # split by num_heads, the stripe pattern is friendly to tensor parallel. + if x.numel() == 0: + return torch.empty(0), torch.empty(0) + + x1, x2 = x[0], x[1] + return x1, x2 + + def forward_decode( + self, + query: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ): + if not attn_metadata.decode_metadata: + block_tables_arg = attn_metadata.cross_layer_shared_block_tables + else: + block_tables_arg = attn_metadata.block_tables + + output = flash_attn_with_kvcache( + q=query.unsqueeze(1), + k_cache=k_cache, + v_cache=v_cache, + block_table=block_tables_arg, + cache_seqlens=attn_metadata.seq_lens_tensor, + softmax_scale=self.attn.impl.scale, + causal=True, + window_size=self.attn.impl.sliding_window, + alibi_slopes=self.attn.impl.alibi_slopes, + softcap=self.attn.impl.logits_soft_cap, + ).squeeze(1) + return output + + def populate_kv_cache(self, + key, + value, + kv_cache, + attn_metadata): + if (kv_cache.numel() > 0): + if (key is not None) and (value is not None): + updated_slot_mapping = attn_metadata.slot_mapping + # previous_key_cache_sum = key_cache.sum() + # previous_value_cache_sum = value_cache.sum() + + torch.ops._C_cache_ops.reshape_and_cache_flash( + key, + value, + kv_cache[0], + kv_cache[1], + updated_slot_mapping.flatten(), + self.attn.impl.kv_cache_dtype, + self._k_scale, + self._v_scale, + ) + # assert key_cache.sum() - previous_key_cache_sum == key.sum(), "key_cache sum mismatch" + # assert value_cache.sum() - previous_value_cache_sum == value.sum(), "value_cache sum mismatch" + # if key_cache.sum() - previous_key_cache_sum != key.sum(): + # print("key_cache sum mismatch") + # if value_cache.sum() - previous_value_cache_sum != value.sum(): + # print("value_cache sum mismatch") + + def forward_customized( + self, + query: torch.Tensor, + key: Optional[torch.Tensor], + value: Optional[torch.Tensor], + k_cache: torch.Tensor, + v_cache: torch.Tensor, + attn_metadata: AttentionMetadata + ) -> torch.Tensor: + + head_size = self.head_dim + num_heads = self.num_heads // 2 + num_kv_heads = self.num_key_value_heads // 2 + + query = query.view(-1, num_heads, head_size) + if key is not None: + assert value is not None + key = key.view(-1, num_kv_heads, head_size) + value = value.view(-1, num_kv_heads, head_size) + else: + assert value is None + + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + assert key.shape[0] == num_prefill_tokens + num_decode_tokens, "key shape mismatch" + assert value.shape[0] == num_prefill_tokens + num_decode_tokens, "value shape mismatch" + + output = torch.empty_like(query) + # Query for decode. KV is not needed because it is already cached. + decode_query = query[num_prefill_tokens:] + # QKV for prefill. + query = query[:num_prefill_tokens] + if key is not None and value is not None: + key = key[:num_prefill_tokens] + value = value[:num_prefill_tokens] + + assert query.shape[0] == num_prefill_tokens, "query shape mismatch" + assert decode_query.shape[0] == num_decode_tokens, "decode query shape mismatch" + + if prefill_meta := attn_metadata.prefill_metadata: + # Prompt run. + if k_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0: + # normal attention + prefill_output = flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=prefill_meta.seq_start_loc, + cu_seqlens_k=prefill_meta.seq_start_loc, + max_seqlen_q=prefill_meta.max_prefill_seq_len, + max_seqlen_k=prefill_meta.max_prefill_seq_len, + softmax_scale=self.attn.impl.scale, + causal=True, + window_size=self.attn.impl.sliding_window, + alibi_slopes=self.attn.impl.alibi_slopes, + softcap=self.attn.impl.logits_soft_cap, + ) + assert prefill_output.shape == output[:num_prefill_tokens].shape + output[:num_prefill_tokens] = prefill_output + else: + raise Exception("prefix caching not supported") + + if decode_meta := attn_metadata.decode_metadata: + block_tables_arg = decode_meta.block_tables + try: + output[num_prefill_tokens:] = flash_attn_with_kvcache( + q=decode_query.unsqueeze(1), + k_cache=k_cache, + v_cache=v_cache, + block_table=block_tables_arg, + cache_seqlens=decode_meta.seq_lens_tensor, + softmax_scale=self.attn.impl.scale, + causal=True, + window_size=self.attn.impl.sliding_window, + alibi_slopes=self.attn.impl.alibi_slopes, + softcap=self.attn.impl.logits_soft_cap, + ).squeeze(1) + except Exception as e: + logger.error( + f"Error in PagedAttention.forward_decode: {str(e)}") + raise e + + # Reshape the output tensor. + return output.view(-1, num_heads, head_size) + + def forward( + self, + hidden_states: torch.Tensor, + positions: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ): + + if not self.yoco_cross: # need to generate kv-cache + qkv = self.Wqkv(hidden_states) + q, k, v = qkv.split([self.hidden_size, self.num_key_value_heads * self.head_dim, self.num_key_value_heads * self.head_dim], dim=-1) + # q, k = self.rotary_emb(positions, q, k) + # reshape + q = q.view(-1, self.num_heads, self.head_dim) + k = k.view(-1, self.num_key_value_heads, self.head_dim) + v = v.view(-1, self.num_key_value_heads, self.head_dim) + + q1, q2 = self.split_heads(q) + k1, k2 = self.split_heads(k) + v1, v2 = self.split_heads(v) + + # kv_cache shape is (2, 2, num_blocks, block_size * num_kv_heads // 2 * head_size) + # Split by half along the first dimension. + kv_cache1, kv_cache2 = self.split_kv_cache(kv_cache) + assert kv_cache1.is_contiguous(), "kv_cache1 is not contiguous" + assert kv_cache2.is_contiguous(), "kv_cache2 is not contiguous" + + if kv_cache1.numel() != 0: + self.populate_kv_cache(k1, v1, kv_cache1, attn_metadata) + self.populate_kv_cache(k2, v2, kv_cache2, attn_metadata) + + key_cache1, value_cache1 = self.split_kv_cache(kv_cache1) + key_cache2, value_cache2 = self.split_kv_cache(kv_cache2) + else: + key_cache1, value_cache1 = torch.empty(0), torch.empty(0) + key_cache2, value_cache2 = torch.empty(0), torch.empty(0) + attn11 = self.forward_customized(q1, k1, v1, key_cache1, value_cache1, attn_metadata) + attn12 = self.forward_customized(q1, k1, v2, key_cache1, value_cache2, attn_metadata) + attn11 = attn11.view(q1.shape) + attn12 = attn12.view(q1.shape) + attn1 = torch.cat([attn11, attn12], dim=-1) + + attn21 = self.forward_customized(q2, k2, v1, key_cache2, value_cache1, attn_metadata) + attn22 = self.forward_customized(q2, k2, v2, key_cache2, value_cache2, attn_metadata) + attn21 = attn21.view(q2.shape) + attn22 = attn22.view(q2.shape) + attn2 = torch.cat([attn21, attn22], dim=-1) + + lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(q) + lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(q) + lambda_full = lambda_1 - lambda_2 + self.lambda_init + attn = attn1 - lambda_full * attn2 + # attn shape (-1, self.num_heads // 2, 2 * self.head_dim) + attn = self.subln(attn) + attn = attn * (1 - self.lambda_init) + # reshape back to 2 * num_head + attn_output = rearrange(attn, "... H (two D) -> ... (H two) D", two=2) + + else: # re-use the kv cache, full attention + q = self.Wqkv(hidden_states) + q = q.view(-1, self.num_heads, self.head_dim) + q1, q2 = self.split_heads(q) + # kv_cache shape is (2, num_blocks, block_size * num_kv_heads * head_size) + kv_cache1, kv_cache2 = self.split_kv_cache(kv_cache) + key_cache1, value_cache1 = kv_cache1[0], kv_cache1[1] + key_cache2, value_cache2 = kv_cache2[0], kv_cache2[1] + + attn11 = self.forward_decode(q1, key_cache1, value_cache1, attn_metadata) + attn12 = self.forward_decode(q1, key_cache1, value_cache2, attn_metadata) + attn11 = attn11.view(q1.shape) + attn12 = attn12.view(q1.shape) + attn1 = torch.cat([attn11, attn12], dim=-1) + + attn21 = self.forward_decode(q2, key_cache2, value_cache1, attn_metadata) + attn22 = self.forward_decode(q2, key_cache2, value_cache2, attn_metadata) + attn21 = attn21.view(q2.shape) + attn22 = attn22.view(q2.shape) + attn2 = torch.cat([attn21, attn22], dim=-1) + + lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(q) + lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(q) + lambda_full = lambda_1 - lambda_2 + self.lambda_init + attn = attn1 - lambda_full * attn2 + attn = self.subln(attn) + attn = attn * (1 - self.lambda_init) + # reshape back to 2 * num_head + attn_output = rearrange(attn, "... H (two D) -> ... (H two) D", two=2) + attn_output = attn_output.view(-1, self.num_heads * self.head_dim) + return self.out_proj(attn_output) + + +class Phi3Mamba(nn.Module): + def __init__( + self, + d_model, + d_state=16, + d_conv=4, + expand=2, + dt_rank="auto", + dt_min=0.001, + dt_max=0.1, + dt_init="random", # difference + dt_scale=1.0, # difference + dt_init_floor=1e-4, + conv_bias=True, + bias=False, + use_fast_path=True, # Fused kernel options + layer_idx=None, + device=None, + dtype=None, + yoco_cross=False, + yoco_kv=False, + ): + factory_kwargs = {"params_dtype": dtype} # difference + super().__init__() + self.yoco_cross = yoco_cross + self.yoco_kv = yoco_kv + self.d_model = d_model + self.d_state = d_state + self.d_conv = d_conv + self.expand = expand + self.d_inner = int(self.expand * self.d_model) + self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank + self.use_fast_path = use_fast_path + self.layer_idx = layer_idx + self.swiGluActivation = SwiGLUActivation() + if self.yoco_cross: + self.in_proj = MergedColumnParallelLinear(self.d_model, [self.d_inner], bias=bias, **factory_kwargs) + self.out_proj = RowParallelLinear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) + return + # self.conv1d = nn.Conv1d( + # in_channels=self.d_inner, + # out_channels=self.d_inner, + # bias=conv_bias, + # kernel_size=d_conv, + # groups=self.d_inner, + # padding=d_conv - 1, + # **factory_kwargs, + # ) + + self.conv1d = ColumnParallelLinear( + input_size=d_conv, + output_size=self.d_inner, + bias=conv_bias, + params_dtype=dtype, + ) + # unsqueeze to fit conv1d weights shape into the linear weights shape. + # Can't do this in `weight_loader` since it already exists in + # `ColumnParallelLinear` and `set_weight_attrs` + # doesn't allow to override it + self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) + + # self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs) + self.in_proj = MergedColumnParallelLinear(self.d_model, + [self.d_inner] * 2, + bias=bias, + params_dtype=dtype, + ) + + # self.x_proj = nn.Linear( + # self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs + # ) + # selective projection used to make dt, B and C input dependent + self.x_proj = RowParallelLinear( + self.d_inner, + self.dt_rank + self.d_state * 2, + bias=False, + params_dtype=dtype, + ) + + # self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs) + # time step projection (discretization) - + # In the forward we need to apply dt_proj without the bias, + # as the bias is added in the selective scan kernel. + self.dt_proj = ColumnParallelLinear(self.dt_rank, + self.d_inner, + bias=True, + skip_bias_add=True, + params_dtype=dtype, + ) + + # # S4D real initialization + # A = repeat( + # torch.arange(1, self.d_state + 1, dtype=torch.float32), + # "n -> d n", + # d=self.d_inner, + # ).contiguous() + # A_log = torch.log(A) # Keep A_log in fp32 + # self.A_log = nn.Parameter(A_log) + + # # D "skip" parameter + # self.D = nn.Parameter(torch.ones(self.d_inner)) # Keep in fp32 + self.A = nn.Parameter( + torch.empty( + self.d_inner, + self.d_state, + dtype=torch.float32, + )) + self.D = nn.Parameter(torch.ones(self.d_inner, dtype=torch.float32)) + + # self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) + self.out_proj = RowParallelLinear( + self.d_inner, + self.d_model, + bias=bias, + input_is_parallel=True, + params_dtype=dtype, + ) + print(f"-------- layer_idx {layer_idx}") + self.activation = "silu" + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + mamba_cache_params: MambaCacheParams, + yoco_key_values = None + ) -> torch.Tensor: + + if self.yoco_cross: + out = self.in_proj(hidden_states)[0] + out = self.swiGluActivation(yoco_key_values, out) + out = self.out_proj(out) + return out[0], yoco_key_values + + # 1. Gated MLP's linear projection + # projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1) + projected_states = self.in_proj(hidden_states.to(self.in_proj.weight.dtype))[0].transpose(-2, -1) + hidden_states, gate = projected_states.chunk(2, dim=-2) + + # 2. Convolution sequence transformation + conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), + self.conv1d.weight.size(2)) + + if attn_metadata.query_start_loc is not None \ + and attn_metadata.context_lens_tensor is not None: + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + hidden_states = causal_conv1d_fn( + hidden_states, + conv_weights, + self.conv1d.bias, + activation=self.activation, + conv_states=mamba_cache_params.conv_state, + has_initial_state=attn_metadata.context_lens_tensor > 0, + cache_indices=mamba_cache_params.state_indices_tensor, + query_start_loc=attn_metadata.query_start_loc) + else: + hidden_states = causal_conv1d_update( + hidden_states.transpose(0, 1), + mamba_cache_params.conv_state, + conv_weights, + self.conv1d.bias, + self.activation, + conv_state_indices=mamba_cache_params.state_indices_tensor) + hidden_states = hidden_states.transpose(0, 1) + + # 3. State Space Model sequence transformation + # 3.a. input varying initialization of time_step, B and C + ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0] + + time_step, B, C = torch.split( + ssm_parameters, + [self.dt_rank, self.d_state, self.d_state], + dim=-1, + ) + + # Note that Jamba normalizes B, C, and time_step here but Mamba doesn't. + + discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1) + # 3.c perform the recurrence y ← SSM(A, B, C)(x) + time_proj_bias = (self.dt_proj.bias.float() if hasattr( + self.dt_proj, "bias") else None) + + if attn_metadata.query_start_loc is not None \ + and attn_metadata.context_lens_tensor is not None: + scan_outputs = selective_scan_fn( + hidden_states, + mamba_cache_params.ssm_state, + discrete_time_step, + self.A, + B.transpose(-2, -1), + C.transpose(-2, -1), + self.D.float(), + # z, + None if self.yoco_kv else gate, + time_proj_bias, + delta_softplus=True, + cache_indices=mamba_cache_params.state_indices_tensor, + has_initial_state=attn_metadata.context_lens_tensor > 0, + query_start_loc=attn_metadata.query_start_loc) + else: + scan_outputs = selective_state_update( + mamba_cache_params.ssm_state, + hidden_states.transpose(0, 1), + discrete_time_step.transpose(0, 1), + self.A, + B, + C, + self.D, + # z + # gate.transpose(0, 1), + None if self.yoco_kv else gate.transpose(0, 1), + time_proj_bias, + dt_softplus=True, + state_batch_indices=mamba_cache_params.state_indices_tensor) + scan_outputs = scan_outputs.transpose(0, 1) + + # 4. Final linear projection + if self.yoco_kv: + # gate = gate.transpose(-1,-2).contiguous() + yoco_key_values = scan_outputs.transpose(-2, -1) + scan_outputs = self.swiGluActivation(scan_outputs, gate) + + contextualized_states = self.out_proj(scan_outputs.transpose(-2, + -1))[0] + + return contextualized_states, yoco_key_values + + +class SambaDecoderLayer(nn.Module): + + def __init__(self, + config, + layer_idx, + cache_config, + prefix: str = "",) -> None: + super().__init__() + + self.config = config + self.layer_idx = layer_idx + + self.mlp = SambaMLP(config) + self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.yoco_mb = False + self.yoco_kv = False + self.yoco_cross = False + assert config.num_hidden_layers % 4 == 0, 'n_layer should be divisible by 4 for samba + yoco' + if layer_idx >= config.num_hidden_layers//2: + self.yoco_mb = True + self.yoco_kv = (layer_idx >= (config.num_hidden_layers//2 +1)) + self.yoco_cross = (layer_idx >= (config.num_hidden_layers//2 +2)) + self.use_mamba = config.mb_per_layer > 0 and layer_idx % config.mb_per_layer == 0 + if self.use_mamba: + factory_kwargs = {"dtype": None} + self.attn = Phi3Mamba(config.hidden_size, layer_idx=layer_idx, + yoco_cross=self.yoco_cross, yoco_kv=self.yoco_mb, **factory_kwargs) + else: + self.attn = SambaAttention(config, layer_idx=layer_idx, yoco_cross=self.yoco_cross, cache_config=cache_config, prefix=f"{prefix}.self_attn") + + self.resid_attn_dropout = nn.Dropout(config.resid_pdrop) + self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + positions: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + mamba_cache_params: MambaCacheParams, + ssm_output: Optional[torch.LongTensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if self.use_mamba: + assert kv_cache is None and mamba_cache_params is not None + else: + assert kv_cache is not None and mamba_cache_params is None + + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states.to(dtype=self.input_layernorm.weight.dtype)) + + if self.use_mamba: + attn_outputs, ssm_output = self.attn( + hidden_states, + attn_metadata, + mamba_cache_params, + yoco_key_values = ssm_output + ) + residual = residual.to(torch.float32) + else: + attn_outputs = self.attn( + hidden_states, + positions, + kv_cache, + attn_metadata, + ) + try: + hidden_states = residual + self.resid_attn_dropout(attn_outputs) + except Exception as e: + print('>>> exception: ', e) + print('>>>', hidden_states.shape) + print('>>>', self.layer_idx) + print('>>>', residual.shape) + print('>>>', self.resid_attn_dropout) + print('>>>', attn_outputs) + raise + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states.to(dtype=self.post_attention_layernorm.weight.dtype)) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + self.resid_mlp_dropout(hidden_states) + + return hidden_states, ssm_output + +def get_kv_cache(layer_name): + forward_context: ForwardContext = get_forward_context() + self = forward_context.no_compile_layers[layer_name] + kv_cache = self.kv_cache[forward_context.virtual_engine] + return kv_cache + +class SambaModel(nn.Module): + + def __init__( + self, + config, + cache_config = None, + quant_config = None, + lora_config = None, + prefix: str = "" + ) -> None: + super().__init__() + + self.config = config + + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + # self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + self.embed_dropout = nn.Dropout(config.embd_pdrop) + # Pipeline parallel is not supported since the second half of the layers share the kv cache. + if get_pp_group().world_size != 1: + raise ValueError("Pipeline Parallel not supported") + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: SambaDecoderLayer(config, + int(prefix.split('.')[-1]), + cache_config, + prefix=prefix), + prefix=f"{prefix}.layers") + self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + attn_metadata: AttentionMetadata, + mamba_cache_params: MambaCacheParams, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + + kv_cache_idx = 0 + mamba_state_idx = 0 + ssm_output = None + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + if i == self.config.num_hidden_layers // 2 + 2: + # profile run + cache_layer = self.layers[kv_cache_idx] + kv_cache = get_kv_cache(cache_layer.attn.attn.layer_name) + if kv_cache.numel() == 0: + break + + # Starting from this layer, we do not need to cuculate the kv cache since we reuse + # the kv cache from last layer. If in prefill phase, we can prune truncate + # hidden state to save computation cost. + if attn_metadata.prefill_metadata: + selected_token_indices = torch.cumsum(attn_metadata.seq_lens_tensor, dim=0) - 1 + hidden_states = hidden_states.index_select(0, selected_token_indices) + ssm_output = ssm_output.index_select(0, selected_token_indices) + + + # start_env = torch.cuda.Event(enable_timing=True) + # end_env = torch.cuda.Event(enable_timing=True) + # start_env.record() + if layer.use_mamba: + if i < self.config.num_hidden_layers // 2: + mamba_cache = mamba_cache_params.at_layer_idx(mamba_state_idx) + mamba_state_idx += 1 + elif not layer.yoco_cross: + mamba_cache = mamba_cache_params.at_layer_idx(mamba_state_idx) + mamba_state_idx += 1 + else: + mamba_cache = mamba_cache_params.at_layer_idx(mamba_state_idx-1) + + hidden_states, ssm_output = layer( + hidden_states, + positions, + None, # kv_cache + attn_metadata, + mamba_cache, + ssm_output = ssm_output + ) + else: + if i < self.config.num_hidden_layers // 2: + # sliding window attention + cache_layer = self.layers[i] + kv_cache = get_kv_cache(cache_layer.attn.attn.layer_name) + kv_cache_idx = i + elif not layer.yoco_cross: + # full attention that generates kv cache + cache_layer = self.layers[i] + kv_cache = get_kv_cache(cache_layer.attn.attn.layer_name) + kv_cache_idx = i + else: + # full attention that reuses kv cache + cache_layer = self.layers[kv_cache_idx] + kv_cache = get_kv_cache(cache_layer.attn.attn.layer_name) + + hidden_states, ssm_output = layer( + hidden_states, + positions, + kv_cache, + attn_metadata, + None, # mamba_cache_params + ssm_output = ssm_output + ) + # end_env.record() + # torch.cuda.synchronize() + # print('>>> layer', i, 'time', start_env.elapsed_time(end_env)) + + hidden_states = self.final_layernorm(hidden_states.to(dtype=self.final_layernorm.weight.dtype)) + return hidden_states + + +class SambaForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + lora_config = vllm_config.lora_config + quant_config = vllm_config.quant_config + scheduler_config = vllm_config.scheduler_config + self.compilation_config = vllm_config.compilation_config + self.vllm_config = vllm_config + # Prefix caching is not supported since there are mamba layers in this + # mode. + assert not cache_config.enable_prefix_caching, \ + "Samba currently does not support prefix caching" + + super().__init__() + self.config = config + self.model_config = vllm_config.model_config + self.scheduler_config = scheduler_config + self.model = SambaModel( + config, + cache_config=cache_config, + prefix=maybe_prefix(prefix, "model") + ) + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=( + DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else + lora_config.lora_vocab_padding_size), + quant_config=quant_config, + ) + self.embedding_bias = None + # Used to track and store by the Mamba cache between steps. + self.mamba_cache: Optional[MambaCacheManager] = None + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, + logits_as_input=False) + # self.sampler = Sampler() + self.sampler = get_sampler() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[torch.Tensor, IntermediateTensors]: + if self.mamba_cache is None: + num_mamba_layers = self.config.num_hidden_layers // 2 // self.config.mb_per_layer + 1 + self.mamba_cache = MambaCacheManager( + self.vllm_config, + self.lm_head.weight.dtype, num_mamba_layers, *self._get_mamba_cache_shape() + ) + mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) + + attn_metadata = get_forward_context().attn_metadata + hidden_states = self.model(input_ids, positions, + attn_metadata, mamba_cache_params, + intermediate_tensors, inputs_embeds) + return hidden_states + + def _get_mamba_cache_shape(self) -> Tuple[Optional[Tuple[int, int]], Optional[Tuple[int, int]]]: + world_size = get_tensor_model_parallel_world_size() + hidden_size = self.config.hidden_size + mamba_expand = self.config.mamba_expand # 2 + mamba_d_conv = self.config.mamba_d_conv # 4 + mamba_d_state = self.config.mamba_d_state # 16 + conv_state_shape = ( + mamba_expand * hidden_size // world_size, + mamba_d_conv - 1, + ) + temporal_state_shape = ( + mamba_expand * hidden_size // world_size, + mamba_d_state, + ) + return conv_state_shape, temporal_state_shape + + def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): + return self.mamba_cache.copy_inputs_before_cuda_graphs( + input_buffers, **kwargs) + + def get_seqlen_agnostic_capture_inputs(self, batch_size: int): + return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + # If the shape is the same, it means that we have already prune hidden states manually. + prune_hidden_states = hidden_states.size(0) != sampling_metadata.selected_token_indices.size(0) + processed_logits = self.logits_processor( + self.lm_head, + hidden_states, + sampling_metadata, + self.embedding_bias, + prune_hidden_states=prune_hidden_states + ) + return processed_logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights( + self, + weights: Iterable[Tuple[str, torch.Tensor]], + ): + weights = {name: weight for name, weight in weights} + print(f"--------- num of keys: {len(weights.keys())}") + adjusted_weights = {} + for name, weight in weights.items(): + if "A_log" in name: + name = name.replace("A_log", "A") + weight = -torch.exp(weight.float()) + if "inner_cross_attn." in name: + name = name.replace("inner_cross_attn.", "") + adjusted_weights[name] = weight + adjusted_weights["lm_head.weight"] = weights["model.embed_tokens.weight"] + for name, loaded_weight in adjusted_weights.items(): + print(name, loaded_weight.shape) + + params_dict = dict(self.named_parameters()) + + print(f"{adjusted_weights.keys() - params_dict.keys()} not in model") + print(f"{params_dict.keys() - adjusted_weights.keys()} not in weights") + + loaded_params: Set[str] = set() + + for name, param in self.named_parameters(): + weight = adjusted_weights.get(name, None) + if weight is not None and weight.shape != param.shape: + print(f"Shape mismatch: {name} {weight.shape} {param.shape}") + loaded_params.add(name) + missing_keys, unexpected_keys = self.load_state_dict(adjusted_weights, strict=False) + print(f"--------------- missing keys {missing_keys}") + print("--------------- unexpected keys ---------------") + for key in unexpected_keys: + print(key) + if not key.endswith("bias"): + print("------- not bias -------") + # assert missing_keys == ['embedding_bias', 'lm_head.weight',], f"Missing keys: {missing_keys}" + # assert unexpected_keys == ['lm_head.bias',], f"Unexpected keys: {unexpected_keys}" + # self.lm_head.weight.data.copy_(adjusted_weights['model.embed_tokens.weight']) + # self.embedding_bias.data.copy_(adjusted_weights['lm_head.bias']) + # self.embedding_bias = None + return loaded_params \ No newline at end of file diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 17d44fa71d5..b70e10875b2 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -110,6 +110,7 @@ "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"), "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"), "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"), + "SambaForCausalLM": ("phi3samba", "SambaForCausalLM"), "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"), "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 9d936f3dbf0..2db44b4f22e 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1112,6 +1112,10 @@ def __init__( (self.max_batchsize_to_capture, self.get_max_block_per_batch()), dtype=np.int32) + self.cross_layer_shared_graph_block_tables = np.zeros( + (self.max_batchsize_to_capture, self.get_max_block_per_batch()), + dtype=np.int32) + # Attention-free but stateful models like Mamba need a placeholder attn # backend, as the attention metadata is needed to manage internal state. # However we must bypass attention selection altogether for some models From 961e638ace655230f718980d0726c357d6744588 Mon Sep 17 00:00:00 2001 From: Congcong Chen Date: Thu, 5 Jun 2025 21:55:25 +0000 Subject: [PATCH 02/24] Add a new backend Signed-off-by: Congcong Chen --- .../backends/differential_flash_attn.py | 1010 +++++++++++++++++ vllm/attention/backends/flash_attn.py | 4 +- vllm/platforms/cuda.py | 4 + vllm/platforms/interface.py | 1 + 4 files changed, 1017 insertions(+), 2 deletions(-) create mode 100644 vllm/attention/backends/differential_flash_attn.py diff --git a/vllm/attention/backends/differential_flash_attn.py b/vllm/attention/backends/differential_flash_attn.py new file mode 100644 index 00000000000..fbcd275cb23 --- /dev/null +++ b/vllm/attention/backends/differential_flash_attn.py @@ -0,0 +1,1010 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Attention layer with FlashAttention.""" +from collections import defaultdict +from dataclasses import dataclass +from itertools import accumulate +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type + +import torch + +from vllm import _custom_ops as ops +# yapf conflicts with isort for this block +# yapf: disable +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, + AttentionMetadata, + AttentionMetadataBuilder, + AttentionType, + is_quantized_kv_cache) +# yapf: enable +from vllm.attention.backends.utils import ( + PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping, + compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens, + get_seq_len_block_table_args, is_all_cross_attn_metadata_set, + is_all_encoder_attn_metadata_set, is_block_tables_empty) +from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8, + get_flash_attn_version) +from vllm.logger import init_logger +from vllm.multimodal import MultiModalPlaceholderMap +from vllm.utils import async_tensor_h2d, make_tensor_with_pad +from vllm.vllm_flash_attn import (flash_attn_varlen_func, + flash_attn_with_kvcache) +from vllm.attention.backends.flash_attn import (FlashAttentionBackend, + FlashAttentionImpl, + FlashAttentionMetadata, + FlashAttentionMetadataBuilder) + +if TYPE_CHECKING: + from vllm.worker.model_runner import (ModelInputForGPUBuilder, + ModelInputForGPUWithSamplingMetadata) + +logger = init_logger(__name__) + + +class DifferentialFlashAttentionBackend(FlashAttentionBackend): + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + if block_size % 16 != 0: + raise ValueError("Block size must be a multiple of 16.") + # return (2, num_blocks, block_size, num_kv_heads, head_size) + return (2, 2, num_blocks, block_size, num_kv_heads // 2, head_size) + + @staticmethod + def get_name() -> str: + return "DIFFERENTIAL_FLASH_ATTN" + + @staticmethod + def get_impl_cls() -> Type["DifferentialFlashAttentionImpl"]: + return DifferentialFlashAttentionImpl + + @staticmethod + def get_metadata_cls() -> Type["DifferentialFlashAttentionMetadata"]: + return DifferentialFlashAttentionMetadata + + @staticmethod + def get_builder_cls() -> Type["DifferentialFlashAttentionMetadataBuilder"]: + return DifferentialFlashAttentionMetadataBuilder + + +@dataclass +class DifferentialFlashAttentionMetadata(AttentionMetadata): + """Metadata for FlashAttentionBackend. + + NOTE: Any python object stored here is not updated when it is + cuda-graph replayed. If you have values that need to be changed + dynamically, it should be stored in tensor. The tensor has to be + updated from `CUDAGraphRunner.forward` API. + """ + # (batch_size,). The sequence length per sequence. Sequence length means + # the computed tokens + new tokens None if it is a decoding. + seq_lens: Optional[List[int]] + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] + + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Maximum sequence length among decode batch. 0 if there are prefill + # requests only. + max_decode_seq_len: int + # (batch_size,) A tensor of context lengths (tokens that are computed + # so far). + context_lens_tensor: Optional[torch.Tensor] + + # (batch_size, max_blocks_per_seq). + # Block addresses per sequence. (Seq id -> list of physical block) + # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks + # in the kv cache. Each block can contain up to block_size tokens. + # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph + # captured. + block_tables: Optional[torch.Tensor] + + # Whether or not if cuda graph is enabled. + # Cuda-graph is currently enabled for decoding only. + # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. + + use_cuda_graph: bool + + # Maximum query length in the batch. + max_query_len: Optional[int] = None + + # Max number of query tokens among request in the batch. + max_decode_query_len: Optional[int] = None + + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + query_start_loc: Optional[torch.Tensor] = None + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] = None + + _cached_prefill_metadata: Optional["DifferentialFlashAttentionMetadata"] = None + _cached_decode_metadata: Optional["DifferentialFlashAttentionMetadata"] = None + + # Begin encoder attn & enc/dec cross-attn fields... + + # Encoder sequence lengths representation + encoder_seq_lens: Optional[List[int]] = None + encoder_seq_lens_tensor: Optional[torch.Tensor] = None + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + encoder_seq_start_loc: Optional[torch.Tensor] = None + # Maximum sequence length among encoder sequences + max_encoder_seq_len: Optional[int] = None + # Number of tokens input to encoder + num_encoder_tokens: Optional[int] = None + + # Cross-attention memory-mapping data structures: slot mapping + # and block tables + cross_slot_mapping: Optional[torch.Tensor] = None + cross_block_tables: Optional[torch.Tensor] = None + + # Cross-layer shared attention block tables + cross_layer_shared_block_tables: Optional[torch.Tensor] = None + + @property + def is_all_encoder_attn_metadata_set(self): + ''' + All attention metadata required for encoder attention is set. + ''' + return is_all_encoder_attn_metadata_set(self) + + @property + def is_all_cross_attn_metadata_set(self): + ''' + All attention metadata required for enc/dec cross-attention is set. + + Superset of encoder attention required metadata. + ''' + return is_all_cross_attn_metadata_set(self) + + @property + def prefill_metadata(self) -> Optional["DifferentialFlashAttentionMetadata"]: + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + return self._cached_prefill_metadata + + assert ((self.seq_lens is not None) + or (self.encoder_seq_lens is not None)) + assert ((self.seq_lens_tensor is not None) + or (self.encoder_seq_lens_tensor is not None)) + + # Compute some attn_metadata fields which default to None + query_start_loc = (None if self.query_start_loc is None else + self.query_start_loc[:self.num_prefills + 1]) + slot_mapping = (None if self.slot_mapping is None else + self.slot_mapping[:self.num_prefill_tokens]) + seq_lens = (None if self.seq_lens is None else + self.seq_lens[:self.num_prefills]) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[:self.num_prefills]) + seq_start_loc = (None if self.seq_start_loc is None else + self.seq_start_loc[:self.num_prefills + 1]) + context_lens_tensor = (None if self.context_lens_tensor is None else + self.context_lens_tensor[:self.num_prefills]) + block_tables = (None if self.block_tables is None else + self.block_tables[:self.num_prefills]) + cross_layer_shared_block_tables = (None if self.cross_layer_shared_block_tables is None else + self.cross_layer_shared_block_tables[:self.num_prefills]) + + self._cached_prefill_metadata = DifferentialFlashAttentionMetadata( + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=self. + multi_modal_placeholder_index_maps, + enable_kv_scales_calculation=self.enable_kv_scales_calculation, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_query_len=0, + max_decode_seq_len=0, + query_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + cross_layer_shared_block_tables=cross_layer_shared_block_tables, + use_cuda_graph=False, + # Begin encoder & cross attn fields below... + encoder_seq_lens=self.encoder_seq_lens, + encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + encoder_seq_start_loc=self.encoder_seq_start_loc, + max_encoder_seq_len=self.max_encoder_seq_len, + cross_slot_mapping=self.cross_slot_mapping, + cross_block_tables=self.cross_block_tables) + return self._cached_prefill_metadata + + @property + def decode_metadata(self) -> Optional["DifferentialFlashAttentionMetadata"]: + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + return self._cached_decode_metadata + assert ((self.seq_lens_tensor is not None) + or (self.encoder_seq_lens_tensor is not None)) + + # Compute some attn_metadata fields which default to None + slot_mapping = (None if self.slot_mapping is None else + self.slot_mapping[self.num_prefill_tokens:]) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[self.num_prefills:]) + block_tables = (None if self.block_tables is None else + self.block_tables[self.num_prefills:]) + cross_layer_shared_block_tables = (None if self.cross_layer_shared_block_tables is None else + self.cross_layer_shared_block_tables[self.num_prefills:]) + self._cached_decode_metadata = DifferentialFlashAttentionMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, + seq_lens=None, + seq_lens_tensor=seq_lens_tensor, + max_decode_query_len=self.max_decode_query_len, + max_query_len=self.max_query_len, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + # Batch may be composed of prefill|decodes, adjust query start + # indices to refer to the start of decodes. E.g. + # in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6]. + query_start_loc=(self.query_start_loc[self.num_prefills:] - + self.query_start_loc[self.num_prefills]) + if self.query_start_loc is not None else None, + seq_start_loc=self.seq_start_loc[self.num_prefills:] + if self.seq_start_loc is not None else None, + context_lens_tensor=None, + block_tables=block_tables, + cross_layer_shared_block_tables=cross_layer_shared_block_tables, + use_cuda_graph=self.use_cuda_graph, + # Begin encoder & cross attn fields below... + encoder_seq_lens=self.encoder_seq_lens, + encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + encoder_seq_start_loc=self.encoder_seq_start_loc, + max_encoder_seq_len=self.max_encoder_seq_len, + cross_slot_mapping=self.cross_slot_mapping, + cross_block_tables=self.cross_block_tables) + return self._cached_decode_metadata + + def advance_step(self, + model_input: "ModelInputForGPUWithSamplingMetadata", + sampled_token_ids: Optional[torch.Tensor], + block_size: int, + num_seqs: int, + num_queries: int, + turn_prefills_into_decodes: bool = False): + """ + Update metadata in-place to advance one decode step. + """ + # When using cudagraph, the num_seqs is padded to the next captured + # batch sized, but num_queries tracks the actual number of requests in + # the batch. For --enforce-eager mode, num_seqs == num_queries + if num_seqs != num_queries: + assert num_seqs > num_queries + assert self.use_cuda_graph + + if turn_prefills_into_decodes: + # When Multi-Step is enabled with Chunked-Prefill, prefills and + # decodes are scheduled together. In the first step, all the + # prefills turn into decodes. This update reflects that + # conversion. + assert self.num_decode_tokens + self.num_prefills == num_seqs + self.num_decode_tokens += self.num_prefills + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.max_prefill_seq_len = 0 + self.max_query_len = 1 + + self.slot_mapping = self.slot_mapping[:num_seqs] + else: + assert self.seq_lens is not None + assert self.max_decode_seq_len == max(self.seq_lens) + + assert self.num_prefills == 0 + assert self.num_prefill_tokens == 0 + assert self.num_decode_tokens == num_seqs + assert self.slot_mapping.shape == (num_seqs, ) + + assert self.seq_lens is not None + assert len(self.seq_lens) == num_seqs + assert self.seq_lens_tensor is not None + assert self.seq_lens_tensor.shape == (num_seqs, ) + assert self.max_query_len == 1 + assert self.max_prefill_seq_len == 0 + + assert self.query_start_loc is not None + assert self.query_start_loc.shape == (num_queries + 1, ) + assert self.seq_start_loc is not None + assert self.seq_start_loc.shape == (num_seqs + 1, ) + + assert self.context_lens_tensor is not None + assert self.context_lens_tensor.shape == (num_queries, ) + + assert self.block_tables is not None + assert self.block_tables.shape[0] == num_seqs + + # Update query lengths. Note that we update only queries and not seqs, + # since tensors may be padded due to captured cuda graph batch size + for i in range(num_queries): + self.seq_lens[i] += 1 + self.max_decode_seq_len = max(self.seq_lens) + + ops.advance_step_flashattn(num_seqs=num_seqs, + num_queries=num_queries, + block_size=block_size, + input_tokens=model_input.input_tokens, + sampled_token_ids=sampled_token_ids, + input_positions=model_input.input_positions, + seq_lens=self.seq_lens_tensor, + slot_mapping=self.slot_mapping, + block_tables=self.block_tables) + + +class DifferentialFlashAttentionMetadataBuilder( + AttentionMetadataBuilder[DifferentialFlashAttentionMetadata]): + + def __init__(self, input_builder: "ModelInputForGPUBuilder"): + self.input_builder = input_builder + self.runner = input_builder.runner + self.sliding_window = input_builder.sliding_window + self.block_size = input_builder.block_size + + def prepare(self): + self.slot_mapping: List[int] = [] + self.prefill_seq_lens: List[int] = [] + self.context_lens: List[int] = [] + self.block_tables: List[List[int]] = [] + self.cross_layer_shared_block_tables: List[List[int]] = [] + self.curr_seq_lens: List[int] = [] + self.multimodal_placeholder_maps: Dict[ + str, + MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.num_decode_tokens = 0 + self.has_prefix_cache_hit = False + + def _add_seq_group( + self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", + chunked_prefill_enabled: bool, prefix_cache_hit: bool): + """Add a sequence group to the metadata. Specifically update/append + 1. context length. + 2. block table. + 3. slot mapping. + """ + is_prompt = inter_data.is_prompt + block_tables = inter_data.block_tables + + for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, + curr_sliding_window_block) in zip( + inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], + inter_data.orig_seq_lens, inter_data.seq_lens, + inter_data.query_lens, inter_data.context_lens, + inter_data.curr_sliding_window_blocks): + self.context_lens.append(context_len) + + if is_prompt: + mm_maps = inter_data.multi_modal_placeholder_maps + if mm_maps: + for modality, placeholders in mm_maps.items(): + self.multimodal_placeholder_maps[modality].extend( + placeholders) + + self.num_prefills += 1 + self.num_prefill_tokens += token_len + self.prefill_seq_lens.append(seq_len) + else: + self.num_decode_tokens += query_len + self.curr_seq_lens.append(curr_seq_len) + + # Compute block table. + # TODO(sang): Combine chunked prefill and prefix caching by + # only allowing multiple of block_size chunk size. + # NOTE: This only works for oooooooxxx style attention. + block_table = [] + if prefix_cache_hit: + # NOTE(woosuk): For flash-attn, the block table should + # include the entries for the incoming prefill tokens. + block_table = block_tables[seq_id] + elif ((chunked_prefill_enabled or not is_prompt) + and block_tables is not None): + if curr_sliding_window_block == 0: + block_table = block_tables[seq_id] + else: + block_table = block_tables[seq_id][ + -curr_sliding_window_block:] + self.block_tables.append(block_table) + + cross_layer_shared_block_table = [] + if prefix_cache_hit: + cross_layer_shared_block_table = block_tables[seq_id] + elif block_tables is not None: + if curr_sliding_window_block == 0: + cross_layer_shared_block_table = block_tables[seq_id] + else: + cross_layer_shared_block_table = block_tables[seq_id][ + -curr_sliding_window_block:] + self.cross_layer_shared_block_tables.append(cross_layer_shared_block_table) + + # Compute slot mapping. + is_profile_run = is_block_tables_empty(block_tables) + start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, + context_len, + self.sliding_window) + compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, + seq_len, context_len, start_idx, + self.block_size, inter_data.block_tables) + + def _get_graph_runner_block_tables( + self, num_seqs: int, + block_tables: List[List[int]], + graph_block_tables) -> torch.Tensor: + # The shape of graph_block_tables is + # [max batch size, max context len // block size]. + # max_batch_size, max_blocks = self.runner.graph_block_tables.shape + max_batch_size, max_blocks = graph_block_tables.shape + assert max_batch_size >= num_seqs + + # graph_block_tables = self.runner.graph_block_tables[:num_seqs] + graph_block_tables = graph_block_tables[:num_seqs] + for i, block_table in enumerate(block_tables): + if block_table: + num_blocks = len(block_table) + if num_blocks <= max_blocks: + graph_block_tables[i, :num_blocks] = block_table + else: + # It may be possible to have more blocks allocated due + # to lookahead slots of multi-step, however, they are + # not used anyway, so can be safely ignored. + graph_block_tables[ + i, :max_blocks] = block_table[:max_blocks] + + return torch.from_numpy(graph_block_tables).to( + device=self.runner.device, non_blocking=True) + + def build(self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, batch_size: int): + """Build attention metadata with on-device tensors. + + Args: + seq_lens: The maybe padded sequence lengths of the input sequences. + query_lens: The query lengths of the input sequences. + cuda_graph_pad_size: The padding size for cuda graph. + -1 if cuda graph is not used. + batch_size: The maybe padded batch size. + """ + prefix_cache_hit = any([ + inter_data.prefix_cache_hit + for inter_data in self.input_builder.inter_data_list + ]) + for inter_data in self.input_builder.inter_data_list: + self._add_seq_group(inter_data, + self.input_builder.chunked_prefill_enabled, + prefix_cache_hit) + + device = self.runner.device + use_captured_graph = cuda_graph_pad_size != -1 + + max_query_len = max(query_lens) + decode_query_lens = query_lens[self.num_prefills:] + if len(decode_query_lens) > 0: + max_decode_query_len = max(decode_query_lens) + else: + max_decode_query_len = 1 + max_prefill_seq_len = max(self.prefill_seq_lens, default=0) + max_decode_seq_len = max(self.curr_seq_lens, default=0) + num_decode_tokens = self.num_decode_tokens + query_start_loc = list(accumulate(query_lens, initial=0)) + seq_start_loc = list(accumulate(seq_lens, initial=0)) + + num_seqs = len(seq_lens) + if use_captured_graph: + self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) + self.block_tables.extend([] * cuda_graph_pad_size) + + self.cross_layer_shared_block_tables.extend([] * cuda_graph_pad_size) + + num_decode_tokens = batch_size - self.num_prefill_tokens + block_tables = self._get_graph_runner_block_tables( + num_seqs, self.block_tables, self.runner.graph_block_tables) + cross_layer_shared_block_tables = self._get_graph_runner_block_tables( + num_seqs, self.cross_layer_shared_block_tables, self.runner.cross_layer_shared_graph_block_tables) + else: + block_tables = make_tensor_with_pad( + self.block_tables, + pad=0, + dtype=torch.int, + device=device, + ) + cross_layer_shared_block_tables = make_tensor_with_pad( + self.cross_layer_shared_block_tables, + pad=0, + dtype=torch.int, + device=device, + ) + assert max_query_len > 0, ("query_lens: {}".format(query_lens)) + + assert device is not None + context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int, + device, self.runner.pin_memory) + seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, + self.runner.pin_memory) + slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, + device, self.runner.pin_memory) + query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32, + device, + self.runner.pin_memory) + seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32, + device, self.runner.pin_memory) + placeholder_index_maps = { + modality: placeholder_map.index_map() + for modality, placeholder_map in + self.multimodal_placeholder_maps.items() + } + + return DifferentialFlashAttentionMetadata( + num_prefills=self.num_prefills, + slot_mapping=slot_mapping_tensor, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + seq_lens=seq_lens, + multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=True, + seq_lens_tensor=seq_lens_tensor, + max_query_len=max_query_len, + max_decode_query_len=max_decode_query_len, + max_prefill_seq_len=max_prefill_seq_len, + max_decode_seq_len=max_decode_seq_len, + query_start_loc=query_start_loc_tensor, + seq_start_loc=seq_start_loc_tensor, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + cross_layer_shared_block_tables=cross_layer_shared_block_tables, + use_cuda_graph=use_captured_graph, + ) + + +class DifferentialFlashAttentionImpl(AttentionImpl): + """ + If the input tensors contain prompt tokens, the layout is as follows: + |<--------------- num_prefill_tokens ----------------->| + |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->| + + Otherwise, the layout is as follows: + |<----------------- num_decode_tokens ------------------>| + |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->| + + Generation tokens can contain padding when cuda-graph is used. + Currently, prompt tokens don't contain any padding. + + The prompts might have different lengths, while the generation tokens + always have length 1. + + If chunked prefill is enabled, prefill tokens and decode tokens can be + batched together in a flattened 1D query. + + |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->| + |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->| + + Currently, cuda graph is disabled for chunked prefill, meaning there's no + padding between prefill and decode tokens. + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, + kv_sharing_target_layer_name: Optional[str] = None, + use_irope: bool = False, + ) -> None: + if kv_sharing_target_layer_name is not None: + raise NotImplementedError("KV sharing is not supported in V0.") + if blocksparse_params is not None: + raise ValueError( + "FlashAttention does not support block-sparse attention.") + if use_irope: + logger.warning( + "Using irope in V0 is not supported yet, it will fall back " + "to global attention for long context.") + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.sliding_window = ((sliding_window - 1, + 0) if sliding_window is not None else (-1, -1)) + self.kv_cache_dtype = kv_cache_dtype + self.vllm_flash_attn_version = get_flash_attn_version( + requires_alibi=self.alibi_slopes is not None) + if is_quantized_kv_cache(self.kv_cache_dtype) and ( + not self.kv_cache_dtype.startswith("fp8") + or not flash_attn_supports_fp8()): + raise NotImplementedError( + f"FlashAttention does not support {self.kv_cache_dtype} " + "kv-cache on this device " + f"(FA supports fp8 = {flash_attn_supports_fp8()}).") + if logits_soft_cap is None: + # In flash-attn, setting logits_soft_cap as 0 means no soft cap. + logits_soft_cap = 0 + self.logits_soft_cap = logits_soft_cap + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + support_head_sizes = FlashAttentionBackend.get_supported_head_sizes() + if head_size not in support_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by FlashAttention. " + f"Supported head sizes are: {support_head_sizes}.") + self.attn_type = attn_type + + def forward( + self, + layer: AttentionLayer, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: DifferentialFlashAttentionMetadata, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with FlashAttention. + + Args: + query: shape = [num_tokens, num_heads, head_size] + key: shape = [num_tokens, num_kv_heads, head_size] + value: shape = [num_tokens, num_kv_heads, head_size] + output: shape = [num_tokens, num_heads, head_size] + kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] + NOTE: kv_cache will be an empty tensor with shape [0] + for profiling run. + attn_metadata: Metadata for attention. + NOTE: It in-place updates the output tensor. + NOTE: FP8 quantization, flash-attn expect the size of + {q,k,v}_descale to be (num_sequences, num_kv_heads). + We use torch's .expand() to avoid duplicating values + """ + assert output is not None, "Output tensor must be provided." + + # NOTE(woosuk): FlashAttention2 does not support FP8 KV cache. + if not flash_attn_supports_fp8() or output.dtype != torch.bfloat16: + assert ( + layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0), ( + "key/v_scale is only supported in FlashAttention 3 with " + "base dtype bfloat16") + + attn_type = self.attn_type + if (attn_type == AttentionType.ENCODER + and (not attn_metadata.is_all_encoder_attn_metadata_set)): + raise AttributeError("Encoder attention requires setting " + "encoder metadata attributes.") + elif (attn_type == AttentionType.ENCODER_DECODER + and (not attn_metadata.is_all_cross_attn_metadata_set)): + raise AttributeError("Encoder/decoder cross-attention " + "requires setting cross-attention " + "metadata attributes.") + + kv_cache_dtype: str = self.kv_cache_dtype + softmax_scale: float = self.scale + window_size = self.sliding_window + alibi_slopes: Optional[torch.Tensor] = self.alibi_slopes + logits_soft_cap: Optional[float] = self.logits_soft_cap + fp8_attention = kv_cache_dtype.startswith("fp8") + + if fp8_attention and not flash_attn_supports_fp8(): + raise NotImplementedError( + "FlashAttention does not support FP8 kv-cache on this device.") + + if kv_cache.numel() > 0: + key_cache = kv_cache[0] + value_cache = kv_cache[1] + # We skip updating the KV cache under two conditions: + # a. When the Attention Type is ENCODER. In this phase, we compute + # only the encoder attention without updating the cache. + # b. When both Key and Value are None. This occurs during + # cross-attention computation in the decoding phase, where the + # KV cache is already populated with the cross-attention + # tensor. Thus, we skip cache updates during this time. + if (attn_type != AttentionType.ENCODER) and (key is not None) and ( + value is not None): + if attn_type == AttentionType.ENCODER_DECODER: + # Update cross-attention KV cache (prefill-only) + updated_slot_mapping = attn_metadata.cross_slot_mapping + else: + # Update self-attention KV cache (prefill/decode) + updated_slot_mapping = attn_metadata.slot_mapping + + # Reshape the input keys and values and store them in the cache. + # If kv_cache is not provided, the new key and value tensors are + # not cached. This happens during the initial memory + # profiling run. + torch.ops._C_cache_ops.reshape_and_cache_flash( + key, + value, + kv_cache[0], + kv_cache[1], + updated_slot_mapping.flatten(), # type: ignore[union-attr] + kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + + if fp8_attention: + kv_cache = kv_cache.view(torch.float8_e4m3fn) + key_cache = key_cache.view(torch.float8_e4m3fn) + value_cache = value_cache.view(torch.float8_e4m3fn) + + if fp8_attention: + num_tokens, num_heads, head_size = query.shape + query, _ = ops.scaled_fp8_quant( + query.reshape( + (num_tokens, num_heads * head_size)).contiguous(), + layer._q_scale) + query = query.reshape((num_tokens, num_heads, head_size)) + + (num_prefill_query_tokens, num_prefill_kv_tokens, + num_decode_query_tokens) = \ + get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type) + decode_query = query[num_prefill_query_tokens:] + decode_output = output[num_prefill_query_tokens:] + # QKV for prefill. + query = query[:num_prefill_query_tokens] + prefill_output = output[:num_prefill_query_tokens] + assert query.shape[0] == num_prefill_query_tokens + assert decode_query.shape[0] == num_decode_query_tokens + + if prefill_meta := attn_metadata.prefill_metadata: + # Prompt run. + if (kv_cache.numel() == 0 or prefill_meta.block_tables is None + or prefill_meta.block_tables.numel() == 0): + # normal attention + # When block_tables are not filled, it means q and k are the + # prompt, and they have the same length. + q_seq_start_loc, q_seq_len, k_seq_start_loc, k_seq_len = \ + _get_query_key_seq_metadata(prefill_meta, True, attn_type) + + key = key[:num_prefill_kv_tokens] + value = value[:num_prefill_kv_tokens] + + if fp8_attention: + num_kv_tokens, num_kv_heads, head_size = key.shape + + key, _ = ops.scaled_fp8_quant( + key.reshape((num_kv_tokens, + num_kv_heads * head_size)).contiguous(), + layer._k_scale) + key = key.reshape((num_kv_tokens, num_kv_heads, head_size)) + + value, _ = ops.scaled_fp8_quant( + value.reshape((num_kv_tokens, + num_kv_heads * head_size)).contiguous(), + layer._v_scale) + value = value.reshape( + (num_kv_tokens, num_kv_heads, head_size)) + + descale_shape = (q_seq_start_loc.shape[0] - 1, key.shape[1]) + flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=q_seq_start_loc, + cu_seqlens_k=k_seq_start_loc, + max_seqlen_q=q_seq_len, + max_seqlen_k=k_seq_len, + softmax_scale=softmax_scale, + causal=_get_causal_option(attn_type), + window_size=window_size, + alibi_slopes=alibi_slopes, + softcap=logits_soft_cap, + out=prefill_output, + fa_version=self.vllm_flash_attn_version, + q_descale=layer._q_scale.expand(descale_shape), + k_descale=layer._k_scale.expand(descale_shape), + v_descale=layer._v_scale.expand(descale_shape), + ) + else: + # prefix-enabled attention + assert attn_type == AttentionType.DECODER, ( + "Only decoder-only models support prefix caching") + assert prefill_meta.seq_lens is not None + assert prefill_meta.query_start_loc is not None + max_seq_len = max(prefill_meta.seq_lens) + descale_shape = (prefill_meta.query_start_loc.shape[0] - 1, + key.shape[1]) + flash_attn_varlen_func( # noqa + q=query, + k=key_cache, + v=value_cache, + cu_seqlens_q=prefill_meta.query_start_loc, + max_seqlen_q=prefill_meta.max_query_len, + seqused_k=prefill_meta.seq_lens_tensor, + max_seqlen_k=max_seq_len, + softmax_scale=softmax_scale, + causal=True, + window_size=window_size, + alibi_slopes=alibi_slopes, + block_table=prefill_meta.block_tables, + softcap=logits_soft_cap, + out=prefill_output, + fa_version=self.vllm_flash_attn_version, + q_descale=layer._q_scale.expand(descale_shape), + k_descale=layer._k_scale.expand(descale_shape), + v_descale=layer._v_scale.expand(descale_shape), + ) + + if decode_meta := attn_metadata.decode_metadata: + # Decoding run. + # Use flash_attn_varlen_func kernel for speculative decoding + # because different queries might have different lengths. + + assert decode_meta.max_decode_query_len is not None + # use only for actual varlen decoding + if decode_meta.max_decode_query_len > 1: + assert attn_type == AttentionType.DECODER, ( + "Only decoder-only models support max_decode_query_len > 1" + ) + assert decode_meta.query_start_loc is not None + descale_shape = (decode_meta.query_start_loc.shape[0] - 1, + key.shape[1]) + flash_attn_varlen_func( + q=decode_query, + k=key_cache, + v=value_cache, + cu_seqlens_q=decode_meta.query_start_loc, + max_seqlen_q=decode_meta.max_decode_query_len, + seqused_k=decode_meta.seq_lens_tensor, + max_seqlen_k=decode_meta.max_decode_seq_len, + softmax_scale=softmax_scale, + causal=True, + window_size=window_size, + alibi_slopes=alibi_slopes, + softcap=logits_soft_cap, + block_table=decode_meta.block_tables, + out=decode_output, + fa_version=self.vllm_flash_attn_version, + q_descale=layer._q_scale.expand(descale_shape), + k_descale=layer._k_scale.expand(descale_shape), + v_descale=layer._v_scale.expand(descale_shape), + ) + else: + # Use flash_attn_with_kvcache for normal decoding. + ( + seq_lens_arg, + _, + block_tables_arg, + ) = get_seq_len_block_table_args(decode_meta, False, attn_type) + descale_shape = (seq_lens_arg.shape[0], key_cache.shape[-2]) + flash_attn_with_kvcache( + q=decode_query.unsqueeze(1), + k_cache=key_cache, + v_cache=value_cache, + block_table=block_tables_arg, + cache_seqlens=seq_lens_arg, + softmax_scale=softmax_scale, + causal=True, + window_size=window_size, + alibi_slopes=alibi_slopes, + softcap=logits_soft_cap, + out=decode_output.unsqueeze(1), + fa_version=self.vllm_flash_attn_version, + q_descale=layer._q_scale.expand(descale_shape), + k_descale=layer._k_scale.expand(descale_shape), + v_descale=layer._v_scale.expand(descale_shape), + ) + return output + + +def _get_query_key_seq_metadata( + attn_metadata, + is_prompt: bool, + attn_type: str, +) -> tuple: + """ + Returns sequence metadata for key and query based on the specified + attention type and whether input is a prompt. + + This function computes the starting locations and maximum sequence lengths + for key and query sequences for different attention types. + + Args: + attn_metadata: The attention metadata object + is_prompt (bool): A flag indicating if the input is a prompt + attn_type (AttentionType): The type of attention being used. + + Returns: + tuple: A tuple containing four integers: + - Starting location for the query sequence. + - Maximum sequence length for the query sequence. + - Starting location for the key sequence. + - Maximum sequence length for the key sequence. + + Raises: + AttributeError: If an invalid attention type is provided. + """ + if attn_type == AttentionType.DECODER: + # Decoder self-attention + # Choose max_seq_len based on whether we are in prompt_run + if is_prompt: + max_seq_len = attn_metadata.max_prefill_seq_len + else: + max_seq_len = attn_metadata.max_decode_seq_len + return (attn_metadata.seq_start_loc, max_seq_len, + attn_metadata.seq_start_loc, max_seq_len) + + elif attn_type == AttentionType.ENCODER_DECODER: + # This is cross attention between the where the key + # is the precomputed encoder attention and query + # is the input sequence. + # Choose query max length based on whether it is prompt + # or not. + if is_prompt: + max_seq_len = attn_metadata.max_prefill_seq_len + else: + max_seq_len = attn_metadata.max_decode_seq_len + return (attn_metadata.seq_start_loc, max_seq_len, + attn_metadata.encoder_seq_start_loc, + attn_metadata.max_encoder_seq_len) + elif attn_type == AttentionType.ENCODER: + # For encoder attention both the query and the key are same i.e the + # encoder sequence. + return (attn_metadata.encoder_seq_start_loc, + attn_metadata.max_encoder_seq_len, + attn_metadata.encoder_seq_start_loc, + attn_metadata.max_encoder_seq_len) + elif attn_type == AttentionType.ENCODER_ONLY: + assert is_prompt, "Should not have decode for encoder only model." + return (attn_metadata.seq_start_loc, attn_metadata.max_prefill_seq_len, + attn_metadata.seq_start_loc, attn_metadata.max_prefill_seq_len) + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + + +def _get_causal_option(attn_type: str) -> bool: + """ + Determine whether the given attention type is suitable for causal + attention mechanisms. + + Args: + attn_type (AttentionType): The type of attention being evaluated + + Returns: + bool: Returns `True` if the attention type is suitable for causal + attention (i.e., not encoder, encoder-only, or encoder-decoder), + otherwise returns `False`. + """ + return not (attn_type == AttentionType.ENCODER + or attn_type == AttentionType.ENCODER_ONLY + or attn_type == AttentionType.ENCODER_DECODER) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 410377f9a97..533408d5abd 100755 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -75,8 +75,8 @@ def get_kv_cache_shape( ) -> Tuple[int, ...]: if block_size % 16 != 0: raise ValueError("Block size must be a multiple of 16.") - # return (2, num_blocks, block_size, num_kv_heads, head_size) - return (2, 2, num_blocks, block_size, num_kv_heads // 2, head_size) + return (2, num_blocks, block_size, num_kv_heads, head_size) + # return (2, 2, num_blocks, block_size, num_kv_heads // 2, head_size) @staticmethod def swap_blocks( diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 00151296a75..878f8f77edf 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -316,6 +316,10 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, logger.info("Using DualChunkFlashAttention backend.") return ("vllm.attention.backends.dual_chunk_flash_attn." "DualChunkFlashAttentionBackend") + elif selected_backend == _Backend.DIFFERENTIAL_FLASH_ATTN: + logger.info("Using DifferentialFlashAttention backend.") + return ("vllm.attention.backends.differential_flash_attn." + "DifferentialFlashAttentionBackend") elif selected_backend == _Backend.FLASH_ATTN: pass elif selected_backend: diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index d3060685e98..ae675bcc8d2 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -60,6 +60,7 @@ class _Backend(enum.Enum): IPEX = enum.auto() BLOCK_SPARSE_FLASH_ATTN = enum.auto() DUAL_CHUNK_FLASH_ATTN = enum.auto() + DIFFERENTIAL_FLASH_ATTN = enum.auto() NO_ATTENTION = enum.auto() FLEX_ATTENTION = enum.auto() From c264085b283a45fade2fbefd82de193e0fe0bff7 Mon Sep 17 00:00:00 2001 From: Congcong Chen Date: Sun, 8 Jun 2025 01:29:45 +0000 Subject: [PATCH 03/24] Use differential flash attn backend Signed-off-by: Congcong Chen --- .../backends/differential_flash_attn.py | 550 ++++++++---------- vllm/model_executor/models/phi3samba.py | 182 +++--- 2 files changed, 344 insertions(+), 388 deletions(-) diff --git a/vllm/attention/backends/differential_flash_attn.py b/vllm/attention/backends/differential_flash_attn.py index fbcd275cb23..a925cdcb3c7 100644 --- a/vllm/attention/backends/differential_flash_attn.py +++ b/vllm/attention/backends/differential_flash_attn.py @@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type import torch +import torch.nn as nn from vllm import _custom_ops as ops # yapf conflicts with isort for this block @@ -34,6 +35,7 @@ FlashAttentionImpl, FlashAttentionMetadata, FlashAttentionMetadataBuilder) +from einops import rearrange if TYPE_CHECKING: from vllm.worker.model_runner import (ModelInputForGPUBuilder, @@ -43,7 +45,7 @@ class DifferentialFlashAttentionBackend(FlashAttentionBackend): - + accept_output_buffer = False @staticmethod def get_kv_cache_shape( num_blocks: int, @@ -627,7 +629,11 @@ def __init__( attn_type: str = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[str] = None, use_irope: bool = False, + differential_flash_attention_config: Optional[Dict[str, Any]] = None, ) -> None: + self.differential_flash_attention_config = differential_flash_attention_config + self.used_shared_kv_cache = self.differential_flash_attention_config.get( + "used_shared_kv_cache", False) if kv_sharing_target_layer_name is not None: raise NotImplementedError("KV sharing is not supported in V0.") if blocksparse_params is not None: @@ -671,340 +677,270 @@ def __init__( f"Supported head sizes are: {support_head_sizes}.") self.attn_type = attn_type - def forward( - self, - layer: AttentionLayer, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: DifferentialFlashAttentionMetadata, - output: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """Forward pass with FlashAttention. + self.lambda_full = None + # self.subln = nn.RMSNorm(2 * self.head_size, eps=1e-5, elementwise_affine=True) + self.subln = self.differential_flash_attention_config["subln"] - Args: - query: shape = [num_tokens, num_heads, head_size] - key: shape = [num_tokens, num_kv_heads, head_size] - value: shape = [num_tokens, num_kv_heads, head_size] - output: shape = [num_tokens, num_heads, head_size] - kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] - NOTE: kv_cache will be an empty tensor with shape [0] - for profiling run. - attn_metadata: Metadata for attention. - NOTE: It in-place updates the output tensor. - NOTE: FP8 quantization, flash-attn expect the size of - {q,k,v}_descale to be (num_sequences, num_kv_heads). - We use torch's .expand() to avoid duplicating values - """ - assert output is not None, "Output tensor must be provided." - - # NOTE(woosuk): FlashAttention2 does not support FP8 KV cache. - if not flash_attn_supports_fp8() or output.dtype != torch.bfloat16: - assert ( - layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0), ( - "key/v_scale is only supported in FlashAttention 3 with " - "base dtype bfloat16") - - attn_type = self.attn_type - if (attn_type == AttentionType.ENCODER - and (not attn_metadata.is_all_encoder_attn_metadata_set)): - raise AttributeError("Encoder attention requires setting " - "encoder metadata attributes.") - elif (attn_type == AttentionType.ENCODER_DECODER - and (not attn_metadata.is_all_cross_attn_metadata_set)): - raise AttributeError("Encoder/decoder cross-attention " - "requires setting cross-attention " - "metadata attributes.") - - kv_cache_dtype: str = self.kv_cache_dtype - softmax_scale: float = self.scale - window_size = self.sliding_window - alibi_slopes: Optional[torch.Tensor] = self.alibi_slopes - logits_soft_cap: Optional[float] = self.logits_soft_cap - fp8_attention = kv_cache_dtype.startswith("fp8") - - if fp8_attention and not flash_attn_supports_fp8(): - raise NotImplementedError( - "FlashAttention does not support FP8 kv-cache on this device.") - - if kv_cache.numel() > 0: - key_cache = kv_cache[0] - value_cache = kv_cache[1] - # We skip updating the KV cache under two conditions: - # a. When the Attention Type is ENCODER. In this phase, we compute - # only the encoder attention without updating the cache. - # b. When both Key and Value are None. This occurs during - # cross-attention computation in the decoding phase, where the - # KV cache is already populated with the cross-attention - # tensor. Thus, we skip cache updates during this time. - if (attn_type != AttentionType.ENCODER) and (key is not None) and ( - value is not None): - if attn_type == AttentionType.ENCODER_DECODER: - # Update cross-attention KV cache (prefill-only) - updated_slot_mapping = attn_metadata.cross_slot_mapping - else: - # Update self-attention KV cache (prefill/decode) - updated_slot_mapping = attn_metadata.slot_mapping + def split_heads(self, x): + # split by num_heads, the stripe pattern is friendly to tensor parallel. + x = rearrange(x, "... (H two) D -> ... H two D", two=2) + x1 = x[..., 0, :] + x2 = x[..., 1, :] + return x1.contiguous(), x2.contiguous() + + def split_kv_cache(self, x): + # split by num_heads, the stripe pattern is friendly to tensor parallel. + if x.numel() == 0: + return torch.empty(0), torch.empty(0) + + x1, x2 = x[0], x[1] + return x1, x2 + + def populate_kv_cache(self, + layer: AttentionLayer, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: DifferentialFlashAttentionMetadata): + if (kv_cache.numel() > 0): + if (key is not None) and (value is not None): + updated_slot_mapping = attn_metadata.slot_mapping + # previous_key_cache_sum = key_cache.sum() + # previous_value_cache_sum = value_cache.sum() - # Reshape the input keys and values and store them in the cache. - # If kv_cache is not provided, the new key and value tensors are - # not cached. This happens during the initial memory - # profiling run. torch.ops._C_cache_ops.reshape_and_cache_flash( key, value, kv_cache[0], kv_cache[1], - updated_slot_mapping.flatten(), # type: ignore[union-attr] - kv_cache_dtype, + updated_slot_mapping.flatten(), + self.kv_cache_dtype, layer._k_scale, layer._v_scale, ) + # assert key_cache.sum() - previous_key_cache_sum == key.sum(), "key_cache sum mismatch" + # assert value_cache.sum() - previous_value_cache_sum == value.sum(), "value_cache sum mismatch" + # if key_cache.sum() - previous_key_cache_sum != key.sum(): + # print("key_cache sum mismatch") + # if value_cache.sum() - previous_value_cache_sum != value.sum(): + # print("value_cache sum mismatch") + + def forward_generate_kv_cache( + self, + query: torch.Tensor, + key: Optional[torch.Tensor], + value: Optional[torch.Tensor], + k_cache: torch.Tensor, + v_cache: torch.Tensor, + attn_metadata: AttentionMetadata + ) -> torch.Tensor: + + head_size = self.head_size + num_heads = self.num_heads // 2 + num_kv_heads = self.num_kv_heads // 2 - if fp8_attention: - kv_cache = kv_cache.view(torch.float8_e4m3fn) - key_cache = key_cache.view(torch.float8_e4m3fn) - value_cache = value_cache.view(torch.float8_e4m3fn) - - if fp8_attention: - num_tokens, num_heads, head_size = query.shape - query, _ = ops.scaled_fp8_quant( - query.reshape( - (num_tokens, num_heads * head_size)).contiguous(), - layer._q_scale) - query = query.reshape((num_tokens, num_heads, head_size)) - - (num_prefill_query_tokens, num_prefill_kv_tokens, - num_decode_query_tokens) = \ - get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type) - decode_query = query[num_prefill_query_tokens:] - decode_output = output[num_prefill_query_tokens:] + query = query.view(-1, num_heads, head_size) + if key is not None: + assert value is not None + key = key.view(-1, num_kv_heads, head_size) + value = value.view(-1, num_kv_heads, head_size) + else: + assert value is None + + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + assert key.shape[0] == num_prefill_tokens + num_decode_tokens, "key shape mismatch" + assert value.shape[0] == num_prefill_tokens + num_decode_tokens, "value shape mismatch" + + output = torch.empty_like(query) + # Query for decode. KV is not needed because it is already cached. + decode_query = query[num_prefill_tokens:] # QKV for prefill. - query = query[:num_prefill_query_tokens] - prefill_output = output[:num_prefill_query_tokens] - assert query.shape[0] == num_prefill_query_tokens - assert decode_query.shape[0] == num_decode_query_tokens + query = query[:num_prefill_tokens] + if key is not None and value is not None: + key = key[:num_prefill_tokens] + value = value[:num_prefill_tokens] + + assert query.shape[0] == num_prefill_tokens, "query shape mismatch" + assert decode_query.shape[0] == num_decode_tokens, "decode query shape mismatch" if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. - if (kv_cache.numel() == 0 or prefill_meta.block_tables is None - or prefill_meta.block_tables.numel() == 0): + if k_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0: # normal attention - # When block_tables are not filled, it means q and k are the - # prompt, and they have the same length. - q_seq_start_loc, q_seq_len, k_seq_start_loc, k_seq_len = \ - _get_query_key_seq_metadata(prefill_meta, True, attn_type) - - key = key[:num_prefill_kv_tokens] - value = value[:num_prefill_kv_tokens] - - if fp8_attention: - num_kv_tokens, num_kv_heads, head_size = key.shape - - key, _ = ops.scaled_fp8_quant( - key.reshape((num_kv_tokens, - num_kv_heads * head_size)).contiguous(), - layer._k_scale) - key = key.reshape((num_kv_tokens, num_kv_heads, head_size)) - - value, _ = ops.scaled_fp8_quant( - value.reshape((num_kv_tokens, - num_kv_heads * head_size)).contiguous(), - layer._v_scale) - value = value.reshape( - (num_kv_tokens, num_kv_heads, head_size)) - - descale_shape = (q_seq_start_loc.shape[0] - 1, key.shape[1]) - flash_attn_varlen_func( + prefill_output = flash_attn_varlen_func( q=query, k=key, v=value, - cu_seqlens_q=q_seq_start_loc, - cu_seqlens_k=k_seq_start_loc, - max_seqlen_q=q_seq_len, - max_seqlen_k=k_seq_len, - softmax_scale=softmax_scale, - causal=_get_causal_option(attn_type), - window_size=window_size, - alibi_slopes=alibi_slopes, - softcap=logits_soft_cap, - out=prefill_output, - fa_version=self.vllm_flash_attn_version, - q_descale=layer._q_scale.expand(descale_shape), - k_descale=layer._k_scale.expand(descale_shape), - v_descale=layer._v_scale.expand(descale_shape), - ) - else: - # prefix-enabled attention - assert attn_type == AttentionType.DECODER, ( - "Only decoder-only models support prefix caching") - assert prefill_meta.seq_lens is not None - assert prefill_meta.query_start_loc is not None - max_seq_len = max(prefill_meta.seq_lens) - descale_shape = (prefill_meta.query_start_loc.shape[0] - 1, - key.shape[1]) - flash_attn_varlen_func( # noqa - q=query, - k=key_cache, - v=value_cache, - cu_seqlens_q=prefill_meta.query_start_loc, - max_seqlen_q=prefill_meta.max_query_len, - seqused_k=prefill_meta.seq_lens_tensor, - max_seqlen_k=max_seq_len, - softmax_scale=softmax_scale, + cu_seqlens_q=prefill_meta.seq_start_loc, + cu_seqlens_k=prefill_meta.seq_start_loc, + max_seqlen_q=prefill_meta.max_prefill_seq_len, + max_seqlen_k=prefill_meta.max_prefill_seq_len, + softmax_scale=self.scale, causal=True, - window_size=window_size, - alibi_slopes=alibi_slopes, - block_table=prefill_meta.block_tables, - softcap=logits_soft_cap, - out=prefill_output, - fa_version=self.vllm_flash_attn_version, - q_descale=layer._q_scale.expand(descale_shape), - k_descale=layer._k_scale.expand(descale_shape), - v_descale=layer._v_scale.expand(descale_shape), + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes, + softcap=self.logits_soft_cap, ) + assert prefill_output.shape == output[:num_prefill_tokens].shape + output[:num_prefill_tokens] = prefill_output + else: + raise Exception("prefix caching not supported") if decode_meta := attn_metadata.decode_metadata: - # Decoding run. - # Use flash_attn_varlen_func kernel for speculative decoding - # because different queries might have different lengths. - - assert decode_meta.max_decode_query_len is not None - # use only for actual varlen decoding - if decode_meta.max_decode_query_len > 1: - assert attn_type == AttentionType.DECODER, ( - "Only decoder-only models support max_decode_query_len > 1" - ) - assert decode_meta.query_start_loc is not None - descale_shape = (decode_meta.query_start_loc.shape[0] - 1, - key.shape[1]) - flash_attn_varlen_func( - q=decode_query, - k=key_cache, - v=value_cache, - cu_seqlens_q=decode_meta.query_start_loc, - max_seqlen_q=decode_meta.max_decode_query_len, - seqused_k=decode_meta.seq_lens_tensor, - max_seqlen_k=decode_meta.max_decode_seq_len, - softmax_scale=softmax_scale, - causal=True, - window_size=window_size, - alibi_slopes=alibi_slopes, - softcap=logits_soft_cap, - block_table=decode_meta.block_tables, - out=decode_output, - fa_version=self.vllm_flash_attn_version, - q_descale=layer._q_scale.expand(descale_shape), - k_descale=layer._k_scale.expand(descale_shape), - v_descale=layer._v_scale.expand(descale_shape), - ) - else: - # Use flash_attn_with_kvcache for normal decoding. - ( - seq_lens_arg, - _, - block_tables_arg, - ) = get_seq_len_block_table_args(decode_meta, False, attn_type) - descale_shape = (seq_lens_arg.shape[0], key_cache.shape[-2]) - flash_attn_with_kvcache( + block_tables_arg = decode_meta.block_tables + try: + output[num_prefill_tokens:] = flash_attn_with_kvcache( q=decode_query.unsqueeze(1), - k_cache=key_cache, - v_cache=value_cache, + k_cache=k_cache, + v_cache=v_cache, block_table=block_tables_arg, - cache_seqlens=seq_lens_arg, - softmax_scale=softmax_scale, + cache_seqlens=decode_meta.seq_lens_tensor, + softmax_scale=self.scale, causal=True, - window_size=window_size, - alibi_slopes=alibi_slopes, - softcap=logits_soft_cap, - out=decode_output.unsqueeze(1), - fa_version=self.vllm_flash_attn_version, - q_descale=layer._q_scale.expand(descale_shape), - k_descale=layer._k_scale.expand(descale_shape), - v_descale=layer._v_scale.expand(descale_shape), - ) + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes, + softcap=self.logits_soft_cap, + ).squeeze(1) + except Exception as e: + logger.error( + f"Error in PagedAttention.forward_decode: {str(e)}") + raise e + + # Reshape the output tensor. + return output.view(-1, num_heads, head_size) + + + def forward_with_kv_cache_only( + self, + query: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ): + if not attn_metadata.decode_metadata: + block_tables_arg = attn_metadata.cross_layer_shared_block_tables + else: + block_tables_arg = attn_metadata.block_tables + + output = flash_attn_with_kvcache( + q=query.unsqueeze(1), + k_cache=k_cache, + v_cache=v_cache, + block_table=block_tables_arg, + cache_seqlens=attn_metadata.seq_lens_tensor, + softmax_scale=self.scale, + causal=True, + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes, + softcap=self.logits_soft_cap, + ).squeeze(1) return output + def forward( + self, + layer: AttentionLayer, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: DifferentialFlashAttentionMetadata, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with FlashAttention. -def _get_query_key_seq_metadata( - attn_metadata, - is_prompt: bool, - attn_type: str, -) -> tuple: - """ - Returns sequence metadata for key and query based on the specified - attention type and whether input is a prompt. - - This function computes the starting locations and maximum sequence lengths - for key and query sequences for different attention types. - - Args: - attn_metadata: The attention metadata object - is_prompt (bool): A flag indicating if the input is a prompt - attn_type (AttentionType): The type of attention being used. - - Returns: - tuple: A tuple containing four integers: - - Starting location for the query sequence. - - Maximum sequence length for the query sequence. - - Starting location for the key sequence. - - Maximum sequence length for the key sequence. - - Raises: - AttributeError: If an invalid attention type is provided. - """ - if attn_type == AttentionType.DECODER: - # Decoder self-attention - # Choose max_seq_len based on whether we are in prompt_run - if is_prompt: - max_seq_len = attn_metadata.max_prefill_seq_len - else: - max_seq_len = attn_metadata.max_decode_seq_len - return (attn_metadata.seq_start_loc, max_seq_len, - attn_metadata.seq_start_loc, max_seq_len) - - elif attn_type == AttentionType.ENCODER_DECODER: - # This is cross attention between the where the key - # is the precomputed encoder attention and query - # is the input sequence. - # Choose query max length based on whether it is prompt - # or not. - if is_prompt: - max_seq_len = attn_metadata.max_prefill_seq_len - else: - max_seq_len = attn_metadata.max_decode_seq_len - return (attn_metadata.seq_start_loc, max_seq_len, - attn_metadata.encoder_seq_start_loc, - attn_metadata.max_encoder_seq_len) - elif attn_type == AttentionType.ENCODER: - # For encoder attention both the query and the key are same i.e the - # encoder sequence. - return (attn_metadata.encoder_seq_start_loc, - attn_metadata.max_encoder_seq_len, - attn_metadata.encoder_seq_start_loc, - attn_metadata.max_encoder_seq_len) - elif attn_type == AttentionType.ENCODER_ONLY: - assert is_prompt, "Should not have decode for encoder only model." - return (attn_metadata.seq_start_loc, attn_metadata.max_prefill_seq_len, - attn_metadata.seq_start_loc, attn_metadata.max_prefill_seq_len) - else: - raise AttributeError(f"Invalid attention type {str(attn_type)}") - - -def _get_causal_option(attn_type: str) -> bool: - """ - Determine whether the given attention type is suitable for causal - attention mechanisms. + Args: + query: shape = [num_tokens, num_heads, head_size] + key: shape = [num_tokens, num_kv_heads, head_size] + value: shape = [num_tokens, num_kv_heads, head_size] + output: shape = [num_tokens, num_heads, head_size] + kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] + NOTE: kv_cache will be an empty tensor with shape [0] + for profiling run. + attn_metadata: Metadata for attention. + NOTE: It in-place updates the output tensor. + NOTE: FP8 quantization, flash-attn expect the size of + {q,k,v}_descale to be (num_sequences, num_kv_heads). + We use torch's .expand() to avoid duplicating values + """ + if self.lambda_full is None: + self.lambda_init = self.differential_flash_attention_config["lambda_init"] + lambda_q1 = self.differential_flash_attention_config["lambda_q1"] + lambda_k1 = self.differential_flash_attention_config["lambda_k1"] + lambda_q2 = self.differential_flash_attention_config["lambda_q2"] + lambda_k2 = self.differential_flash_attention_config["lambda_k2"] + lambda_1 = torch.exp(torch.sum(lambda_q1 * lambda_k1, dim=-1).float()).type_as(q) + lambda_2 = torch.exp(torch.sum(lambda_q2 * lambda_k2, dim=-1).float()).type_as(q) + self.lambda_full = lambda_1 - lambda_2 + self.lambda_init + - Args: - attn_type (AttentionType): The type of attention being evaluated + if not self.used_shared_kv_cache: # need to generate kv-cache + q = q.view(-1, self.num_heads, self.head_size) + k = k.view(-1, self.num_kv_heads, self.head_size) + v = v.view(-1, self.num_kv_heads, self.head_size) - Returns: - bool: Returns `True` if the attention type is suitable for causal - attention (i.e., not encoder, encoder-only, or encoder-decoder), - otherwise returns `False`. - """ - return not (attn_type == AttentionType.ENCODER - or attn_type == AttentionType.ENCODER_ONLY - or attn_type == AttentionType.ENCODER_DECODER) + q1, q2 = self.split_heads(q) + k1, k2 = self.split_heads(k) + v1, v2 = self.split_heads(v) + + # kv_cache shape is (2, 2, num_blocks, block_size * num_kv_heads // 2 * head_size) + # Split by half along the first dimension. + kv_cache1, kv_cache2 = self.split_kv_cache(kv_cache) + assert kv_cache1.is_contiguous(), "kv_cache1 is not contiguous" + assert kv_cache2.is_contiguous(), "kv_cache2 is not contiguous" + + if kv_cache1.numel() != 0: + self.populate_kv_cache(layer, k1, v1, kv_cache1, attn_metadata) + self.populate_kv_cache(layer, k2, v2, kv_cache2, attn_metadata) + + key_cache1, value_cache1 = self.split_kv_cache(kv_cache1) + key_cache2, value_cache2 = self.split_kv_cache(kv_cache2) + else: + key_cache1, value_cache1 = torch.empty(0), torch.empty(0) + key_cache2, value_cache2 = torch.empty(0), torch.empty(0) + attn11 = self.forward_generate_kv_cache(q1, k1, v1, key_cache1, value_cache1, attn_metadata) + attn12 = self.forward_generate_kv_cache(q1, k1, v2, key_cache1, value_cache2, attn_metadata) + attn11 = attn11.view(q1.shape) + attn12 = attn12.view(q1.shape) + attn1 = torch.cat([attn11, attn12], dim=-1) + + attn21 = self.forward_generate_kv_cache(q2, k2, v1, key_cache2, value_cache1, attn_metadata) + attn22 = self.forward_generate_kv_cache(q2, k2, v2, key_cache2, value_cache2, attn_metadata) + attn21 = attn21.view(q2.shape) + attn22 = attn22.view(q2.shape) + attn2 = torch.cat([attn21, attn22], dim=-1) + + attn = attn1 - self.lambda_full * attn2 + # attn shape (-1, self.num_heads // 2, 2 * self.head_dim) + attn = self.subln(attn) + attn = attn * (1 - self.lambda_init) + # reshape back to 2 * num_head + attn_output = rearrange(attn, "... H (two D) -> ... (H two) D", two=2) + + else: # re-use the kv cache, full attention + q = q.view(-1, self.num_heads, self.head_size) + q1, q2 = self.split_heads(q) + # kv_cache shape is (2, num_blocks, block_size * num_kv_heads * head_size) + kv_cache1, kv_cache2 = self.split_kv_cache(kv_cache) + key_cache1, value_cache1 = kv_cache1[0], kv_cache1[1] + key_cache2, value_cache2 = kv_cache2[0], kv_cache2[1] + + attn11 = self.forward_with_kv_cache_only(q1, key_cache1, value_cache1, attn_metadata) + attn12 = self.forward_with_kv_cache_only(q1, key_cache1, value_cache2, attn_metadata) + attn11 = attn11.view(q1.shape) + attn12 = attn12.view(q1.shape) + attn1 = torch.cat([attn11, attn12], dim=-1) + + attn21 = self.forward_with_kv_cache_only(q2, key_cache2, value_cache1, attn_metadata) + attn22 = self.forward_with_kv_cache_only(q2, key_cache2, value_cache2, attn_metadata) + attn21 = attn21.view(q2.shape) + attn22 = attn22.view(q2.shape) + attn2 = torch.cat([attn21, attn22], dim=-1) + + attn = attn1 - self.lambda_full * attn2 + attn = self.subln(attn) + attn = attn * (1 - self.lambda_init) + # reshape back to 2 * num_head + attn_output = rearrange(attn, "... H (two D) -> ... (H two) D", two=2) + attn_output = attn_output.view(-1, self.num_heads * self.head_size) + return attn_output \ No newline at end of file diff --git a/vllm/model_executor/models/phi3samba.py b/vllm/model_executor/models/phi3samba.py index 7ca88ee865f..5b583848fcf 100644 --- a/vllm/model_executor/models/phi3samba.py +++ b/vllm/model_executor/models/phi3samba.py @@ -129,24 +129,37 @@ def __init__(self, assert self.num_heads % 2 == 0, 'num_heads should be even' assert self.num_key_value_heads % 2 == 0, 'num_heads should be even' + + self.lambda_init = self.lambda_init_fn(layer_idx) + self.lambda_q1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1)) + self.lambda_k1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1)) + self.lambda_q2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1)) + self.lambda_k2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1)) + self.subln = nn.RMSNorm(2 * self.head_dim, eps=1e-5, elementwise_affine=True) + params = {'differential_flash_attention_config': + { + 'used_shared_kv_cache': self.yoco_cross, + 'lambda_init': self.lambda_init, + 'lambda_q1': self.lambda_q1, + 'lambda_k1': self.lambda_k1, + 'lambda_q2': self.lambda_q2, + 'lambda_k2': self.lambda_k2, + "subln": self.subln, + } + } + self.attn = Attention( - self.num_heads//2, + self.num_heads, self.head_dim, self.head_dim**-0.5, - num_kv_heads=self.num_key_value_heads//2, + num_kv_heads=self.num_key_value_heads, cache_config=cache_config, per_layer_sliding_window=sliding_window, prefix=f"{prefix}.attn", - attn_type=AttentionType.DECODER_DECODER if self.yoco_cross else AttentionType.DECODER + attn_type=AttentionType.DECODER_DECODER if self.yoco_cross else AttentionType.DECODER, + **params ) - self.subln = nn.RMSNorm(2 * self.head_dim, eps=1e-5, elementwise_affine=True) - - self.lambda_init = self.lambda_init_fn(layer_idx) - self.lambda_q1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1)) - self.lambda_k1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1)) - self.lambda_q2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1)) - self.lambda_k2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1)) self._k_scale = torch.tensor(1.0, dtype=torch.float32) self._v_scale = torch.tensor(1.0, dtype=torch.float32) @@ -320,82 +333,89 @@ def forward( if not self.yoco_cross: # need to generate kv-cache qkv = self.Wqkv(hidden_states) q, k, v = qkv.split([self.hidden_size, self.num_key_value_heads * self.head_dim, self.num_key_value_heads * self.head_dim], dim=-1) - # q, k = self.rotary_emb(positions, q, k) - # reshape - q = q.view(-1, self.num_heads, self.head_dim) - k = k.view(-1, self.num_key_value_heads, self.head_dim) - v = v.view(-1, self.num_key_value_heads, self.head_dim) - - q1, q2 = self.split_heads(q) - k1, k2 = self.split_heads(k) - v1, v2 = self.split_heads(v) - - # kv_cache shape is (2, 2, num_blocks, block_size * num_kv_heads // 2 * head_size) - # Split by half along the first dimension. - kv_cache1, kv_cache2 = self.split_kv_cache(kv_cache) - assert kv_cache1.is_contiguous(), "kv_cache1 is not contiguous" - assert kv_cache2.is_contiguous(), "kv_cache2 is not contiguous" + reference_attn_output = self.attn(q, k, v) + # # q, k = self.rotary_emb(positions, q, k) + # # reshape + # q = q.view(-1, self.num_heads, self.head_dim) + # k = k.view(-1, self.num_key_value_heads, self.head_dim) + # v = v.view(-1, self.num_key_value_heads, self.head_dim) + + # q1, q2 = self.split_heads(q) + # k1, k2 = self.split_heads(k) + # v1, v2 = self.split_heads(v) + + # # kv_cache shape is (2, 2, num_blocks, block_size * num_kv_heads // 2 * head_size) + # # Split by half along the first dimension. + # kv_cache1, kv_cache2 = self.split_kv_cache(kv_cache) + # assert kv_cache1.is_contiguous(), "kv_cache1 is not contiguous" + # assert kv_cache2.is_contiguous(), "kv_cache2 is not contiguous" - if kv_cache1.numel() != 0: - self.populate_kv_cache(k1, v1, kv_cache1, attn_metadata) - self.populate_kv_cache(k2, v2, kv_cache2, attn_metadata) + # if kv_cache1.numel() != 0: + # self.populate_kv_cache(k1, v1, kv_cache1, attn_metadata) + # self.populate_kv_cache(k2, v2, kv_cache2, attn_metadata) - key_cache1, value_cache1 = self.split_kv_cache(kv_cache1) - key_cache2, value_cache2 = self.split_kv_cache(kv_cache2) - else: - key_cache1, value_cache1 = torch.empty(0), torch.empty(0) - key_cache2, value_cache2 = torch.empty(0), torch.empty(0) - attn11 = self.forward_customized(q1, k1, v1, key_cache1, value_cache1, attn_metadata) - attn12 = self.forward_customized(q1, k1, v2, key_cache1, value_cache2, attn_metadata) - attn11 = attn11.view(q1.shape) - attn12 = attn12.view(q1.shape) - attn1 = torch.cat([attn11, attn12], dim=-1) - - attn21 = self.forward_customized(q2, k2, v1, key_cache2, value_cache1, attn_metadata) - attn22 = self.forward_customized(q2, k2, v2, key_cache2, value_cache2, attn_metadata) - attn21 = attn21.view(q2.shape) - attn22 = attn22.view(q2.shape) - attn2 = torch.cat([attn21, attn22], dim=-1) - - lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(q) - lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(q) - lambda_full = lambda_1 - lambda_2 + self.lambda_init - attn = attn1 - lambda_full * attn2 - # attn shape (-1, self.num_heads // 2, 2 * self.head_dim) - attn = self.subln(attn) - attn = attn * (1 - self.lambda_init) - # reshape back to 2 * num_head - attn_output = rearrange(attn, "... H (two D) -> ... (H two) D", two=2) - + # key_cache1, value_cache1 = self.split_kv_cache(kv_cache1) + # key_cache2, value_cache2 = self.split_kv_cache(kv_cache2) + # else: + # key_cache1, value_cache1 = torch.empty(0), torch.empty(0) + # key_cache2, value_cache2 = torch.empty(0), torch.empty(0) + # attn11 = self.forward_customized(q1, k1, v1, key_cache1, value_cache1, attn_metadata) + # attn12 = self.forward_customized(q1, k1, v2, key_cache1, value_cache2, attn_metadata) + # attn11 = attn11.view(q1.shape) + # attn12 = attn12.view(q1.shape) + # attn1 = torch.cat([attn11, attn12], dim=-1) + + # attn21 = self.forward_customized(q2, k2, v1, key_cache2, value_cache1, attn_metadata) + # attn22 = self.forward_customized(q2, k2, v2, key_cache2, value_cache2, attn_metadata) + # attn21 = attn21.view(q2.shape) + # attn22 = attn22.view(q2.shape) + # attn2 = torch.cat([attn21, attn22], dim=-1) + + # lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(q) + # lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(q) + # lambda_full = lambda_1 - lambda_2 + self.lambda_init + + # attn = attn1 - lambda_full * attn2 + # # attn shape (-1, self.num_heads // 2, 2 * self.head_dim) + # attn = self.subln(attn) + # attn = attn * (1 - self.lambda_init) + # # reshape back to 2 * num_head + # attn_output = rearrange(attn, "... H (two D) -> ... (H two) D", two=2) + attn_output = self.attn(q, k, v) else: # re-use the kv cache, full attention q = self.Wqkv(hidden_states) - q = q.view(-1, self.num_heads, self.head_dim) - q1, q2 = self.split_heads(q) - # kv_cache shape is (2, num_blocks, block_size * num_kv_heads * head_size) - kv_cache1, kv_cache2 = self.split_kv_cache(kv_cache) - key_cache1, value_cache1 = kv_cache1[0], kv_cache1[1] - key_cache2, value_cache2 = kv_cache2[0], kv_cache2[1] + # q = q.view(-1, self.num_heads, self.head_dim) + # q1, q2 = self.split_heads(q) + # # kv_cache shape is (2, num_blocks, block_size * num_kv_heads * head_size) + # kv_cache1, kv_cache2 = self.split_kv_cache(kv_cache) + # key_cache1, value_cache1 = kv_cache1[0], kv_cache1[1] + # key_cache2, value_cache2 = kv_cache2[0], kv_cache2[1] - attn11 = self.forward_decode(q1, key_cache1, value_cache1, attn_metadata) - attn12 = self.forward_decode(q1, key_cache1, value_cache2, attn_metadata) - attn11 = attn11.view(q1.shape) - attn12 = attn12.view(q1.shape) - attn1 = torch.cat([attn11, attn12], dim=-1) - - attn21 = self.forward_decode(q2, key_cache2, value_cache1, attn_metadata) - attn22 = self.forward_decode(q2, key_cache2, value_cache2, attn_metadata) - attn21 = attn21.view(q2.shape) - attn22 = attn22.view(q2.shape) - attn2 = torch.cat([attn21, attn22], dim=-1) - - lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(q) - lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(q) - lambda_full = lambda_1 - lambda_2 + self.lambda_init - attn = attn1 - lambda_full * attn2 - attn = self.subln(attn) - attn = attn * (1 - self.lambda_init) - # reshape back to 2 * num_head - attn_output = rearrange(attn, "... H (two D) -> ... (H two) D", two=2) + # attn11 = self.forward_decode(q1, key_cache1, value_cache1, attn_metadata) + # attn12 = self.forward_decode(q1, key_cache1, value_cache2, attn_metadata) + # attn11 = attn11.view(q1.shape) + # attn12 = attn12.view(q1.shape) + # attn1 = torch.cat([attn11, attn12], dim=-1) + + # attn21 = self.forward_decode(q2, key_cache2, value_cache1, attn_metadata) + # attn22 = self.forward_decode(q2, key_cache2, value_cache2, attn_metadata) + # attn21 = attn21.view(q2.shape) + # attn22 = attn22.view(q2.shape) + # attn2 = torch.cat([attn21, attn22], dim=-1) + + # lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(q) + # lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(q) + # lambda_full = lambda_1 - lambda_2 + self.lambda_init + # attn = attn1 - lambda_full * attn2 + # attn = self.subln(attn) + # attn = attn * (1 - self.lambda_init) + # # reshape back to 2 * num_head + # attn_output = rearrange(attn, "... H (two D) -> ... (H two) D", two=2) + + + if self.attn.kv_cache[0].numel() == 0: + self.attn.kv_cache = [kv_cache] + attn_output = self.attn(q, None, None) attn_output = attn_output.view(-1, self.num_heads * self.head_dim) return self.out_proj(attn_output) From 344190f319d7cea4dc8214a0964804fa696d22f3 Mon Sep 17 00:00:00 2001 From: Congcong Chen Date: Sun, 8 Jun 2025 06:31:39 +0000 Subject: [PATCH 04/24] clean up code Signed-off-by: Congcong Chen --- vllm/model_executor/models/phi3samba.py | 232 ------------------------ 1 file changed, 232 deletions(-) diff --git a/vllm/model_executor/models/phi3samba.py b/vllm/model_executor/models/phi3samba.py index 5b583848fcf..d0f6f12582e 100644 --- a/vllm/model_executor/models/phi3samba.py +++ b/vllm/model_executor/models/phi3samba.py @@ -167,161 +167,6 @@ def __init__(self, def lambda_init_fn(self, depth): return 0.8 - 0.6 * math.exp(-0.3 * depth) - - def split_heads(self, x): - # split by num_heads, the stripe pattern is friendly to tensor parallel. - x = rearrange(x, "... (H two) D -> ... H two D", two=2) - x1 = x[..., 0, :] - x2 = x[..., 1, :] - return x1.contiguous(), x2.contiguous() - - def split_kv_cache(self, x): - # split by num_heads, the stripe pattern is friendly to tensor parallel. - if x.numel() == 0: - return torch.empty(0), torch.empty(0) - - x1, x2 = x[0], x[1] - return x1, x2 - - def forward_decode( - self, - query: torch.Tensor, - k_cache: torch.Tensor, - v_cache: torch.Tensor, - attn_metadata: AttentionMetadata, - ): - if not attn_metadata.decode_metadata: - block_tables_arg = attn_metadata.cross_layer_shared_block_tables - else: - block_tables_arg = attn_metadata.block_tables - - output = flash_attn_with_kvcache( - q=query.unsqueeze(1), - k_cache=k_cache, - v_cache=v_cache, - block_table=block_tables_arg, - cache_seqlens=attn_metadata.seq_lens_tensor, - softmax_scale=self.attn.impl.scale, - causal=True, - window_size=self.attn.impl.sliding_window, - alibi_slopes=self.attn.impl.alibi_slopes, - softcap=self.attn.impl.logits_soft_cap, - ).squeeze(1) - return output - - def populate_kv_cache(self, - key, - value, - kv_cache, - attn_metadata): - if (kv_cache.numel() > 0): - if (key is not None) and (value is not None): - updated_slot_mapping = attn_metadata.slot_mapping - # previous_key_cache_sum = key_cache.sum() - # previous_value_cache_sum = value_cache.sum() - - torch.ops._C_cache_ops.reshape_and_cache_flash( - key, - value, - kv_cache[0], - kv_cache[1], - updated_slot_mapping.flatten(), - self.attn.impl.kv_cache_dtype, - self._k_scale, - self._v_scale, - ) - # assert key_cache.sum() - previous_key_cache_sum == key.sum(), "key_cache sum mismatch" - # assert value_cache.sum() - previous_value_cache_sum == value.sum(), "value_cache sum mismatch" - # if key_cache.sum() - previous_key_cache_sum != key.sum(): - # print("key_cache sum mismatch") - # if value_cache.sum() - previous_value_cache_sum != value.sum(): - # print("value_cache sum mismatch") - - def forward_customized( - self, - query: torch.Tensor, - key: Optional[torch.Tensor], - value: Optional[torch.Tensor], - k_cache: torch.Tensor, - v_cache: torch.Tensor, - attn_metadata: AttentionMetadata - ) -> torch.Tensor: - - head_size = self.head_dim - num_heads = self.num_heads // 2 - num_kv_heads = self.num_key_value_heads // 2 - - query = query.view(-1, num_heads, head_size) - if key is not None: - assert value is not None - key = key.view(-1, num_kv_heads, head_size) - value = value.view(-1, num_kv_heads, head_size) - else: - assert value is None - - num_prefill_tokens = attn_metadata.num_prefill_tokens - num_decode_tokens = attn_metadata.num_decode_tokens - assert key.shape[0] == num_prefill_tokens + num_decode_tokens, "key shape mismatch" - assert value.shape[0] == num_prefill_tokens + num_decode_tokens, "value shape mismatch" - - output = torch.empty_like(query) - # Query for decode. KV is not needed because it is already cached. - decode_query = query[num_prefill_tokens:] - # QKV for prefill. - query = query[:num_prefill_tokens] - if key is not None and value is not None: - key = key[:num_prefill_tokens] - value = value[:num_prefill_tokens] - - assert query.shape[0] == num_prefill_tokens, "query shape mismatch" - assert decode_query.shape[0] == num_decode_tokens, "decode query shape mismatch" - - if prefill_meta := attn_metadata.prefill_metadata: - # Prompt run. - if k_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0: - # normal attention - prefill_output = flash_attn_varlen_func( - q=query, - k=key, - v=value, - cu_seqlens_q=prefill_meta.seq_start_loc, - cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_q=prefill_meta.max_prefill_seq_len, - max_seqlen_k=prefill_meta.max_prefill_seq_len, - softmax_scale=self.attn.impl.scale, - causal=True, - window_size=self.attn.impl.sliding_window, - alibi_slopes=self.attn.impl.alibi_slopes, - softcap=self.attn.impl.logits_soft_cap, - ) - assert prefill_output.shape == output[:num_prefill_tokens].shape - output[:num_prefill_tokens] = prefill_output - else: - raise Exception("prefix caching not supported") - - if decode_meta := attn_metadata.decode_metadata: - block_tables_arg = decode_meta.block_tables - try: - output[num_prefill_tokens:] = flash_attn_with_kvcache( - q=decode_query.unsqueeze(1), - k_cache=k_cache, - v_cache=v_cache, - block_table=block_tables_arg, - cache_seqlens=decode_meta.seq_lens_tensor, - softmax_scale=self.attn.impl.scale, - causal=True, - window_size=self.attn.impl.sliding_window, - alibi_slopes=self.attn.impl.alibi_slopes, - softcap=self.attn.impl.logits_soft_cap, - ).squeeze(1) - except Exception as e: - logger.error( - f"Error in PagedAttention.forward_decode: {str(e)}") - raise e - - # Reshape the output tensor. - return output.view(-1, num_heads, head_size) - def forward( self, hidden_states: torch.Tensor, @@ -333,86 +178,9 @@ def forward( if not self.yoco_cross: # need to generate kv-cache qkv = self.Wqkv(hidden_states) q, k, v = qkv.split([self.hidden_size, self.num_key_value_heads * self.head_dim, self.num_key_value_heads * self.head_dim], dim=-1) - reference_attn_output = self.attn(q, k, v) - # # q, k = self.rotary_emb(positions, q, k) - # # reshape - # q = q.view(-1, self.num_heads, self.head_dim) - # k = k.view(-1, self.num_key_value_heads, self.head_dim) - # v = v.view(-1, self.num_key_value_heads, self.head_dim) - - # q1, q2 = self.split_heads(q) - # k1, k2 = self.split_heads(k) - # v1, v2 = self.split_heads(v) - - # # kv_cache shape is (2, 2, num_blocks, block_size * num_kv_heads // 2 * head_size) - # # Split by half along the first dimension. - # kv_cache1, kv_cache2 = self.split_kv_cache(kv_cache) - # assert kv_cache1.is_contiguous(), "kv_cache1 is not contiguous" - # assert kv_cache2.is_contiguous(), "kv_cache2 is not contiguous" - - # if kv_cache1.numel() != 0: - # self.populate_kv_cache(k1, v1, kv_cache1, attn_metadata) - # self.populate_kv_cache(k2, v2, kv_cache2, attn_metadata) - - # key_cache1, value_cache1 = self.split_kv_cache(kv_cache1) - # key_cache2, value_cache2 = self.split_kv_cache(kv_cache2) - # else: - # key_cache1, value_cache1 = torch.empty(0), torch.empty(0) - # key_cache2, value_cache2 = torch.empty(0), torch.empty(0) - # attn11 = self.forward_customized(q1, k1, v1, key_cache1, value_cache1, attn_metadata) - # attn12 = self.forward_customized(q1, k1, v2, key_cache1, value_cache2, attn_metadata) - # attn11 = attn11.view(q1.shape) - # attn12 = attn12.view(q1.shape) - # attn1 = torch.cat([attn11, attn12], dim=-1) - - # attn21 = self.forward_customized(q2, k2, v1, key_cache2, value_cache1, attn_metadata) - # attn22 = self.forward_customized(q2, k2, v2, key_cache2, value_cache2, attn_metadata) - # attn21 = attn21.view(q2.shape) - # attn22 = attn22.view(q2.shape) - # attn2 = torch.cat([attn21, attn22], dim=-1) - - # lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(q) - # lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(q) - # lambda_full = lambda_1 - lambda_2 + self.lambda_init - - # attn = attn1 - lambda_full * attn2 - # # attn shape (-1, self.num_heads // 2, 2 * self.head_dim) - # attn = self.subln(attn) - # attn = attn * (1 - self.lambda_init) - # # reshape back to 2 * num_head - # attn_output = rearrange(attn, "... H (two D) -> ... (H two) D", two=2) attn_output = self.attn(q, k, v) else: # re-use the kv cache, full attention q = self.Wqkv(hidden_states) - # q = q.view(-1, self.num_heads, self.head_dim) - # q1, q2 = self.split_heads(q) - # # kv_cache shape is (2, num_blocks, block_size * num_kv_heads * head_size) - # kv_cache1, kv_cache2 = self.split_kv_cache(kv_cache) - # key_cache1, value_cache1 = kv_cache1[0], kv_cache1[1] - # key_cache2, value_cache2 = kv_cache2[0], kv_cache2[1] - - # attn11 = self.forward_decode(q1, key_cache1, value_cache1, attn_metadata) - # attn12 = self.forward_decode(q1, key_cache1, value_cache2, attn_metadata) - # attn11 = attn11.view(q1.shape) - # attn12 = attn12.view(q1.shape) - # attn1 = torch.cat([attn11, attn12], dim=-1) - - # attn21 = self.forward_decode(q2, key_cache2, value_cache1, attn_metadata) - # attn22 = self.forward_decode(q2, key_cache2, value_cache2, attn_metadata) - # attn21 = attn21.view(q2.shape) - # attn22 = attn22.view(q2.shape) - # attn2 = torch.cat([attn21, attn22], dim=-1) - - # lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(q) - # lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(q) - # lambda_full = lambda_1 - lambda_2 + self.lambda_init - # attn = attn1 - lambda_full * attn2 - # attn = self.subln(attn) - # attn = attn * (1 - self.lambda_init) - # # reshape back to 2 * num_head - # attn_output = rearrange(attn, "... H (two D) -> ... (H two) D", two=2) - - if self.attn.kv_cache[0].numel() == 0: self.attn.kv_cache = [kv_cache] attn_output = self.attn(q, None, None) From 3f89641779138ff5c0d028753114424806788b52 Mon Sep 17 00:00:00 2001 From: Congcong Chen Date: Thu, 12 Jun 2025 05:29:55 +0000 Subject: [PATCH 05/24] clean up code Signed-off-by: Congcong Chen --- vllm/attention/backends/abstract.py | 3 +- vllm/model_executor/models/phi3samba.py | 109 ++++-------------------- 2 files changed, 19 insertions(+), 93 deletions(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 48af715edc7..428bddb0e12 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -32,7 +32,8 @@ class AttentionType: ENCODER_ONLY = "encoder_only" # Attention between dec. Q and enc. K/V for encoder-decoder ENCODER_DECODER = "encoder_decoder" - DECODER_DECODER = "decoder_decoder" # Attention layer that reuse kv cache + # Attention layer that reuse kv cache + DECODER_DECODER = "decoder_decoder" class AttentionBackend(ABC): diff --git a/vllm/model_executor/models/phi3samba.py b/vllm/model_executor/models/phi3samba.py index d0f6f12582e..d508be81fb8 100644 --- a/vllm/model_executor/models/phi3samba.py +++ b/vllm/model_executor/models/phi3samba.py @@ -1,6 +1,5 @@ -from typing import List, Optional, Tuple, Union, Iterable, Dict +from typing import List, Optional, Tuple, Union, Iterable import math -import copy import torch import torch.nn as nn @@ -17,7 +16,7 @@ RowParallelLinear, ColumnParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead) from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -30,10 +29,7 @@ causal_conv1d_fn, causal_conv1d_update) from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( selective_scan_fn, selective_state_update) -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, AttentionType) -from vllm.vllm_flash_attn import (flash_attn_varlen_func, - flash_attn_with_kvcache) +from vllm.attention.backends.abstract import (AttentionMetadata, AttentionType) from vllm.logger import init_logger from .utils import (maybe_prefix, make_layers) @@ -52,6 +48,7 @@ def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: # print(f"x1 shape: {x1.shape}, x2 shape: {x2.shape}") return x1 * nn.functional.silu(x2) + class SambaMLP(nn.Module): """Gated Linear Unit. @@ -77,9 +74,11 @@ def forward(self, hidden_states): return self.fc2(y) -class SambaAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" +def get_virtual_engine(): + forward_context: ForwardContext = get_forward_context() + return forward_context.virtual_engine +class SambaAttention(nn.Module): def __init__(self, config, layer_idx: Optional[int] = None, @@ -87,24 +86,16 @@ def __init__(self, cache_config: Optional[CacheConfig] = None, prefix: str = ""): super().__init__() - self.config = config - self.layer_idx = layer_idx if layer_idx is None: logger.warning_once( f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " "when creating this class." ) - - self.attention_dropout = config.attention_dropout self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta - self.is_causal = True self.yoco_cross = yoco_cross if (self.head_dim * self.num_heads) != self.hidden_size: @@ -120,8 +111,6 @@ def __init__(self, else: self.Wqkv = nn.Linear(self.hidden_size, op_size, bias=True) - assert self.config.attention_dropout == 0.0, 'Attention dropout is not supported for now' - # disable sliding window for the second half of the model sliding_window = config.interleaved_sliding_window[layer_idx] if layer_idx >= config.num_hidden_layers // 2 or layer_idx % 2 == 0: @@ -161,9 +150,6 @@ def __init__(self, **params ) - self._k_scale = torch.tensor(1.0, dtype=torch.float32) - self._v_scale = torch.tensor(1.0, dtype=torch.float32) - def lambda_init_fn(self, depth): return 0.8 - 0.6 * math.exp(-0.3 * depth) @@ -181,8 +167,9 @@ def forward( attn_output = self.attn(q, k, v) else: # re-use the kv cache, full attention q = self.Wqkv(hidden_states) - if self.attn.kv_cache[0].numel() == 0: - self.attn.kv_cache = [kv_cache] + virtual_engine = get_virtual_engine() + if self.attn.kv_cache[virtual_engine].numel() == 0: + self.attn.kv_cache[virtual_engine] = kv_cache attn_output = self.attn(q, None, None) attn_output = attn_output.view(-1, self.num_heads * self.head_dim) return self.out_proj(attn_output) @@ -227,16 +214,6 @@ def __init__( self.in_proj = MergedColumnParallelLinear(self.d_model, [self.d_inner], bias=bias, **factory_kwargs) self.out_proj = RowParallelLinear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) return - # self.conv1d = nn.Conv1d( - # in_channels=self.d_inner, - # out_channels=self.d_inner, - # bias=conv_bias, - # kernel_size=d_conv, - # groups=self.d_inner, - # padding=d_conv - 1, - # **factory_kwargs, - # ) - self.conv1d = ColumnParallelLinear( input_size=d_conv, output_size=self.d_inner, @@ -249,16 +226,12 @@ def __init__( # doesn't allow to override it self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) - # self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs) self.in_proj = MergedColumnParallelLinear(self.d_model, [self.d_inner] * 2, bias=bias, params_dtype=dtype, ) - # self.x_proj = nn.Linear( - # self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs - # ) # selective projection used to make dt, B and C input dependent self.x_proj = RowParallelLinear( self.d_inner, @@ -267,7 +240,6 @@ def __init__( params_dtype=dtype, ) - # self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs) # time step projection (discretization) - # In the forward we need to apply dt_proj without the bias, # as the bias is added in the selective scan kernel. @@ -297,7 +269,6 @@ def __init__( )) self.D = nn.Parameter(torch.ones(self.d_inner, dtype=torch.float32)) - # self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) self.out_proj = RowParallelLinear( self.d_inner, self.d_model, @@ -305,7 +276,6 @@ def __init__( input_is_parallel=True, params_dtype=dtype, ) - print(f"-------- layer_idx {layer_idx}") self.activation = "silu" def forward( @@ -451,9 +421,6 @@ def __init__(self, yoco_cross=self.yoco_cross, yoco_kv=self.yoco_mb, **factory_kwargs) else: self.attn = SambaAttention(config, layer_idx=layer_idx, yoco_cross=self.yoco_cross, cache_config=cache_config, prefix=f"{prefix}.self_attn") - - self.resid_attn_dropout = nn.Dropout(config.resid_pdrop) - self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) def forward( @@ -488,21 +455,11 @@ def forward( kv_cache, attn_metadata, ) - try: - hidden_states = residual + self.resid_attn_dropout(attn_outputs) - except Exception as e: - print('>>> exception: ', e) - print('>>>', hidden_states.shape) - print('>>>', self.layer_idx) - print('>>>', residual.shape) - print('>>>', self.resid_attn_dropout) - print('>>>', attn_outputs) - raise - + hidden_states = residual + attn_outputs residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states.to(dtype=self.post_attention_layernorm.weight.dtype)) hidden_states = self.mlp(hidden_states) - hidden_states = residual + self.resid_mlp_dropout(hidden_states) + hidden_states = residual + hidden_states return hidden_states, ssm_output @@ -523,19 +480,14 @@ def __init__( prefix: str = "" ) -> None: super().__init__() - self.config = config - - self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - - # self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, org_num_embeddings=config.vocab_size, ) - self.embed_dropout = nn.Dropout(config.embd_pdrop) + # Pipeline parallel is not supported since the second half of the layers share the kv cache. if get_pp_group().world_size != 1: raise ValueError("Pipeline Parallel not supported") @@ -591,10 +543,6 @@ def forward( hidden_states = hidden_states.index_select(0, selected_token_indices) ssm_output = ssm_output.index_select(0, selected_token_indices) - - # start_env = torch.cuda.Event(enable_timing=True) - # end_env = torch.cuda.Event(enable_timing=True) - # start_env.record() if layer.use_mamba: if i < self.config.num_hidden_layers // 2: mamba_cache = mamba_cache_params.at_layer_idx(mamba_state_idx) @@ -637,9 +585,6 @@ def forward( None, # mamba_cache_params ssm_output = ssm_output ) - # end_env.record() - # torch.cuda.synchronize() - # print('>>> layer', i, 'time', start_env.elapsed_time(end_env)) hidden_states = self.final_layernorm(hidden_states.to(dtype=self.final_layernorm.weight.dtype)) return hidden_states @@ -690,7 +635,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size, logits_as_input=False) - # self.sampler = Sampler() self.sampler = get_sampler() def forward( @@ -767,7 +711,6 @@ def load_weights( weights: Iterable[Tuple[str, torch.Tensor]], ): weights = {name: weight for name, weight in weights} - print(f"--------- num of keys: {len(weights.keys())}") adjusted_weights = {} for name, weight in weights.items(): if "A_log" in name: @@ -777,31 +720,13 @@ def load_weights( name = name.replace("inner_cross_attn.", "") adjusted_weights[name] = weight adjusted_weights["lm_head.weight"] = weights["model.embed_tokens.weight"] - for name, loaded_weight in adjusted_weights.items(): - print(name, loaded_weight.shape) - - params_dict = dict(self.named_parameters()) - - print(f"{adjusted_weights.keys() - params_dict.keys()} not in model") - print(f"{params_dict.keys() - adjusted_weights.keys()} not in weights") - loaded_params: Set[str] = set() - for name, param in self.named_parameters(): weight = adjusted_weights.get(name, None) if weight is not None and weight.shape != param.shape: - print(f"Shape mismatch: {name} {weight.shape} {param.shape}") + logger.warning(f"Shape mismatch: {name} {weight.shape} {param.shape}") loaded_params.add(name) missing_keys, unexpected_keys = self.load_state_dict(adjusted_weights, strict=False) - print(f"--------------- missing keys {missing_keys}") - print("--------------- unexpected keys ---------------") - for key in unexpected_keys: - print(key) - if not key.endswith("bias"): - print("------- not bias -------") - # assert missing_keys == ['embedding_bias', 'lm_head.weight',], f"Missing keys: {missing_keys}" - # assert unexpected_keys == ['lm_head.bias',], f"Unexpected keys: {unexpected_keys}" - # self.lm_head.weight.data.copy_(adjusted_weights['model.embed_tokens.weight']) - # self.embedding_bias.data.copy_(adjusted_weights['lm_head.bias']) - # self.embedding_bias = None + assert len(unexpected_keys) == 0, f"Unexpected keys: {unexpected_keys}" + assert len(missing_keys) == 0, f"Missing keys: {missing_keys}" return loaded_params \ No newline at end of file From 5a0041422982dd3ae5f8e845bd1b439dd868da09 Mon Sep 17 00:00:00 2001 From: Congcong Chen Date: Thu, 12 Jun 2025 05:48:25 +0000 Subject: [PATCH 06/24] clean up Signed-off-by: Congcong Chen --- vllm/attention/backends/flash_attn.py | 48 ++++----------------------- 1 file changed, 6 insertions(+), 42 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 533408d5abd..bf8e373802f 100755 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -76,7 +76,6 @@ def get_kv_cache_shape( if block_size % 16 != 0: raise ValueError("Block size must be a multiple of 16.") return (2, num_blocks, block_size, num_kv_heads, head_size) - # return (2, 2, num_blocks, block_size, num_kv_heads // 2, head_size) @staticmethod def swap_blocks( @@ -186,9 +185,6 @@ class FlashAttentionMetadata(AttentionMetadata): cross_slot_mapping: Optional[torch.Tensor] = None cross_block_tables: Optional[torch.Tensor] = None - # Cross-layer shared attention block tables - cross_layer_shared_block_tables: Optional[torch.Tensor] = None - @property def is_all_encoder_attn_metadata_set(self): ''' @@ -233,9 +229,7 @@ def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]: self.context_lens_tensor[:self.num_prefills]) block_tables = (None if self.block_tables is None else self.block_tables[:self.num_prefills]) - cross_layer_shared_block_tables = (None if self.cross_layer_shared_block_tables is None else - self.cross_layer_shared_block_tables[:self.num_prefills]) - + self._cached_prefill_metadata = FlashAttentionMetadata( num_prefills=self.num_prefills, num_prefill_tokens=self.num_prefill_tokens, @@ -254,7 +248,6 @@ def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]: seq_start_loc=seq_start_loc, context_lens_tensor=context_lens_tensor, block_tables=block_tables, - cross_layer_shared_block_tables=cross_layer_shared_block_tables, use_cuda_graph=False, # Begin encoder & cross attn fields below... encoder_seq_lens=self.encoder_seq_lens, @@ -282,8 +275,7 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: self.seq_lens_tensor[self.num_prefills:]) block_tables = (None if self.block_tables is None else self.block_tables[self.num_prefills:]) - cross_layer_shared_block_tables = (None if self.cross_layer_shared_block_tables is None else - self.cross_layer_shared_block_tables[self.num_prefills:]) + self._cached_decode_metadata = FlashAttentionMetadata( num_prefills=0, num_prefill_tokens=0, @@ -307,7 +299,6 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: if self.seq_start_loc is not None else None, context_lens_tensor=None, block_tables=block_tables, - cross_layer_shared_block_tables=cross_layer_shared_block_tables, use_cuda_graph=self.use_cuda_graph, # Begin encoder & cross attn fields below... encoder_seq_lens=self.encoder_seq_lens, @@ -406,7 +397,6 @@ def prepare(self): self.prefill_seq_lens: List[int] = [] self.context_lens: List[int] = [] self.block_tables: List[List[int]] = [] - self.cross_layer_shared_block_tables: List[List[int]] = [] self.curr_seq_lens: List[int] = [] self.multimodal_placeholder_maps: Dict[ str, @@ -467,17 +457,6 @@ def _add_seq_group( -curr_sliding_window_block:] self.block_tables.append(block_table) - cross_layer_shared_block_table = [] - if prefix_cache_hit: - cross_layer_shared_block_table = block_tables[seq_id] - elif block_tables is not None: - if curr_sliding_window_block == 0: - cross_layer_shared_block_table = block_tables[seq_id] - else: - cross_layer_shared_block_table = block_tables[seq_id][ - -curr_sliding_window_block:] - self.cross_layer_shared_block_tables.append(cross_layer_shared_block_table) - # Compute slot mapping. is_profile_run = is_block_tables_empty(block_tables) start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, @@ -489,16 +468,13 @@ def _add_seq_group( def _get_graph_runner_block_tables( self, num_seqs: int, - block_tables: List[List[int]], - graph_block_tables) -> torch.Tensor: + block_tables: List[List[int]]) -> torch.Tensor: # The shape of graph_block_tables is # [max batch size, max context len // block size]. - # max_batch_size, max_blocks = self.runner.graph_block_tables.shape - max_batch_size, max_blocks = graph_block_tables.shape + max_batch_size, max_blocks = self.runner.graph_block_tables.shape assert max_batch_size >= num_seqs - # graph_block_tables = self.runner.graph_block_tables[:num_seqs] - graph_block_tables = graph_block_tables[:num_seqs] + graph_block_tables = self.runner.graph_block_tables[:num_seqs] for i, block_table in enumerate(block_tables): if block_table: num_blocks = len(block_table) @@ -553,14 +529,9 @@ def build(self, seq_lens: List[int], query_lens: List[int], if use_captured_graph: self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) self.block_tables.extend([] * cuda_graph_pad_size) - - self.cross_layer_shared_block_tables.extend([] * cuda_graph_pad_size) - num_decode_tokens = batch_size - self.num_prefill_tokens block_tables = self._get_graph_runner_block_tables( - num_seqs, self.block_tables, self.runner.graph_block_tables) - cross_layer_shared_block_tables = self._get_graph_runner_block_tables( - num_seqs, self.cross_layer_shared_block_tables, self.runner.cross_layer_shared_graph_block_tables) + num_seqs, self.block_tables) else: block_tables = make_tensor_with_pad( self.block_tables, @@ -568,12 +539,6 @@ def build(self, seq_lens: List[int], query_lens: List[int], dtype=torch.int, device=device, ) - cross_layer_shared_block_tables = make_tensor_with_pad( - self.cross_layer_shared_block_tables, - pad=0, - dtype=torch.int, - device=device, - ) assert max_query_len > 0, ("query_lens: {}".format(query_lens)) assert device is not None @@ -611,7 +576,6 @@ def build(self, seq_lens: List[int], query_lens: List[int], seq_start_loc=seq_start_loc_tensor, context_lens_tensor=context_lens_tensor, block_tables=block_tables, - cross_layer_shared_block_tables=cross_layer_shared_block_tables, use_cuda_graph=use_captured_graph, ) From e95162a8aba9e7d55c4bb7fe476a04146db0035f Mon Sep 17 00:00:00 2001 From: Congcong Chen Date: Thu, 12 Jun 2025 05:50:27 +0000 Subject: [PATCH 07/24] update Signed-off-by: Congcong Chen --- vllm/attention/backends/differential_flash_attn.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/vllm/attention/backends/differential_flash_attn.py b/vllm/attention/backends/differential_flash_attn.py index a925cdcb3c7..0e8735b62e8 100644 --- a/vllm/attention/backends/differential_flash_attn.py +++ b/vllm/attention/backends/differential_flash_attn.py @@ -678,7 +678,6 @@ def __init__( self.attn_type = attn_type self.lambda_full = None - # self.subln = nn.RMSNorm(2 * self.head_size, eps=1e-5, elementwise_affine=True) self.subln = self.differential_flash_attention_config["subln"] def split_heads(self, x): @@ -705,9 +704,6 @@ def populate_kv_cache(self, if (kv_cache.numel() > 0): if (key is not None) and (value is not None): updated_slot_mapping = attn_metadata.slot_mapping - # previous_key_cache_sum = key_cache.sum() - # previous_value_cache_sum = value_cache.sum() - torch.ops._C_cache_ops.reshape_and_cache_flash( key, value, @@ -718,12 +714,6 @@ def populate_kv_cache(self, layer._k_scale, layer._v_scale, ) - # assert key_cache.sum() - previous_key_cache_sum == key.sum(), "key_cache sum mismatch" - # assert value_cache.sum() - previous_value_cache_sum == value.sum(), "value_cache sum mismatch" - # if key_cache.sum() - previous_key_cache_sum != key.sum(): - # print("key_cache sum mismatch") - # if value_cache.sum() - previous_value_cache_sum != value.sum(): - # print("value_cache sum mismatch") def forward_generate_kv_cache( self, From 67502d8cb1194d261be7a4bcc3e684afc1d43a97 Mon Sep 17 00:00:00 2001 From: Congcong Chen Date: Sat, 14 Jun 2025 08:41:47 +0000 Subject: [PATCH 08/24] renaming Signed-off-by: Congcong Chen --- .../models/{phi3samba.py => phi4sambay.py} | 22 +++++++++---------- vllm/model_executor/models/registry.py | 2 +- 2 files changed, 12 insertions(+), 12 deletions(-) rename vllm/model_executor/models/{phi3samba.py => phi4sambay.py} (98%) diff --git a/vllm/model_executor/models/phi3samba.py b/vllm/model_executor/models/phi4sambay.py similarity index 98% rename from vllm/model_executor/models/phi3samba.py rename to vllm/model_executor/models/phi4sambay.py index d508be81fb8..ec3189409df 100644 --- a/vllm/model_executor/models/phi3samba.py +++ b/vllm/model_executor/models/phi4sambay.py @@ -49,7 +49,7 @@ def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: return x1 * nn.functional.silu(x2) -class SambaMLP(nn.Module): +class SambaYMLP(nn.Module): """Gated Linear Unit. Reference: @@ -78,7 +78,7 @@ def get_virtual_engine(): forward_context: ForwardContext = get_forward_context() return forward_context.virtual_engine -class SambaAttention(nn.Module): +class SambaYAttention(nn.Module): def __init__(self, config, layer_idx: Optional[int] = None, @@ -391,7 +391,7 @@ def forward( return contextualized_states, yoco_key_values -class SambaDecoderLayer(nn.Module): +class SambaYDecoderLayer(nn.Module): def __init__(self, config, @@ -403,13 +403,13 @@ def __init__(self, self.config = config self.layer_idx = layer_idx - self.mlp = SambaMLP(config) + self.mlp = SambaYMLP(config) self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.yoco_mb = False self.yoco_kv = False self.yoco_cross = False - assert config.num_hidden_layers % 4 == 0, 'n_layer should be divisible by 4 for samba + yoco' + assert config.num_hidden_layers % 4 == 0, 'n_layer should be divisible by 4 for SambaY + yoco' if layer_idx >= config.num_hidden_layers//2: self.yoco_mb = True self.yoco_kv = (layer_idx >= (config.num_hidden_layers//2 +1)) @@ -420,7 +420,7 @@ def __init__(self, self.attn = Phi3Mamba(config.hidden_size, layer_idx=layer_idx, yoco_cross=self.yoco_cross, yoco_kv=self.yoco_mb, **factory_kwargs) else: - self.attn = SambaAttention(config, layer_idx=layer_idx, yoco_cross=self.yoco_cross, cache_config=cache_config, prefix=f"{prefix}.self_attn") + self.attn = SambaYAttention(config, layer_idx=layer_idx, yoco_cross=self.yoco_cross, cache_config=cache_config, prefix=f"{prefix}.self_attn") self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) def forward( @@ -469,7 +469,7 @@ def get_kv_cache(layer_name): kv_cache = self.kv_cache[forward_context.virtual_engine] return kv_cache -class SambaModel(nn.Module): +class SambaYModel(nn.Module): def __init__( self, @@ -494,7 +494,7 @@ def __init__( self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: SambaDecoderLayer(config, + lambda prefix: SambaYDecoderLayer(config, int(prefix.split('.')[-1]), cache_config, prefix=prefix), @@ -590,7 +590,7 @@ def forward( return hidden_states -class SambaForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only): +class Phi4MiniFlashForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config @@ -603,13 +603,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # Prefix caching is not supported since there are mamba layers in this # mode. assert not cache_config.enable_prefix_caching, \ - "Samba currently does not support prefix caching" + "SambaY currently does not support prefix caching" super().__init__() self.config = config self.model_config = vllm_config.model_config self.scheduler_config = scheduler_config - self.model = SambaModel( + self.model = SambaYModel( config, cache_config=cache_config, prefix=maybe_prefix(prefix, "model") diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index b70e10875b2..6a1dead209e 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -110,7 +110,7 @@ "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"), "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"), "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"), - "SambaForCausalLM": ("phi3samba", "SambaForCausalLM"), + "Phi4MiniFlashForCausalLM": ("phi4sambay", "Phi4MiniFlashForCausalLM"), "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"), "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), From 5ad7dbb82f73aa46d88de8fc1c0036cf58446e6e Mon Sep 17 00:00:00 2001 From: Congcong Chen Date: Mon, 16 Jun 2025 23:38:36 +0000 Subject: [PATCH 09/24] renames Signed-off-by: Congcong Chen --- .../models/{phi4sambay.py => phi4flash.py} | 18 ++++-------------- vllm/model_executor/models/registry.py | 2 +- 2 files changed, 5 insertions(+), 15 deletions(-) rename vllm/model_executor/models/{phi4sambay.py => phi4flash.py} (98%) diff --git a/vllm/model_executor/models/phi4sambay.py b/vllm/model_executor/models/phi4flash.py similarity index 98% rename from vllm/model_executor/models/phi4sambay.py rename to vllm/model_executor/models/phi4flash.py index ec3189409df..ccc610671ae 100644 --- a/vllm/model_executor/models/phi4sambay.py +++ b/vllm/model_executor/models/phi4flash.py @@ -45,7 +45,6 @@ class SwiGLUActivation(nn.Module): def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: - # print(f"x1 shape: {x1.shape}, x2 shape: {x2.shape}") return x1 * nn.functional.silu(x2) @@ -175,7 +174,7 @@ def forward( return self.out_proj(attn_output) -class Phi3Mamba(nn.Module): +class Phi4Mamba(nn.Module): def __init__( self, d_model, @@ -250,15 +249,6 @@ def __init__( params_dtype=dtype, ) - # # S4D real initialization - # A = repeat( - # torch.arange(1, self.d_state + 1, dtype=torch.float32), - # "n -> d n", - # d=self.d_inner, - # ).contiguous() - # A_log = torch.log(A) # Keep A_log in fp32 - # self.A_log = nn.Parameter(A_log) - # # D "skip" parameter # self.D = nn.Parameter(torch.ones(self.d_inner)) # Keep in fp32 self.A = nn.Parameter( @@ -417,7 +407,7 @@ def __init__(self, self.use_mamba = config.mb_per_layer > 0 and layer_idx % config.mb_per_layer == 0 if self.use_mamba: factory_kwargs = {"dtype": None} - self.attn = Phi3Mamba(config.hidden_size, layer_idx=layer_idx, + self.attn = Phi4Mamba(config.hidden_size, layer_idx=layer_idx, yoco_cross=self.yoco_cross, yoco_kv=self.yoco_mb, **factory_kwargs) else: self.attn = SambaYAttention(config, layer_idx=layer_idx, yoco_cross=self.yoco_cross, cache_config=cache_config, prefix=f"{prefix}.self_attn") @@ -590,7 +580,7 @@ def forward( return hidden_states -class Phi4MiniFlashForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only): +class Phi4FlashForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config @@ -603,7 +593,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # Prefix caching is not supported since there are mamba layers in this # mode. assert not cache_config.enable_prefix_caching, \ - "SambaY currently does not support prefix caching" + "Phi4flash currently does not support prefix caching" super().__init__() self.config = config diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 6a1dead209e..5f9b145b661 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -110,7 +110,7 @@ "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"), "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"), "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"), - "Phi4MiniFlashForCausalLM": ("phi4sambay", "Phi4MiniFlashForCausalLM"), + "Phi4FlashForCausalLM": ("phi4flash", "Phi4FlashForCausalLM"), "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"), "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), From 9a6607b2d7be2212802b396ec73feb084889e72d Mon Sep 17 00:00:00 2001 From: Congcong Chen Date: Tue, 17 Jun 2025 07:22:20 +0000 Subject: [PATCH 10/24] fix sliding window check Signed-off-by: Congcong Chen --- vllm/model_executor/models/phi4flash.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/phi4flash.py b/vllm/model_executor/models/phi4flash.py index ccc610671ae..3ed7d2888f0 100644 --- a/vllm/model_executor/models/phi4flash.py +++ b/vllm/model_executor/models/phi4flash.py @@ -112,8 +112,10 @@ def __init__(self, # disable sliding window for the second half of the model sliding_window = config.interleaved_sliding_window[layer_idx] - if layer_idx >= config.num_hidden_layers // 2 or layer_idx % 2 == 0: - assert sliding_window == None, "sliding_window is not none" + if layer_idx >= config.num_hidden_layers // 2: + assert sliding_window is None, "sliding_window must be none for the second decoder" + else: + assert sliding_window is not None, "sliding_window must be set for the first decoder" assert self.num_heads % 2 == 0, 'num_heads should be even' assert self.num_key_value_heads % 2 == 0, 'num_heads should be even' @@ -397,12 +399,10 @@ def __init__(self, self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.yoco_mb = False - self.yoco_kv = False self.yoco_cross = False assert config.num_hidden_layers % 4 == 0, 'n_layer should be divisible by 4 for SambaY + yoco' if layer_idx >= config.num_hidden_layers//2: self.yoco_mb = True - self.yoco_kv = (layer_idx >= (config.num_hidden_layers//2 +1)) self.yoco_cross = (layer_idx >= (config.num_hidden_layers//2 +2)) self.use_mamba = config.mb_per_layer > 0 and layer_idx % config.mb_per_layer == 0 if self.use_mamba: From 9ba8ea51989af2eb9a81d1a578ba920a39e83fc0 Mon Sep 17 00:00:00 2001 From: Congcong Chen Date: Tue, 24 Jun 2025 08:27:16 +0000 Subject: [PATCH 11/24] remove warning msg Signed-off-by: Congcong Chen --- vllm/config.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 3eda6b85881..b1f7f9e57a7 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -989,14 +989,6 @@ def _verify_cuda_graph(self) -> None: "to eager mode.", self.hf_config.model_type) self.enforce_eager = True - RECOMMENDED_MODEL_SUPPORTS_CUDA_GRAPH = ['phi3samba'] - if (self.hf_config.model_type in RECOMMENDED_MODEL_SUPPORTS_CUDA_GRAPH - and not self.enforce_eager and self.max_seq_len_to_capture < self.max_model_len): - logger.warning( - "%s model performs best with the CUDA graph explicitly enabled. Set `--max-seq-len-to-capture <#>` " - "when starting vLLM.", self.hf_config.model_type) - - def _verify_bnb_config(self) -> None: """ The current version of bitsandbytes (0.46.1) with 8-bit models does not From 7d1cf256ccdaa9f8cd4dabeef4b2e9891f9ca397 Mon Sep 17 00:00:00 2001 From: Congcong Chen Date: Wed, 9 Jul 2025 07:31:05 +0000 Subject: [PATCH 12/24] address comments Signed-off-by: Congcong Chen --- vllm/attention/backends/abstract.py | 2 - .../backends/differential_flash_attn.py | 7 +-- vllm/attention/layer.py | 6 +- vllm/model_executor/models/phi4flash.py | 55 +++++-------------- vllm/utils/__init__.py | 8 ++- vllm/worker/worker.py | 27 ++++++++- 6 files changed, 53 insertions(+), 52 deletions(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 428bddb0e12..05c098a58a0 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -32,8 +32,6 @@ class AttentionType: ENCODER_ONLY = "encoder_only" # Attention between dec. Q and enc. K/V for encoder-decoder ENCODER_DECODER = "encoder_decoder" - # Attention layer that reuse kv cache - DECODER_DECODER = "decoder_decoder" class AttentionBackend(ABC): diff --git a/vllm/attention/backends/differential_flash_attn.py b/vllm/attention/backends/differential_flash_attn.py index 0e8735b62e8..ce398934549 100644 --- a/vllm/attention/backends/differential_flash_attn.py +++ b/vllm/attention/backends/differential_flash_attn.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Attention layer with FlashAttention.""" from collections import defaultdict from dataclasses import dataclass from itertools import accumulate @@ -55,7 +54,6 @@ def get_kv_cache_shape( ) -> Tuple[int, ...]: if block_size % 16 != 0: raise ValueError("Block size must be a multiple of 16.") - # return (2, num_blocks, block_size, num_kv_heads, head_size) return (2, 2, num_blocks, block_size, num_kv_heads // 2, head_size) @staticmethod @@ -634,8 +632,9 @@ def __init__( self.differential_flash_attention_config = differential_flash_attention_config self.used_shared_kv_cache = self.differential_flash_attention_config.get( "used_shared_kv_cache", False) - if kv_sharing_target_layer_name is not None: - raise NotImplementedError("KV sharing is not supported in V0.") + # if kv_sharing_target_layer_name is not None: + # raise NotImplementedError("KV sharing is not supported in V0.") + self.kv_sharing_target_layer_name = kv_sharing_target_layer_name if blocksparse_params is not None: raise ValueError( "FlashAttention does not support block-sparse attention.") diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 3d5746837be..751cfe84299 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -160,9 +160,9 @@ def __init__( self.attn_type = attn_type if kv_sharing_target_layer_name is not None: - if not envs.VLLM_USE_V1: - raise NotImplementedError( - "Cross-layer KV sharing is not supported in V0.") + # if not envs.VLLM_USE_V1: + # raise NotImplementedError( + # "Cross-layer KV sharing is not supported in V0.") validate_kv_sharing_target( prefix, diff --git a/vllm/model_executor/models/phi4flash.py b/vllm/model_executor/models/phi4flash.py index 3ed7d2888f0..015a4621f1c 100644 --- a/vllm/model_executor/models/phi4flash.py +++ b/vllm/model_executor/models/phi4flash.py @@ -138,7 +138,13 @@ def __init__(self, "subln": self.subln, } } - + + if yoco_cross: + kv_shared_layer_index = config.num_hidden_layers//2 + 1 + kv_sharing_target_layer_name = f"model.layers.{kv_shared_layer_index}.self_attn.attn" # noqa: E501 + else: + kv_sharing_target_layer_name = None + self.attn = Attention( self.num_heads, self.head_dim, @@ -147,7 +153,8 @@ def __init__(self, cache_config=cache_config, per_layer_sliding_window=sliding_window, prefix=f"{prefix}.attn", - attn_type=AttentionType.DECODER_DECODER if self.yoco_cross else AttentionType.DECODER, + attn_type=AttentionType.DECODER, + kv_sharing_target_layer_name=kv_sharing_target_layer_name, **params ) @@ -157,9 +164,6 @@ def lambda_init_fn(self, depth): def forward( self, hidden_states: torch.Tensor, - positions: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ): if not self.yoco_cross: # need to generate kv-cache @@ -168,9 +172,6 @@ def forward( attn_output = self.attn(q, k, v) else: # re-use the kv cache, full attention q = self.Wqkv(hidden_states) - virtual_engine = get_virtual_engine() - if self.attn.kv_cache[virtual_engine].numel() == 0: - self.attn.kv_cache[virtual_engine] = kv_cache attn_output = self.attn(q, None, None) attn_output = attn_output.view(-1, self.num_heads * self.head_dim) return self.out_proj(attn_output) @@ -417,15 +418,14 @@ def forward( self, hidden_states: torch.Tensor, positions: torch.Tensor, - kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, mamba_cache_params: MambaCacheParams, ssm_output: Optional[torch.LongTensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if self.use_mamba: - assert kv_cache is None and mamba_cache_params is not None + assert mamba_cache_params is not None else: - assert kv_cache is not None and mamba_cache_params is None + assert mamba_cache_params is None residual = hidden_states hidden_states = self.input_layernorm(hidden_states.to(dtype=self.input_layernorm.weight.dtype)) @@ -441,9 +441,6 @@ def forward( else: attn_outputs = self.attn( hidden_states, - positions, - kv_cache, - attn_metadata, ) hidden_states = residual + attn_outputs residual = hidden_states @@ -452,12 +449,7 @@ def forward( hidden_states = residual + hidden_states return hidden_states, ssm_output - -def get_kv_cache(layer_name): - forward_context: ForwardContext = get_forward_context() - self = forward_context.no_compile_layers[layer_name] - kv_cache = self.kv_cache[forward_context.virtual_engine] - return kv_cache + class SambaYModel(nn.Module): @@ -513,16 +505,16 @@ def forward( assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - kv_cache_idx = 0 mamba_state_idx = 0 ssm_output = None for i in range(self.start_layer, self.end_layer): layer = self.layers[i] if i == self.config.num_hidden_layers // 2 + 2: # profile run + kv_cache_idx = self.config.num_hidden_layers//2 + 1 cache_layer = self.layers[kv_cache_idx] - kv_cache = get_kv_cache(cache_layer.attn.attn.layer_name) - if kv_cache.numel() == 0: + kv_cache = cache_layer.attn.attn.kv_cache + if kv_cache[0].numel() == 0: break # Starting from this layer, we do not need to cuculate the kv cache since we reuse @@ -546,31 +538,14 @@ def forward( hidden_states, ssm_output = layer( hidden_states, positions, - None, # kv_cache attn_metadata, mamba_cache, ssm_output = ssm_output ) else: - if i < self.config.num_hidden_layers // 2: - # sliding window attention - cache_layer = self.layers[i] - kv_cache = get_kv_cache(cache_layer.attn.attn.layer_name) - kv_cache_idx = i - elif not layer.yoco_cross: - # full attention that generates kv cache - cache_layer = self.layers[i] - kv_cache = get_kv_cache(cache_layer.attn.attn.layer_name) - kv_cache_idx = i - else: - # full attention that reuses kv cache - cache_layer = self.layers[kv_cache_idx] - kv_cache = get_kv_cache(cache_layer.attn.attn.layer_name) - hidden_states, ssm_output = layer( hidden_states, positions, - kv_cache, attn_metadata, None, # mamba_cache_params ssm_output = ssm_output diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index 48346c7d6e5..29eea8bedf2 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -2890,6 +2890,7 @@ def get_mp_context(): def bind_kv_cache( ctx: dict[str, Any], kv_cache: list[list[torch.Tensor]], # [virtual_engine][layer_index] + shared_kv_cache_layers: dict[str, str], ) -> None: # Bind the kv_cache tensor to Attention modules, similar to # ctx[layer_name].kv_cache[ve]=kv_cache[ve][extract_layer_index(layer_name)] @@ -2913,11 +2914,16 @@ def bind_kv_cache( extract_layer_index(layer_name) for layer_name in layer_need_kv_cache)) for layer_name in layer_need_kv_cache: + target_layer_name = shared_kv_cache_layers[layer_name] if layer_name \ + in shared_kv_cache_layers else layer_name kv_cache_idx = layer_index_sorted.index( - extract_layer_index(layer_name)) + extract_layer_index(target_layer_name)) forward_ctx = ctx[layer_name] assert len(forward_ctx.kv_cache) == len(kv_cache) + for ve, ve_kv_cache in enumerate(kv_cache): + assert kv_cache_idx < len(ve_kv_cache), \ + "v0 doesn't support interleaving kv sharing, use v1 instead" forward_ctx.kv_cache[ve] = ve_kv_cache[kv_cache_idx] diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 21e684a3fb5..065da227d48 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -9,7 +9,8 @@ import torch.distributed import vllm.envs as envs -from vllm.config import VllmConfig +from vllm.attention.layer import Attention +from vllm.config import (VllmConfig, get_layers_from_vllm_config) from vllm.device_allocator.cumem import CuMemAllocator from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment, @@ -26,6 +27,7 @@ SequenceGroupMetadata, SequenceGroupMetadataDelta) from vllm.utils import (GiB_bytes, MemorySnapshot, bind_kv_cache, memory_profiling) +from vllm.v1.worker.utils import initialize_kv_cache_for_kv_sharing from vllm.worker.cache_engine import CacheEngine from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner @@ -345,8 +347,29 @@ def _init_cache_engine(self): self.cache_engine[ve].gpu_cache for ve in range(self.parallel_config.pipeline_parallel_size) ] + + # Layer pairings for cross-layer KV sharing. + # If an Attention layer `layer_name` is in the keys of this dict, it + # means this layer will perform attention using the keys and values + # from the KV cache of `shared_kv_cache_layers[layer_name]`. + shared_kv_cache_layers: dict[str, str] = {} + + attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) + + for layer_name, attn_module in attn_layers.items(): + if (kv_tgt_layer := + attn_module.kv_sharing_target_layer_name) is not None: + # The layer doesn't need its own KV cache and will use that of + # the target layer. We skip creating a KVCacheSpec for it, so + # that KV cache management logic will act as this layer does + # not exist, and doesn't allocate KV cache for the layer. This + # enables the memory saving of cross-layer kv sharing, allowing + # a given amount of memory to accommodate longer context lengths + # or enable more requests to be processed simultaneously. + shared_kv_cache_layers[layer_name] = kv_tgt_layer + bind_kv_cache(self.compilation_config.static_forward_context, - self.gpu_cache) + self.gpu_cache, shared_kv_cache_layers) def _warm_up_model(self) -> None: # warm up sizes that are not in cudagraph capture sizes, From 96a84f8f8c3e77a3f808ee4f0d0ad54f275af538 Mon Sep 17 00:00:00 2001 From: Congcong Chen Date: Wed, 9 Jul 2025 17:22:40 +0000 Subject: [PATCH 13/24] revert unrelated changes Signed-off-by: Congcong Chen --- benchmarks/benchmark_prefix_caching.py | 37 +++++++++----------------- 1 file changed, 13 insertions(+), 24 deletions(-) diff --git a/benchmarks/benchmark_prefix_caching.py b/benchmarks/benchmark_prefix_caching.py index 0b7d23a9bd7..b5e2613de1c 100644 --- a/benchmarks/benchmark_prefix_caching.py +++ b/benchmarks/benchmark_prefix_caching.py @@ -45,24 +45,13 @@ except ImportError: from backend_request_func import get_tokenizer -# PROMPT = "You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as fellows. You need to answer my question about the table.\n# Table\n|Opening|Opening|Sl. No.|Film|Cast|Director|Music Director|Notes|\n|----|----|----|----|----|----|----|----|\n|J A N|9|1|Agni Pushpam|Jayabharathi, Kamalahasan|Jeassy|M. K. Arjunan||\n|J A N|16|2|Priyamvada|Mohan Sharma, Lakshmi, KPAC Lalitha|K. S. Sethumadhavan|V. Dakshinamoorthy||\n|J A N|23|3|Yakshagaanam|Madhu, Sheela|Sheela|M. S. Viswanathan||\n|J A N|30|4|Paalkkadal|Sheela, Sharada|T. K. Prasad|A. T. Ummer||\n|F E B|5|5|Amma|Madhu, Srividya|M. Krishnan Nair|M. K. Arjunan||\n|F E B|13|6|Appooppan|Thikkurissi Sukumaran Nair, Kamal Haasan|P. Bhaskaran|M. S. Baburaj||\n|F E B|20|7|Srishti|Chowalloor Krishnankutty, Ravi Alummoodu|K. T. Muhammad|M. S. Baburaj||\n|F E B|20|8|Vanadevatha|Prem Nazir, Madhubala|Yusufali Kechery|G. Devarajan||\n|F E B|27|9|Samasya|Madhu, Kamalahaasan|K. Thankappan|Shyam||\n|F E B|27|10|Yudhabhoomi|K. P. Ummer, Vidhubala|Crossbelt Mani|R. K. Shekhar||\n|M A R|5|11|Seemantha Puthran|Prem Nazir, Jayabharathi|A. B. Raj|M. K. Arjunan||\n|M A R|12|12|Swapnadanam|Rani Chandra, Dr. Mohandas|K. G. George|Bhaskar Chandavarkar||\n|M A R|19|13|Thulavarsham|Prem Nazir, sreedevi, Sudheer|N. Sankaran Nair|V. Dakshinamoorthy||\n|M A R|20|14|Aruthu|Kaviyoor Ponnamma, Kamalahasan|Ravi|G. Devarajan||\n|M A R|26|15|Swimming Pool|Kamal Haasan, M. G. Soman|J. Sasikumar|M. K. Arjunan||\n\n# Question\nWhat' s the content in the (1,1) cells\n" # noqa: E501 -# PROMPT = "Question: Who is bill gates?\n\nAnswer:" -# content = """China officially the People's Republic of China (PRC), is a country in East Asia. With a population exceeding 1.4 billion, it is the second-most populous country after India, representing 17.4% of the world population. China spans the equivalent of five time zones and borders fourteen countries by land[k] across an area of nearly 9.6 million square kilometers (3,700,000 sq mi), making it the third-largest country by total land area.[l] The country is divided into 33 province-level divisions: 22 provinces,[m] five autonomous regions, four municipalities, and two semi-autonomous special administrative regions. Beijing is the country's capital, while Shanghai is its most populous city by urban area and largest financial center. China is considered one of the cradles of civilization: the first human inhabitants in the region arrived during the Paleolithic. By the late 2nd millennium BCE, the earliest dynastic states had emerged in the Yellow River basin. The 8th–3rd centuries BCE saw a breakdown in the authority of the Zhou dynasty, accompanied by the emergence of administrative and military techniques, literature, philosophy, and historiography. In 221 BCE, China was unified under an emperor, ushering in more than two millennia of imperial dynasties including the Qin, Han, Tang, Yuan, Ming, and Qing. With the invention of gunpowder and paper, the establishment of the Silk Road, and the building of the Great Wall, Chinese culture flourished and has heavily influenced both its neighbors and lands further afield. However, China began to cede parts of the country in the late 19th century to various European powers by a series of unequal treaties. After decades of Qing China on the decline, the 1911 Revolution overthrew the Qing dynasty and the monarchy and the Republic of China (ROC) was established the following year. The country under the nascent Beiyang government was unstable and ultimately fragmented during the Warlord Era, which was ended upon the Northern Expedition conducted by the Kuomintang (KMT) to reunify the country. The Chinese Civil War began in 1927, when KMT forces purged members of the rival Chinese Communist Party (CCP), who proceeded to engage in sporadic fighting against the KMT-led Nationalist government. Following the country's invasion by the Empire of Japan in 1937, the CCP and KMT formed the Second United Front to fight the Japanese. The Second Sino-Japanese War eventually ended in a Chinese victory; however, the CCP and the KMT resumed their civil war as soon as the war ended. In 1949, the resurgent Communists established control over most of the country, proclaiming the People's Republic of China and forcing the Nationalist government to retreat to the island of Taiwan. The country was split, with both sides claiming to be the sole legitimate government of China. Following the implementation of land reforms, further attempts by the PRC to realize communism failed: the Great Leap Forward was largely responsible for the Great Chinese Famine that ended with millions of Chinese people having died, and the subsequent Cultural Revolution was a period of social turmoil and persecution characterized by Maoist populism. Following the Sino-Soviet split, the Shanghai Communiqué in 1972 would precipitate the normalization of relations with the United States. Economic reforms that began in 1978 moved the country away from a socialist planned economy towards an increasingly capitalist market economy, spurring significant economic growth. A movement for increased democracy and liberalization stalled after the Tiananmen Square protests and massacre in 1989.""" -# PROMPT = f'{content}' -PROMPT = "Question: Tell me about Seatttle?\n\nAnswer:" +PROMPT = "You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as fellows. You need to answer my question about the table.\n# Table\n|Opening|Opening|Sl. No.|Film|Cast|Director|Music Director|Notes|\n|----|----|----|----|----|----|----|----|\n|J A N|9|1|Agni Pushpam|Jayabharathi, Kamalahasan|Jeassy|M. K. Arjunan||\n|J A N|16|2|Priyamvada|Mohan Sharma, Lakshmi, KPAC Lalitha|K. S. Sethumadhavan|V. Dakshinamoorthy||\n|J A N|23|3|Yakshagaanam|Madhu, Sheela|Sheela|M. S. Viswanathan||\n|J A N|30|4|Paalkkadal|Sheela, Sharada|T. K. Prasad|A. T. Ummer||\n|F E B|5|5|Amma|Madhu, Srividya|M. Krishnan Nair|M. K. Arjunan||\n|F E B|13|6|Appooppan|Thikkurissi Sukumaran Nair, Kamal Haasan|P. Bhaskaran|M. S. Baburaj||\n|F E B|20|7|Srishti|Chowalloor Krishnankutty, Ravi Alummoodu|K. T. Muhammad|M. S. Baburaj||\n|F E B|20|8|Vanadevatha|Prem Nazir, Madhubala|Yusufali Kechery|G. Devarajan||\n|F E B|27|9|Samasya|Madhu, Kamalahaasan|K. Thankappan|Shyam||\n|F E B|27|10|Yudhabhoomi|K. P. Ummer, Vidhubala|Crossbelt Mani|R. K. Shekhar||\n|M A R|5|11|Seemantha Puthran|Prem Nazir, Jayabharathi|A. B. Raj|M. K. Arjunan||\n|M A R|12|12|Swapnadanam|Rani Chandra, Dr. Mohandas|K. G. George|Bhaskar Chandavarkar||\n|M A R|19|13|Thulavarsham|Prem Nazir, sreedevi, Sudheer|N. Sankaran Nair|V. Dakshinamoorthy||\n|M A R|20|14|Aruthu|Kaviyoor Ponnamma, Kamalahasan|Ravi|G. Devarajan||\n|M A R|26|15|Swimming Pool|Kamal Haasan, M. G. Soman|J. Sasikumar|M. K. Arjunan||\n\n# Question\nWhat' s the content in the (1,1) cells\n" # noqa: E501 + def test_prefix(llm=None, sampling_params=None, prompts=None): start_time = time.time() - # llm.generate(prompts, sampling_params=sampling_params) - outputs = llm.generate(prompts, sampling_params=sampling_params) - # Print the outputs. - generated_texts = [] - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text.strip() - generated_texts.append(generated_text) - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + llm.generate(prompts, sampling_params=sampling_params) end_time = time.time() print(f"cost time {end_time - start_time}") @@ -147,16 +136,16 @@ def sample_requests_from_random( min_len, max_len = input_length_range for i in range(num_requests): - # unique_part_token_ids = sample_tokens( - # tokenizer, - # random.randint(min_len - prefix_len, max_len - prefix_len)) - # prompt_token_ids = prefix_token_ids + unique_part_token_ids - # prompt = tokenizer.decode(prompt_token_ids) - # prompt_len = len(prompt_token_ids) - # assert (min_len <= prompt_len <= max_len - # ), f"prompt_len {prompt_len} out of range {min_len}:{max_len}" - - requests.append(Request(PROMPT, 10, fixed_output_len)) + unique_part_token_ids = sample_tokens( + tokenizer, random.randint(min_len - prefix_len, max_len - prefix_len) + ) + prompt_token_ids = prefix_token_ids + unique_part_token_ids + prompt = tokenizer.decode(prompt_token_ids) + prompt_len = len(prompt_token_ids) + assert min_len <= prompt_len <= max_len, ( + f"prompt_len {prompt_len} out of range {min_len}:{max_len}" + ) + requests.append(Request(prompt, prompt_len, fixed_output_len)) return requests From 35564f67cd25d6c2fa96250a241258090ce7de09 Mon Sep 17 00:00:00 2001 From: Congcong Chen Date: Wed, 9 Jul 2025 18:01:33 +0000 Subject: [PATCH 14/24] minor Signed-off-by: Congcong Chen --- vllm/attention/backends/differential_flash_attn.py | 5 +++-- vllm/model_executor/models/phi4flash.py | 6 +++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/vllm/attention/backends/differential_flash_attn.py b/vllm/attention/backends/differential_flash_attn.py index ce398934549..c60ad5ae06c 100644 --- a/vllm/attention/backends/differential_flash_attn.py +++ b/vllm/attention/backends/differential_flash_attn.py @@ -54,6 +54,7 @@ def get_kv_cache_shape( ) -> Tuple[int, ...]: if block_size % 16 != 0: raise ValueError("Block size must be a multiple of 16.") + assert num_kv_heads % 2 == 0, "num_kv_heads must be divisible by 2" return (2, 2, num_blocks, block_size, num_kv_heads // 2, head_size) @staticmethod @@ -872,7 +873,7 @@ def forward( k1, k2 = self.split_heads(k) v1, v2 = self.split_heads(v) - # kv_cache shape is (2, 2, num_blocks, block_size * num_kv_heads // 2 * head_size) + # kv_cache shape is (2, 2, num_blocks, block_size, num_kv_heads // 2, head_size) # Split by half along the first dimension. kv_cache1, kv_cache2 = self.split_kv_cache(kv_cache) assert kv_cache1.is_contiguous(), "kv_cache1 is not contiguous" @@ -909,7 +910,7 @@ def forward( else: # re-use the kv cache, full attention q = q.view(-1, self.num_heads, self.head_size) q1, q2 = self.split_heads(q) - # kv_cache shape is (2, num_blocks, block_size * num_kv_heads * head_size) + # kv_cache shape is (2, num_blocks, block_size, num_kv_heads, head_size) kv_cache1, kv_cache2 = self.split_kv_cache(kv_cache) key_cache1, value_cache1 = kv_cache1[0], kv_cache1[1] key_cache2, value_cache2 = kv_cache2[0], kv_cache2[1] diff --git a/vllm/model_executor/models/phi4flash.py b/vllm/model_executor/models/phi4flash.py index 015a4621f1c..300d14cbfcb 100644 --- a/vllm/model_executor/models/phi4flash.py +++ b/vllm/model_executor/models/phi4flash.py @@ -517,9 +517,9 @@ def forward( if kv_cache[0].numel() == 0: break - # Starting from this layer, we do not need to cuculate the kv cache since we reuse - # the kv cache from last layer. If in prefill phase, we can prune truncate - # hidden state to save computation cost. + # Starting from this layer, we do not need to calculate the kv cache since we reuse + # the kv cache from last layer. If in prefill phase, we can prune> truncate + # the hidden state to save computation cost. if attn_metadata.prefill_metadata: selected_token_indices = torch.cumsum(attn_metadata.seq_lens_tensor, dim=0) - 1 hidden_states = hidden_states.index_select(0, selected_token_indices) From b98b72cf44fef2d68b9739016623df34ae4c2083 Mon Sep 17 00:00:00 2001 From: Congcong Chen Date: Wed, 9 Jul 2025 18:52:27 +0000 Subject: [PATCH 15/24] run lint Signed-off-by: Congcong Chen --- .../backends/differential_flash_attn.py | 218 +++++---- vllm/model_executor/models/phi4flash.py | 440 ++++++++++-------- vllm/utils/__init__.py | 11 +- vllm/worker/model_runner.py | 2 +- vllm/worker/worker.py | 5 +- 5 files changed, 372 insertions(+), 304 deletions(-) diff --git a/vllm/attention/backends/differential_flash_attn.py b/vllm/attention/backends/differential_flash_attn.py index c60ad5ae06c..0de3e8c43e1 100644 --- a/vllm/attention/backends/differential_flash_attn.py +++ b/vllm/attention/backends/differential_flash_attn.py @@ -6,23 +6,23 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type import torch -import torch.nn as nn +from einops import rearrange from vllm import _custom_ops as ops # yapf conflicts with isort for this block # yapf: disable -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionLayer, +from vllm.attention.backends.abstract import (AttentionImpl, AttentionLayer, AttentionMetadata, AttentionMetadataBuilder, AttentionType, is_quantized_kv_cache) +from vllm.attention.backends.flash_attn import FlashAttentionBackend # yapf: enable -from vllm.attention.backends.utils import ( - PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping, - compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens, - get_seq_len_block_table_args, is_all_cross_attn_metadata_set, - is_all_encoder_attn_metadata_set, is_block_tables_empty) +from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, + compute_slot_mapping_start_idx, + is_all_cross_attn_metadata_set, + is_all_encoder_attn_metadata_set, + is_block_tables_empty) from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8, get_flash_attn_version) from vllm.logger import init_logger @@ -30,11 +30,6 @@ from vllm.utils import async_tensor_h2d, make_tensor_with_pad from vllm.vllm_flash_attn import (flash_attn_varlen_func, flash_attn_with_kvcache) -from vllm.attention.backends.flash_attn import (FlashAttentionBackend, - FlashAttentionImpl, - FlashAttentionMetadata, - FlashAttentionMetadataBuilder) -from einops import rearrange if TYPE_CHECKING: from vllm.worker.model_runner import (ModelInputForGPUBuilder, @@ -45,6 +40,7 @@ class DifferentialFlashAttentionBackend(FlashAttentionBackend): accept_output_buffer = False + @staticmethod def get_kv_cache_shape( num_blocks: int, @@ -72,8 +68,8 @@ def get_metadata_cls() -> Type["DifferentialFlashAttentionMetadata"]: @staticmethod def get_builder_cls() -> Type["DifferentialFlashAttentionMetadataBuilder"]: return DifferentialFlashAttentionMetadataBuilder - - + + @dataclass class DifferentialFlashAttentionMetadata(AttentionMetadata): """Metadata for FlashAttentionBackend. @@ -136,8 +132,10 @@ class DifferentialFlashAttentionMetadata(AttentionMetadata): # [4, 6], it is [0, 4, 10]. seq_start_loc: Optional[torch.Tensor] = None - _cached_prefill_metadata: Optional["DifferentialFlashAttentionMetadata"] = None - _cached_decode_metadata: Optional["DifferentialFlashAttentionMetadata"] = None + _cached_prefill_metadata: Optional[ + "DifferentialFlashAttentionMetadata"] = None + _cached_decode_metadata: Optional[ + "DifferentialFlashAttentionMetadata"] = None # Begin encoder attn & enc/dec cross-attn fields... @@ -178,7 +176,8 @@ def is_all_cross_attn_metadata_set(self): return is_all_cross_attn_metadata_set(self) @property - def prefill_metadata(self) -> Optional["DifferentialFlashAttentionMetadata"]: + def prefill_metadata( + self) -> Optional["DifferentialFlashAttentionMetadata"]: if self.num_prefills == 0: return None @@ -205,9 +204,10 @@ def prefill_metadata(self) -> Optional["DifferentialFlashAttentionMetadata"]: self.context_lens_tensor[:self.num_prefills]) block_tables = (None if self.block_tables is None else self.block_tables[:self.num_prefills]) - cross_layer_shared_block_tables = (None if self.cross_layer_shared_block_tables is None else - self.cross_layer_shared_block_tables[:self.num_prefills]) - + cross_layer_shared_block_tables = ( + None if self.cross_layer_shared_block_tables is None else + self.cross_layer_shared_block_tables[:self.num_prefills]) + self._cached_prefill_metadata = DifferentialFlashAttentionMetadata( num_prefills=self.num_prefills, num_prefill_tokens=self.num_prefill_tokens, @@ -238,7 +238,8 @@ def prefill_metadata(self) -> Optional["DifferentialFlashAttentionMetadata"]: return self._cached_prefill_metadata @property - def decode_metadata(self) -> Optional["DifferentialFlashAttentionMetadata"]: + def decode_metadata( + self) -> Optional["DifferentialFlashAttentionMetadata"]: if self.num_decode_tokens == 0: return None @@ -254,8 +255,9 @@ def decode_metadata(self) -> Optional["DifferentialFlashAttentionMetadata"]: self.seq_lens_tensor[self.num_prefills:]) block_tables = (None if self.block_tables is None else self.block_tables[self.num_prefills:]) - cross_layer_shared_block_tables = (None if self.cross_layer_shared_block_tables is None else - self.cross_layer_shared_block_tables[self.num_prefills:]) + cross_layer_shared_block_tables = ( + None if self.cross_layer_shared_block_tables is None else + self.cross_layer_shared_block_tables[self.num_prefills:]) self._cached_decode_metadata = DifferentialFlashAttentionMetadata( num_prefills=0, num_prefill_tokens=0, @@ -448,7 +450,8 @@ def _add_seq_group( else: cross_layer_shared_block_table = block_tables[seq_id][ -curr_sliding_window_block:] - self.cross_layer_shared_block_tables.append(cross_layer_shared_block_table) + self.cross_layer_shared_block_tables.append( + cross_layer_shared_block_table) # Compute slot mapping. is_profile_run = is_block_tables_empty(block_tables) @@ -459,10 +462,9 @@ def _add_seq_group( seq_len, context_len, start_idx, self.block_size, inter_data.block_tables) - def _get_graph_runner_block_tables( - self, num_seqs: int, - block_tables: List[List[int]], - graph_block_tables) -> torch.Tensor: + def _get_graph_runner_block_tables(self, num_seqs: int, + block_tables: List[List[int]], + graph_block_tables) -> torch.Tensor: # The shape of graph_block_tables is # [max batch size, max context len // block size]. # max_batch_size, max_blocks = self.runner.graph_block_tables.shape @@ -526,13 +528,16 @@ def build(self, seq_lens: List[int], query_lens: List[int], self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) self.block_tables.extend([] * cuda_graph_pad_size) - self.cross_layer_shared_block_tables.extend([] * cuda_graph_pad_size) - + self.cross_layer_shared_block_tables.extend([] * + cuda_graph_pad_size) + num_decode_tokens = batch_size - self.num_prefill_tokens block_tables = self._get_graph_runner_block_tables( num_seqs, self.block_tables, self.runner.graph_block_tables) - cross_layer_shared_block_tables = self._get_graph_runner_block_tables( - num_seqs, self.cross_layer_shared_block_tables, self.runner.cross_layer_shared_graph_block_tables) + cross_layer_shared_block_tables = \ + self._get_graph_runner_block_tables( + num_seqs, self.cross_layer_shared_block_tables, + self.runner.cross_layer_shared_graph_block_tables) else: block_tables = make_tensor_with_pad( self.block_tables, @@ -630,9 +635,11 @@ def __init__( use_irope: bool = False, differential_flash_attention_config: Optional[Dict[str, Any]] = None, ) -> None: - self.differential_flash_attention_config = differential_flash_attention_config - self.used_shared_kv_cache = self.differential_flash_attention_config.get( - "used_shared_kv_cache", False) + self.differential_flash_attention_config = \ + differential_flash_attention_config + self.used_shared_kv_cache = \ + self.differential_flash_attention_config.get( + "used_shared_kv_cache", False) # if kv_sharing_target_layer_name is not None: # raise NotImplementedError("KV sharing is not supported in V0.") self.kv_sharing_target_layer_name = kv_sharing_target_layer_name @@ -686,44 +693,36 @@ def split_heads(self, x): x1 = x[..., 0, :] x2 = x[..., 1, :] return x1.contiguous(), x2.contiguous() - + def split_kv_cache(self, x): # split by num_heads, the stripe pattern is friendly to tensor parallel. if x.numel() == 0: return torch.empty(0), torch.empty(0) - + x1, x2 = x[0], x[1] return x1, x2 - def populate_kv_cache(self, - layer: AttentionLayer, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: torch.Tensor, + def populate_kv_cache(self, layer: AttentionLayer, key: torch.Tensor, + value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: DifferentialFlashAttentionMetadata): - if (kv_cache.numel() > 0): - if (key is not None) and (value is not None): - updated_slot_mapping = attn_metadata.slot_mapping - torch.ops._C_cache_ops.reshape_and_cache_flash( - key, - value, - kv_cache[0], - kv_cache[1], - updated_slot_mapping.flatten(), - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) + if kv_cache.numel() > 0 and key is not None and value is not None: + updated_slot_mapping = attn_metadata.slot_mapping + torch.ops._C_cache_ops.reshape_and_cache_flash( + key, + value, + kv_cache[0], + kv_cache[1], + updated_slot_mapping.flatten(), + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) def forward_generate_kv_cache( - self, - query: torch.Tensor, - key: Optional[torch.Tensor], - value: Optional[torch.Tensor], - k_cache: torch.Tensor, - v_cache: torch.Tensor, - attn_metadata: AttentionMetadata - ) -> torch.Tensor: + self, query: torch.Tensor, key: Optional[torch.Tensor], + value: Optional[torch.Tensor], k_cache: torch.Tensor, + v_cache: torch.Tensor, + attn_metadata: AttentionMetadata) -> torch.Tensor: head_size = self.head_size num_heads = self.num_heads // 2 @@ -739,9 +738,11 @@ def forward_generate_kv_cache( num_prefill_tokens = attn_metadata.num_prefill_tokens num_decode_tokens = attn_metadata.num_decode_tokens - assert key.shape[0] == num_prefill_tokens + num_decode_tokens, "key shape mismatch" - assert value.shape[0] == num_prefill_tokens + num_decode_tokens, "value shape mismatch" - + assert key.shape[ + 0] == num_prefill_tokens + num_decode_tokens, "key shape mismatch" + assert value.shape[ + 0] == num_prefill_tokens + num_decode_tokens, "value shape mismatch" + output = torch.empty_like(query) # Query for decode. KV is not needed because it is already cached. decode_query = query[num_prefill_tokens:] @@ -752,7 +753,8 @@ def forward_generate_kv_cache( value = value[:num_prefill_tokens] assert query.shape[0] == num_prefill_tokens, "query shape mismatch" - assert decode_query.shape[0] == num_decode_tokens, "decode query shape mismatch" + assert decode_query.shape[ + 0] == num_decode_tokens, "decode query shape mismatch" if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. @@ -772,7 +774,8 @@ def forward_generate_kv_cache( alibi_slopes=self.alibi_slopes, softcap=self.logits_soft_cap, ) - assert prefill_output.shape == output[:num_prefill_tokens].shape + assert prefill_output.shape == output[: + num_prefill_tokens].shape output[:num_prefill_tokens] = prefill_output else: raise Exception("prefix caching not supported") @@ -793,14 +796,13 @@ def forward_generate_kv_cache( softcap=self.logits_soft_cap, ).squeeze(1) except Exception as e: - logger.error( - f"Error in PagedAttention.forward_decode: {str(e)}") + logger.error("Error in PagedAttention.forward_decode: %s", + str(e)) raise e # Reshape the output tensor. return output.view(-1, num_heads, head_size) - - + def forward_with_kv_cache_only( self, query: torch.Tensor, @@ -808,8 +810,8 @@ def forward_with_kv_cache_only( v_cache: torch.Tensor, attn_metadata: AttentionMetadata, ): - if not attn_metadata.decode_metadata: - block_tables_arg = attn_metadata.cross_layer_shared_block_tables + if not attn_metadata.decode_metadata: + block_tables_arg = attn_metadata.cross_layer_shared_block_tables else: block_tables_arg = attn_metadata.block_tables @@ -854,17 +856,19 @@ def forward( We use torch's .expand() to avoid duplicating values """ if self.lambda_full is None: - self.lambda_init = self.differential_flash_attention_config["lambda_init"] + self.lambda_init = self.differential_flash_attention_config[ + "lambda_init"] lambda_q1 = self.differential_flash_attention_config["lambda_q1"] lambda_k1 = self.differential_flash_attention_config["lambda_k1"] lambda_q2 = self.differential_flash_attention_config["lambda_q2"] lambda_k2 = self.differential_flash_attention_config["lambda_k2"] - lambda_1 = torch.exp(torch.sum(lambda_q1 * lambda_k1, dim=-1).float()).type_as(q) - lambda_2 = torch.exp(torch.sum(lambda_q2 * lambda_k2, dim=-1).float()).type_as(q) + lambda_1 = torch.exp( + torch.sum(lambda_q1 * lambda_k1, dim=-1).float()).type_as(q) + lambda_2 = torch.exp( + torch.sum(lambda_q2 * lambda_k2, dim=-1).float()).type_as(q) self.lambda_full = lambda_1 - lambda_2 + self.lambda_init - - if not self.used_shared_kv_cache: # need to generate kv-cache + if not self.used_shared_kv_cache: # need to generate kv-cache q = q.view(-1, self.num_heads, self.head_size) k = k.view(-1, self.num_kv_heads, self.head_size) v = v.view(-1, self.num_kv_heads, self.head_size) @@ -873,29 +877,37 @@ def forward( k1, k2 = self.split_heads(k) v1, v2 = self.split_heads(v) - # kv_cache shape is (2, 2, num_blocks, block_size, num_kv_heads // 2, head_size) + # kv_cache shape is (2, 2, num_blocks, block_size, num_kv_heads // 2, head_size) # noqa: E501 # Split by half along the first dimension. kv_cache1, kv_cache2 = self.split_kv_cache(kv_cache) assert kv_cache1.is_contiguous(), "kv_cache1 is not contiguous" assert kv_cache2.is_contiguous(), "kv_cache2 is not contiguous" - + if kv_cache1.numel() != 0: self.populate_kv_cache(layer, k1, v1, kv_cache1, attn_metadata) self.populate_kv_cache(layer, k2, v2, kv_cache2, attn_metadata) - + key_cache1, value_cache1 = self.split_kv_cache(kv_cache1) key_cache2, value_cache2 = self.split_kv_cache(kv_cache2) else: key_cache1, value_cache1 = torch.empty(0), torch.empty(0) key_cache2, value_cache2 = torch.empty(0), torch.empty(0) - attn11 = self.forward_generate_kv_cache(q1, k1, v1, key_cache1, value_cache1, attn_metadata) - attn12 = self.forward_generate_kv_cache(q1, k1, v2, key_cache1, value_cache2, attn_metadata) + attn11 = self.forward_generate_kv_cache(q1, k1, v1, key_cache1, + value_cache1, + attn_metadata) + attn12 = self.forward_generate_kv_cache(q1, k1, v2, key_cache1, + value_cache2, + attn_metadata) attn11 = attn11.view(q1.shape) attn12 = attn12.view(q1.shape) attn1 = torch.cat([attn11, attn12], dim=-1) - attn21 = self.forward_generate_kv_cache(q2, k2, v1, key_cache2, value_cache1, attn_metadata) - attn22 = self.forward_generate_kv_cache(q2, k2, v2, key_cache2, value_cache2, attn_metadata) + attn21 = self.forward_generate_kv_cache(q2, k2, v1, key_cache2, + value_cache1, + attn_metadata) + attn22 = self.forward_generate_kv_cache(q2, k2, v2, key_cache2, + value_cache2, + attn_metadata) attn21 = attn21.view(q2.shape) attn22 = attn22.view(q2.shape) attn2 = torch.cat([attn21, attn22], dim=-1) @@ -905,24 +917,34 @@ def forward( attn = self.subln(attn) attn = attn * (1 - self.lambda_init) # reshape back to 2 * num_head - attn_output = rearrange(attn, "... H (two D) -> ... (H two) D", two=2) + attn_output = rearrange(attn, + "... H (two D) -> ... (H two) D", + two=2) - else: # re-use the kv cache, full attention + else: # re-use the kv cache, full attention q = q.view(-1, self.num_heads, self.head_size) q1, q2 = self.split_heads(q) - # kv_cache shape is (2, num_blocks, block_size, num_kv_heads, head_size) + # kv_cache shape is (2, num_blocks, block_size, num_kv_heads, head_size) # noqa: E501 kv_cache1, kv_cache2 = self.split_kv_cache(kv_cache) key_cache1, value_cache1 = kv_cache1[0], kv_cache1[1] key_cache2, value_cache2 = kv_cache2[0], kv_cache2[1] - - attn11 = self.forward_with_kv_cache_only(q1, key_cache1, value_cache1, attn_metadata) - attn12 = self.forward_with_kv_cache_only(q1, key_cache1, value_cache2, attn_metadata) + + attn11 = self.forward_with_kv_cache_only(q1, key_cache1, + value_cache1, + attn_metadata) + attn12 = self.forward_with_kv_cache_only(q1, key_cache1, + value_cache2, + attn_metadata) attn11 = attn11.view(q1.shape) attn12 = attn12.view(q1.shape) attn1 = torch.cat([attn11, attn12], dim=-1) - attn21 = self.forward_with_kv_cache_only(q2, key_cache2, value_cache1, attn_metadata) - attn22 = self.forward_with_kv_cache_only(q2, key_cache2, value_cache2, attn_metadata) + attn21 = self.forward_with_kv_cache_only(q2, key_cache2, + value_cache1, + attn_metadata) + attn22 = self.forward_with_kv_cache_only(q2, key_cache2, + value_cache2, + attn_metadata) attn21 = attn21.view(q2.shape) attn22 = attn22.view(q2.shape) attn2 = torch.cat([attn21, attn22], dim=-1) @@ -931,6 +953,8 @@ def forward( attn = self.subln(attn) attn = attn * (1 - self.lambda_init) # reshape back to 2 * num_head - attn_output = rearrange(attn, "... H (two D) -> ... (H two) D", two=2) + attn_output = rearrange(attn, + "... H (two D) -> ... (H two) D", + two=2) attn_output = attn_output.view(-1, self.num_heads * self.head_size) - return attn_output \ No newline at end of file + return attn_output diff --git a/vllm/model_executor/models/phi4flash.py b/vllm/model_executor/models/phi4flash.py index 300d14cbfcb..1ded2ff476f 100644 --- a/vllm/model_executor/models/phi4flash.py +++ b/vllm/model_executor/models/phi4flash.py @@ -1,43 +1,37 @@ -from typing import List, Optional, Tuple, Union, Iterable +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math +from collections.abc import Iterable +from typing import Optional, Union import torch import torch.nn as nn - -from einops import rearrange from transformers.activations import ACT2FN -from typing import Iterable, List, Optional, Set, Tuple, Union +from vllm.attention import Attention, AttentionMetadata, AttentionType from vllm.config import CacheConfig, VllmConfig -from vllm.attention import Attention, AttentionMetadata -from vllm.config import CacheConfig -from vllm.distributed import (get_pp_group, get_tensor_model_parallel_world_size) -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - RowParallelLinear, - ColumnParallelLinear) +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.forward_context import ForwardContext, get_forward_context +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead) -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors -from vllm.model_executor.models.mamba_cache import (MambaCacheManager, - MambaCacheParams) -from vllm.model_executor.models.interfaces import (HasInnerState, - IsHybrid, SupportsV0Only) from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( causal_conv1d_fn, causal_conv1d_update) from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( selective_scan_fn, selective_state_update) -from vllm.attention.backends.abstract import (AttentionMetadata, AttentionType) - -from vllm.logger import init_logger -from .utils import (maybe_prefix, make_layers) from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler -from vllm.forward_context import ForwardContext, get_forward_context -from vllm.config import CacheConfig, get_current_vllm_config from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid, + SupportsV0Only) +from vllm.model_executor.models.mamba_cache import (MambaCacheManager, + MambaCacheParams) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .utils import make_layers, maybe_prefix logger = init_logger(__name__) @@ -46,7 +40,7 @@ class SwiGLUActivation(nn.Module): def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: return x1 * nn.functional.silu(x2) - + class SambaYMLP(nn.Module): """Gated Linear Unit. @@ -61,8 +55,12 @@ def __init__(self, config): super().__init__() self.config = config - self.fc1 = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False) - self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) + self.fc1 = nn.Linear(config.hidden_size, + 2 * config.intermediate_size, + bias=False) + self.fc2 = nn.Linear(config.intermediate_size, + config.hidden_size, + bias=False) self.activation_fn = ACT2FN[config.hidden_act] @@ -77,71 +75,90 @@ def get_virtual_engine(): forward_context: ForwardContext = get_forward_context() return forward_context.virtual_engine + class SambaYAttention(nn.Module): - def __init__(self, - config, - layer_idx: Optional[int] = None, - yoco_cross: bool = False, + + def __init__(self, + config, + layer_idx: Optional[int] = None, + yoco_cross: bool = False, cache_config: Optional[CacheConfig] = None, prefix: str = ""): super().__init__() if layer_idx is None: logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " - "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) + f"Instantiating {self.__class__.__name__} without passing " + "a `layer_idx` is not recommended and will lead to errors " + "during the forward call if caching is used. Please make " + "sure to provide a `layer_idx` when creating this class.") self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads self.num_key_value_heads = config.num_key_value_heads self.yoco_cross = yoco_cross - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - op_size = self.num_heads * self.head_dim + 2 * (self.num_key_value_heads * self.head_dim) - self.out_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=True) + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError("hidden_size must be divisible by num_heads " + f"(got `hidden_size`: {self.hidden_size} and " + f"`num_heads`: {self.num_heads}).") + + op_size = self.num_heads * self.head_dim + 2 * ( + self.num_key_value_heads * self.head_dim) + self.out_proj = nn.Linear(self.num_heads * self.head_dim, + self.hidden_size, + bias=True) if yoco_cross: - self.Wqkv = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) + self.Wqkv = nn.Linear(self.hidden_size, + self.num_heads * self.head_dim, + bias=True) else: self.Wqkv = nn.Linear(self.hidden_size, op_size, bias=True) # disable sliding window for the second half of the model sliding_window = config.interleaved_sliding_window[layer_idx] if layer_idx >= config.num_hidden_layers // 2: - assert sliding_window is None, "sliding_window must be none for the second decoder" + assert sliding_window is None, \ + "sliding_window must be none for the second decoder" else: - assert sliding_window is not None, "sliding_window must be set for the first decoder" + assert sliding_window is not None, \ + "sliding_window must be set for the first decoder" assert self.num_heads % 2 == 0, 'num_heads should be even' assert self.num_key_value_heads % 2 == 0, 'num_heads should be even' - + self.lambda_init = self.lambda_init_fn(layer_idx) - self.lambda_q1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1)) - self.lambda_k1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1)) - self.lambda_q2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1)) - self.lambda_k2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1)) - self.subln = nn.RMSNorm(2 * self.head_dim, eps=1e-5, elementwise_affine=True) - - params = {'differential_flash_attention_config': - { - 'used_shared_kv_cache': self.yoco_cross, + self.lambda_q1 = nn.Parameter( + torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, + std=0.1)) + self.lambda_k1 = nn.Parameter( + torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, + std=0.1)) + self.lambda_q2 = nn.Parameter( + torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, + std=0.1)) + self.lambda_k2 = nn.Parameter( + torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, + std=0.1)) + self.subln = nn.RMSNorm(2 * self.head_dim, + eps=1e-5, + elementwise_affine=True) + + params = { + 'differential_flash_attention_config': { + 'used_shared_kv_cache': self.yoco_cross, 'lambda_init': self.lambda_init, 'lambda_q1': self.lambda_q1, 'lambda_k1': self.lambda_k1, 'lambda_q2': self.lambda_q2, 'lambda_k2': self.lambda_k2, "subln": self.subln, - } + } } if yoco_cross: - kv_shared_layer_index = config.num_hidden_layers//2 + 1 - kv_sharing_target_layer_name = f"model.layers.{kv_shared_layer_index}.self_attn.attn" # noqa: E501 + kv_shared_layer_index = config.num_hidden_layers // 2 + 1 + kv_sharing_target_layer_name = \ + f"model.layers.{kv_shared_layer_index}.self_attn.attn" else: kv_sharing_target_layer_name = None @@ -155,22 +172,25 @@ def __init__(self, prefix=f"{prefix}.attn", attn_type=AttentionType.DECODER, kv_sharing_target_layer_name=kv_sharing_target_layer_name, - **params - ) + **params) def lambda_init_fn(self, depth): return 0.8 - 0.6 * math.exp(-0.3 * depth) def forward( - self, - hidden_states: torch.Tensor, - ): + self, + hidden_states: torch.Tensor, + ): - if not self.yoco_cross: # need to generate kv-cache + if not self.yoco_cross: # need to generate kv-cache qkv = self.Wqkv(hidden_states) - q, k, v = qkv.split([self.hidden_size, self.num_key_value_heads * self.head_dim, self.num_key_value_heads * self.head_dim], dim=-1) + q, k, v = qkv.split([ + self.hidden_size, self.num_key_value_heads * self.head_dim, + self.num_key_value_heads * self.head_dim + ], + dim=-1) attn_output = self.attn(q, k, v) - else: # re-use the kv cache, full attention + else: # re-use the kv cache, full attention q = self.Wqkv(hidden_states) attn_output = self.attn(q, None, None) attn_output = attn_output.view(-1, self.num_heads * self.head_dim) @@ -178,6 +198,7 @@ def forward( class Phi4Mamba(nn.Module): + def __init__( self, d_model, @@ -187,7 +208,7 @@ def __init__( dt_rank="auto", dt_min=0.001, dt_max=0.1, - dt_init="random", # difference + dt_init="random", # difference dt_scale=1.0, # difference dt_init_floor=1e-4, conv_bias=True, @@ -199,7 +220,7 @@ def __init__( yoco_cross=False, yoco_kv=False, ): - factory_kwargs = {"params_dtype": dtype} # difference + factory_kwargs = {"params_dtype": dtype} # difference super().__init__() self.yoco_cross = yoco_cross self.yoco_kv = yoco_kv @@ -208,13 +229,20 @@ def __init__( self.d_conv = d_conv self.expand = expand self.d_inner = int(self.expand * self.d_model) - self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank + self.dt_rank = math.ceil(self.d_model / + 16) if dt_rank == "auto" else dt_rank self.use_fast_path = use_fast_path self.layer_idx = layer_idx self.swiGluActivation = SwiGLUActivation() if self.yoco_cross: - self.in_proj = MergedColumnParallelLinear(self.d_model, [self.d_inner], bias=bias, **factory_kwargs) - self.out_proj = RowParallelLinear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) + self.in_proj = MergedColumnParallelLinear(self.d_model, + [self.d_inner], + bias=bias, + **factory_kwargs) + self.out_proj = RowParallelLinear(self.d_inner, + self.d_model, + bias=bias, + **factory_kwargs) return self.conv1d = ColumnParallelLinear( input_size=d_conv, @@ -228,11 +256,12 @@ def __init__( # doesn't allow to override it self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) - self.in_proj = MergedColumnParallelLinear(self.d_model, - [self.d_inner] * 2, - bias=bias, - params_dtype=dtype, - ) + self.in_proj = MergedColumnParallelLinear( + self.d_model, + [self.d_inner] * 2, + bias=bias, + params_dtype=dtype, + ) # selective projection used to make dt, B and C input dependent self.x_proj = RowParallelLinear( @@ -245,12 +274,13 @@ def __init__( # time step projection (discretization) - # In the forward we need to apply dt_proj without the bias, # as the bias is added in the selective scan kernel. - self.dt_proj = ColumnParallelLinear(self.dt_rank, - self.d_inner, - bias=True, - skip_bias_add=True, - params_dtype=dtype, - ) + self.dt_proj = ColumnParallelLinear( + self.dt_rank, + self.d_inner, + bias=True, + skip_bias_add=True, + params_dtype=dtype, + ) # # D "skip" parameter # self.D = nn.Parameter(torch.ones(self.d_inner)) # Keep in fp32 @@ -271,23 +301,22 @@ def __init__( ) self.activation = "silu" - def forward( - self, - hidden_states: torch.Tensor, - attn_metadata: AttentionMetadata, - mamba_cache_params: MambaCacheParams, - yoco_key_values = None - ) -> torch.Tensor: - + def forward(self, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + mamba_cache_params: MambaCacheParams, + yoco_key_values=None) -> torch.Tensor: + if self.yoco_cross: out = self.in_proj(hidden_states)[0] out = self.swiGluActivation(yoco_key_values, out) out = self.out_proj(out) - return out[0], yoco_key_values + return out[0], yoco_key_values # 1. Gated MLP's linear projection # projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1) - projected_states = self.in_proj(hidden_states.to(self.in_proj.weight.dtype))[0].transpose(-2, -1) + projected_states = self.in_proj( + hidden_states.to(self.in_proj.weight.dtype))[0].transpose(-2, -1) hidden_states, gate = projected_states.chunk(2, dim=-2) # 2. Convolution sequence transformation @@ -385,34 +414,48 @@ def forward( class SambaYDecoderLayer(nn.Module): - - def __init__(self, - config, - layer_idx, - cache_config, - prefix: str = "",) -> None: + + def __init__( + self, + config, + layer_idx, + cache_config, + prefix: str = "", + ) -> None: super().__init__() - + self.config = config self.layer_idx = layer_idx self.mlp = SambaYMLP(config) - self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - + self.input_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.yoco_mb = False self.yoco_cross = False - assert config.num_hidden_layers % 4 == 0, 'n_layer should be divisible by 4 for SambaY + yoco' - if layer_idx >= config.num_hidden_layers//2: + assert config.num_hidden_layers % 4 == 0, \ + 'n_layer should be divisible by 4 for SambaY + yoco' + if layer_idx >= config.num_hidden_layers // 2: self.yoco_mb = True - self.yoco_cross = (layer_idx >= (config.num_hidden_layers//2 +2)) - self.use_mamba = config.mb_per_layer > 0 and layer_idx % config.mb_per_layer == 0 + self.yoco_cross = (layer_idx + >= (config.num_hidden_layers // 2 + 2)) + self.use_mamba = config.mb_per_layer > 0 and \ + layer_idx % config.mb_per_layer == 0 if self.use_mamba: factory_kwargs = {"dtype": None} - self.attn = Phi4Mamba(config.hidden_size, layer_idx=layer_idx, - yoco_cross=self.yoco_cross, yoco_kv=self.yoco_mb, **factory_kwargs) + self.attn = Phi4Mamba(config.hidden_size, + layer_idx=layer_idx, + yoco_cross=self.yoco_cross, + yoco_kv=self.yoco_mb, + **factory_kwargs) else: - self.attn = SambaYAttention(config, layer_idx=layer_idx, yoco_cross=self.yoco_cross, cache_config=cache_config, prefix=f"{prefix}.self_attn") - self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.attn = SambaYAttention(config, + layer_idx=layer_idx, + yoco_cross=self.yoco_cross, + cache_config=cache_config, + prefix=f"{prefix}.self_attn") + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) def forward( self, @@ -428,39 +471,35 @@ def forward( assert mamba_cache_params is None residual = hidden_states - hidden_states = self.input_layernorm(hidden_states.to(dtype=self.input_layernorm.weight.dtype)) + hidden_states = self.input_layernorm( + hidden_states.to(dtype=self.input_layernorm.weight.dtype)) if self.use_mamba: - attn_outputs, ssm_output = self.attn( - hidden_states, - attn_metadata, - mamba_cache_params, - yoco_key_values = ssm_output - ) + attn_outputs, ssm_output = self.attn(hidden_states, + attn_metadata, + mamba_cache_params, + yoco_key_values=ssm_output) residual = residual.to(torch.float32) else: - attn_outputs = self.attn( - hidden_states, - ) + attn_outputs = self.attn(hidden_states, ) hidden_states = residual + attn_outputs residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states.to(dtype=self.post_attention_layernorm.weight.dtype)) + hidden_states = self.post_attention_layernorm( + hidden_states.to(dtype=self.post_attention_layernorm.weight.dtype)) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states return hidden_states, ssm_output - + class SambaYModel(nn.Module): - def __init__( - self, - config, - cache_config = None, - quant_config = None, - lora_config = None, - prefix: str = "" - ) -> None: + def __init__(self, + config, + cache_config=None, + quant_config=None, + lora_config=None, + prefix: str = "") -> None: super().__init__() self.config = config self.vocab_size = config.vocab_size @@ -470,22 +509,24 @@ def __init__( org_num_embeddings=config.vocab_size, ) - # Pipeline parallel is not supported since the second half of the layers share the kv cache. + # Pipeline parallel is not supported since the second half of + # the layers share the kv cache. if get_pp_group().world_size != 1: raise ValueError("Pipeline Parallel not supported") - + self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: SambaYDecoderLayer(config, - int(prefix.split('.')[-1]), - cache_config, - prefix=prefix), + int(prefix.split('.')[-1]), + cache_config, + prefix=prefix), prefix=f"{prefix}.layers") - self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.final_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) - + def forward( self, input_ids: Optional[torch.Tensor], @@ -495,7 +536,7 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - + if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -511,47 +552,49 @@ def forward( layer = self.layers[i] if i == self.config.num_hidden_layers // 2 + 2: # profile run - kv_cache_idx = self.config.num_hidden_layers//2 + 1 + kv_cache_idx = self.config.num_hidden_layers // 2 + 1 cache_layer = self.layers[kv_cache_idx] kv_cache = cache_layer.attn.attn.kv_cache if kv_cache[0].numel() == 0: break - # Starting from this layer, we do not need to calculate the kv cache since we reuse - # the kv cache from last layer. If in prefill phase, we can prune> truncate + # Starting from this layer, we do not need to calculate + # the kv cache since we reuse the kv cache from last layer. + # If in prefill phase, we can prune> truncate # the hidden state to save computation cost. if attn_metadata.prefill_metadata: - selected_token_indices = torch.cumsum(attn_metadata.seq_lens_tensor, dim=0) - 1 - hidden_states = hidden_states.index_select(0, selected_token_indices) - ssm_output = ssm_output.index_select(0, selected_token_indices) + selected_token_indices = torch.cumsum( + attn_metadata.seq_lens_tensor, dim=0) - 1 + hidden_states = hidden_states.index_select( + 0, selected_token_indices) + ssm_output = ssm_output.index_select( + 0, selected_token_indices) if layer.use_mamba: - if i < self.config.num_hidden_layers // 2: - mamba_cache = mamba_cache_params.at_layer_idx(mamba_state_idx) - mamba_state_idx += 1 - elif not layer.yoco_cross: - mamba_cache = mamba_cache_params.at_layer_idx(mamba_state_idx) + if i < self.config.num_hidden_layers // 2 or \ + not layer.yoco_cross: + mamba_cache = mamba_cache_params.at_layer_idx( + mamba_state_idx) mamba_state_idx += 1 else: - mamba_cache = mamba_cache_params.at_layer_idx(mamba_state_idx-1) - - hidden_states, ssm_output = layer( - hidden_states, - positions, - attn_metadata, - mamba_cache, - ssm_output = ssm_output - ) + mamba_cache = mamba_cache_params.at_layer_idx( + mamba_state_idx - 1) + + hidden_states, ssm_output = layer(hidden_states, + positions, + attn_metadata, + mamba_cache, + ssm_output=ssm_output) else: hidden_states, ssm_output = layer( hidden_states, positions, attn_metadata, - None, # mamba_cache_params - ssm_output = ssm_output - ) + None, # mamba_cache_params + ssm_output=ssm_output) - hidden_states = self.final_layernorm(hidden_states.to(dtype=self.final_layernorm.weight.dtype)) + hidden_states = self.final_layernorm( + hidden_states.to(dtype=self.final_layernorm.weight.dtype)) return hidden_states @@ -565,7 +608,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): scheduler_config = vllm_config.scheduler_config self.compilation_config = vllm_config.compilation_config self.vllm_config = vllm_config - # Prefix caching is not supported since there are mamba layers in this + # Prefix caching is not supported since there are mamba layers in this # mode. assert not cache_config.enable_prefix_caching, \ "Phi4flash currently does not support prefix caching" @@ -574,11 +617,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = config self.model_config = vllm_config.model_config self.scheduler_config = scheduler_config - self.model = SambaYModel( - config, - cache_config=cache_config, - prefix=maybe_prefix(prefix, "model") - ) + self.model = SambaYModel(config, + cache_config=cache_config, + prefix=maybe_prefix(prefix, "model")) self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size @@ -590,8 +631,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility - if not lora_config else - lora_config.lora_vocab_padding_size), + if not lora_config else lora_config.lora_vocab_padding_size), quant_config=quant_config, ) self.embedding_bias = None @@ -611,25 +651,27 @@ def forward( **kwargs, ) -> Union[torch.Tensor, IntermediateTensors]: if self.mamba_cache is None: - num_mamba_layers = self.config.num_hidden_layers // 2 // self.config.mb_per_layer + 1 + num_mamba_layers = self.config.num_hidden_layers \ + // 2 // self.config.mb_per_layer + 1 self.mamba_cache = MambaCacheManager( - self.vllm_config, - self.lm_head.weight.dtype, num_mamba_layers, *self._get_mamba_cache_shape() - ) + self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers, + *self._get_mamba_cache_shape()) mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) attn_metadata = get_forward_context().attn_metadata - hidden_states = self.model(input_ids, positions, - attn_metadata, mamba_cache_params, - intermediate_tensors, inputs_embeds) + hidden_states = self.model(input_ids, positions, attn_metadata, + mamba_cache_params, intermediate_tensors, + inputs_embeds) return hidden_states - def _get_mamba_cache_shape(self) -> Tuple[Optional[Tuple[int, int]], Optional[Tuple[int, int]]]: + def _get_mamba_cache_shape( + self + ) -> tuple[Optional[tuple[int, int]], Optional[tuple[int, int]]]: world_size = get_tensor_model_parallel_world_size() hidden_size = self.config.hidden_size - mamba_expand = self.config.mamba_expand # 2 - mamba_d_conv = self.config.mamba_d_conv # 4 - mamba_d_state = self.config.mamba_d_state # 16 + mamba_expand = self.config.mamba_expand # 2 + mamba_d_conv = self.config.mamba_d_conv # 4 + mamba_d_state = self.config.mamba_d_state # 16 conv_state_shape = ( mamba_expand * hidden_size // world_size, mamba_d_conv - 1, @@ -652,15 +694,16 @@ def compute_logits( hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - # If the shape is the same, it means that we have already prune hidden states manually. - prune_hidden_states = hidden_states.size(0) != sampling_metadata.selected_token_indices.size(0) + # If the shape is the same, it means that we have already + # prune hidden states manually. + prune_hidden_states = hidden_states.size( + 0) != sampling_metadata.selected_token_indices.size(0) processed_logits = self.logits_processor( - self.lm_head, - hidden_states, - sampling_metadata, - self.embedding_bias, - prune_hidden_states=prune_hidden_states - ) + self.lm_head, + hidden_states, + sampling_metadata, + self.embedding_bias, + prune_hidden_states=prune_hidden_states) return processed_logits def sample( @@ -673,7 +716,7 @@ def sample( def load_weights( self, - weights: Iterable[Tuple[str, torch.Tensor]], + weights: Iterable[tuple[str, torch.Tensor]], ): weights = {name: weight for name, weight in weights} adjusted_weights = {} @@ -684,14 +727,17 @@ def load_weights( if "inner_cross_attn." in name: name = name.replace("inner_cross_attn.", "") adjusted_weights[name] = weight - adjusted_weights["lm_head.weight"] = weights["model.embed_tokens.weight"] - loaded_params: Set[str] = set() + adjusted_weights["lm_head.weight"] = weights[ + "model.embed_tokens.weight"] + loaded_params: set[str] = set() for name, param in self.named_parameters(): - weight = adjusted_weights.get(name, None) + weight = adjusted_weights.get(name) if weight is not None and weight.shape != param.shape: - logger.warning(f"Shape mismatch: {name} {weight.shape} {param.shape}") + logger.warning("Shape mismatch: %s %s %s", name, weight.shape, + param.shape) loaded_params.add(name) - missing_keys, unexpected_keys = self.load_state_dict(adjusted_weights, strict=False) + missing_keys, unexpected_keys = self.load_state_dict(adjusted_weights, + strict=False) assert len(unexpected_keys) == 0, f"Unexpected keys: {unexpected_keys}" assert len(missing_keys) == 0, f"Missing keys: {missing_keys}" - return loaded_params \ No newline at end of file + return loaded_params diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index 29eea8bedf2..eea64072f3a 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -2888,9 +2888,9 @@ def get_mp_context(): def bind_kv_cache( - ctx: dict[str, Any], - kv_cache: list[list[torch.Tensor]], # [virtual_engine][layer_index] - shared_kv_cache_layers: dict[str, str], + ctx: dict[str, Any], + kv_cache: list[list[torch.Tensor]], # [virtual_engine][layer_index] + shared_kv_cache_layers: dict[str, str], ) -> None: # Bind the kv_cache tensor to Attention modules, similar to # ctx[layer_name].kv_cache[ve]=kv_cache[ve][extract_layer_index(layer_name)] @@ -2914,13 +2914,12 @@ def bind_kv_cache( extract_layer_index(layer_name) for layer_name in layer_need_kv_cache)) for layer_name in layer_need_kv_cache: - target_layer_name = shared_kv_cache_layers[layer_name] if layer_name \ - in shared_kv_cache_layers else layer_name + target_layer_name = shared_kv_cache_layers.get(layer_name, layer_name) kv_cache_idx = layer_index_sorted.index( extract_layer_index(target_layer_name)) forward_ctx = ctx[layer_name] assert len(forward_ctx.kv_cache) == len(kv_cache) - + for ve, ve_kv_cache in enumerate(kv_cache): assert kv_cache_idx < len(ve_kv_cache), \ "v0 doesn't support interleaving kv sharing, use v1 instead" diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 2db44b4f22e..ab926c2d33b 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1115,7 +1115,7 @@ def __init__( self.cross_layer_shared_graph_block_tables = np.zeros( (self.max_batchsize_to_capture, self.get_max_block_per_batch()), dtype=np.int32) - + # Attention-free but stateful models like Mamba need a placeholder attn # backend, as the attention metadata is needed to manage internal state. # However we must bypass attention selection altogether for some models diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 065da227d48..b2926dbd185 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -10,7 +10,7 @@ import vllm.envs as envs from vllm.attention.layer import Attention -from vllm.config import (VllmConfig, get_layers_from_vllm_config) +from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.device_allocator.cumem import CuMemAllocator from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment, @@ -27,7 +27,6 @@ SequenceGroupMetadata, SequenceGroupMetadataDelta) from vllm.utils import (GiB_bytes, MemorySnapshot, bind_kv_cache, memory_profiling) -from vllm.v1.worker.utils import initialize_kv_cache_for_kv_sharing from vllm.worker.cache_engine import CacheEngine from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner @@ -367,7 +366,7 @@ def _init_cache_engine(self): # a given amount of memory to accommodate longer context lengths # or enable more requests to be processed simultaneously. shared_kv_cache_layers[layer_name] = kv_tgt_layer - + bind_kv_cache(self.compilation_config.static_forward_context, self.gpu_cache, shared_kv_cache_layers) From a468c32c60f7246f28cf2c00bfca9875bc8be33c Mon Sep 17 00:00:00 2001 From: Congcong Chen Date: Wed, 9 Jul 2025 23:13:45 +0000 Subject: [PATCH 16/24] address lint Signed-off-by: Congcong Chen --- .../backends/differential_flash_attn.py | 51 ++++++++++++++++--- 1 file changed, 45 insertions(+), 6 deletions(-) diff --git a/vllm/attention/backends/differential_flash_attn.py b/vllm/attention/backends/differential_flash_attn.py index 0de3e8c43e1..0009c6b56ce 100644 --- a/vllm/attention/backends/differential_flash_attn.py +++ b/vllm/attention/backends/differential_flash_attn.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""" An implementation of https://arxiv.org/pdf/2410.05258 """ from collections import defaultdict from dataclasses import dataclass from itertools import accumulate @@ -11,14 +12,16 @@ from vllm import _custom_ops as ops # yapf conflicts with isort for this block # yapf: disable -from vllm.attention.backends.abstract import (AttentionImpl, AttentionLayer, +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, AttentionMetadata, AttentionMetadataBuilder, AttentionType, is_quantized_kv_cache) from vllm.attention.backends.flash_attn import FlashAttentionBackend # yapf: enable -from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, +from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState, + compute_slot_mapping, compute_slot_mapping_start_idx, is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set, @@ -38,9 +41,13 @@ logger = init_logger(__name__) -class DifferentialFlashAttentionBackend(FlashAttentionBackend): +class DifferentialFlashAttentionBackend(AttentionBackend): accept_output_buffer = False + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [32, 64, 96, 128, 160, 192, 224, 256] + @staticmethod def get_kv_cache_shape( num_blocks: int, @@ -69,6 +76,33 @@ def get_metadata_cls() -> Type["DifferentialFlashAttentionMetadata"]: def get_builder_cls() -> Type["DifferentialFlashAttentionMetadataBuilder"]: return DifferentialFlashAttentionMetadataBuilder + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + src_key_cache = src_kv_cache[0] + dst_key_cache = dst_kv_cache[0] + ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) + src_value_cache = src_kv_cache[1] + dst_value_cache = dst_kv_cache[1] + ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + key_caches = [kv_cache[0] for kv_cache in kv_caches] + value_caches = [kv_cache[1] for kv_cache in kv_caches] + + ops.copy_blocks(key_caches, value_caches, src_to_dists) + @dataclass class DifferentialFlashAttentionMetadata(AttentionMetadata): @@ -635,6 +669,8 @@ def __init__( use_irope: bool = False, differential_flash_attention_config: Optional[Dict[str, Any]] = None, ) -> None: + if differential_flash_attention_config is None: + differential_flash_attention_config = {} self.differential_flash_attention_config = \ differential_flash_attention_config self.used_shared_kv_cache = \ @@ -722,7 +758,7 @@ def forward_generate_kv_cache( self, query: torch.Tensor, key: Optional[torch.Tensor], value: Optional[torch.Tensor], k_cache: torch.Tensor, v_cache: torch.Tensor, - attn_metadata: AttentionMetadata) -> torch.Tensor: + attn_metadata: DifferentialFlashAttentionMetadata) -> torch.Tensor: head_size = self.head_size num_heads = self.num_heads // 2 @@ -758,7 +794,9 @@ def forward_generate_kv_cache( if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. - if k_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0: + if k_cache.numel() == 0 \ + or prefill_meta.block_tables is None \ + or prefill_meta.block_tables.numel() == 0: # normal attention prefill_output = flash_attn_varlen_func( q=query, @@ -808,7 +846,7 @@ def forward_with_kv_cache_only( query: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, - attn_metadata: AttentionMetadata, + attn_metadata: DifferentialFlashAttentionMetadata, ): if not attn_metadata.decode_metadata: block_tables_arg = attn_metadata.cross_layer_shared_block_tables @@ -838,6 +876,7 @@ def forward( kv_cache: torch.Tensor, attn_metadata: DifferentialFlashAttentionMetadata, output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention. From 8b1b0b6e8d81cabb4273e32049fe5b837f4e6d2c Mon Sep 17 00:00:00 2001 From: Congcong Chen Date: Wed, 9 Jul 2025 23:40:12 +0000 Subject: [PATCH 17/24] clean up Signed-off-by: Congcong Chen --- vllm/attention/layer.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 751cfe84299..f9c2d4f4983 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -160,10 +160,6 @@ def __init__( self.attn_type = attn_type if kv_sharing_target_layer_name is not None: - # if not envs.VLLM_USE_V1: - # raise NotImplementedError( - # "Cross-layer KV sharing is not supported in V0.") - validate_kv_sharing_target( prefix, kv_sharing_target_layer_name, From 0f92033096ad8666edd914ed33a9c6a3f14736ae Mon Sep 17 00:00:00 2001 From: Congcong Chen Date: Thu, 10 Jul 2025 08:30:24 +0000 Subject: [PATCH 18/24] update supported_models.md Signed-off-by: Congcong Chen --- docs/models/supported_models.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index ddc920aeb2d..eca37a09058 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -374,6 +374,7 @@ Specified using `--task generate`. | `Phi3ForCausalLM` | Phi-4, Phi-3 | `microsoft/Phi-4-mini-instruct`, `microsoft/Phi-4`, `microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, `microsoft/Phi-3-medium-128k-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Phi3SmallForCausalLM` | Phi-3-Small | `microsoft/Phi-3-small-8k-instruct`, `microsoft/Phi-3-small-128k-instruct`, etc. | | ✅︎ | ✅︎ | | `PhiMoEForCausalLM` | Phi-3.5-MoE | `microsoft/Phi-3.5-MoE-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Phi4FlashForCausalLM` | Phi-4-mini-flash-reasoning | `microsoft/microsoft/Phi-4-mini-instruct`, etc. | | | | | `PersimmonForCausalLM` | Persimmon | `adept/persimmon-8b-base`, `adept/persimmon-8b-chat`, etc. | | ✅︎ | ✅︎ | | `Plamo2ForCausalLM` | PLaMo2 | `pfnet/plamo-2-1b`, `pfnet/plamo-2-8b`, etc. | | | | | `QWenLMHeadModel` | Qwen | `Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ | From a01a2a07a155416154235c84810745f007564796 Mon Sep 17 00:00:00 2001 From: Congcong Chen Date: Fri, 11 Jul 2025 09:07:19 +0000 Subject: [PATCH 19/24] address comment Signed-off-by: Congcong Chen --- tests/models/registry.py | 2 ++ tests/models/test_initialization.py | 3 +++ tests/test_utils.py | 25 +++++++++++++++++++ vllm/attention/backends/blocksparse_attn.py | 3 ++- .../backends/differential_flash_attn.py | 2 -- .../backends/dual_chunk_flash_attn.py | 3 ++- vllm/attention/backends/flash_attn.py | 3 ++- vllm/attention/backends/flashinfer.py | 3 ++- vllm/attention/backends/hpu_attn.py | 3 ++- vllm/attention/backends/rocm_flash_attn.py | 3 ++- vllm/attention/backends/xformers.py | 3 ++- vllm/model_executor/models/phi4flash.py | 11 ++++---- vllm/utils/__init__.py | 13 +++++++--- 13 files changed, 60 insertions(+), 17 deletions(-) diff --git a/tests/models/registry.py b/tests/models/registry.py index fa10857313a..02d90ff35fa 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -248,6 +248,8 @@ def check_available_online( "Phi3SmallForCausalLM": _HfExamplesInfo("microsoft/Phi-3-small-8k-instruct", trust_remote_code=True, v0_only=True), + "Phi4FlashForCausalLM": _HfExamplesInfo("microsoft/Phi-4-mini-flash-reasoning", # noqa: E501 + trust_remote_code=True), "PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct", trust_remote_code=True), "Plamo2ForCausalLM": _HfExamplesInfo("pfnet/plamo-2-1b", diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py index 76726c0c820..038717a129e 100644 --- a/tests/models/test_initialization.py +++ b/tests/models/test_initialization.py @@ -95,6 +95,9 @@ def _initialize_kv_caches_v1(self, vllm_config): _initialize_kv_caches_v1), monkeypatch.context() as m): if model_info.v0_only: m.setenv("VLLM_USE_V1", "0") + if model_arch == "Phi4FlashForCausalLM": + # Phi4FlashForCausalLM only supports DIFFERENTIAL_FLASH_ATTN backend + m.setenv("VLLM_ATTENTION_BACKEND", "DIFFERENTIAL_FLASH_ATTN") LLM( model_info.default, tokenizer=model_info.tokenizer, diff --git a/tests/test_utils.py b/tests/test_utils.py index f90715fd751..28acacd2519 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -458,6 +458,31 @@ def test_bind_kv_cache(): assert ctx['layers.2.self_attn'].kv_cache[0] is kv_cache[2] assert ctx['layers.3.self_attn'].kv_cache[0] is kv_cache[3] +def test_bind_kv_cache_kv_sharing(): + from vllm.attention import Attention + + ctx = { + 'layers.0.self_attn': Attention(32, 128, 0.1), + 'layers.1.self_attn': Attention(32, 128, 0.1), + 'layers.2.self_attn': Attention(32, 128, 0.1), + 'layers.3.self_attn': Attention(32, 128, 0.1), + } + kv_cache = [ + torch.zeros((1, )), + torch.zeros((1, )), + torch.zeros((1, )), + torch.zeros((1, )), + ] + shared_kv_cache_layers = { + 'layers.2.self_attn': 'layers.1.self_attn', + 'layers.3.self_attn': 'layers.0.self_attn' + } + bind_kv_cache(ctx, [kv_cache], shared_kv_cache_layers) + assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[0] + assert ctx['layers.1.self_attn'].kv_cache[0] is kv_cache[1] + assert ctx['layers.2.self_attn'].kv_cache[0] is kv_cache[1] + assert ctx['layers.3.self_attn'].kv_cache[0] is kv_cache[0] + def test_bind_kv_cache_non_attention(): from vllm.attention import Attention diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py index fe9738d804c..e4338805f56 100644 --- a/vllm/attention/backends/blocksparse_attn.py +++ b/vllm/attention/backends/blocksparse_attn.py @@ -308,7 +308,8 @@ def __init__( kv_sharing_target_layer_name: Optional[str] = None, ) -> None: if kv_sharing_target_layer_name is not None: - raise NotImplementedError("KV sharing is not supported in V0.") + raise NotImplementedError("KV sharing is not supported in V0 " + "BLOCK_SPARSE_FLASH_ATTN Backend.") assert blocksparse_params is not None assert alibi_slopes is None, ValueError( "Alibi not support for blocksparse flash attention.") diff --git a/vllm/attention/backends/differential_flash_attn.py b/vllm/attention/backends/differential_flash_attn.py index 0009c6b56ce..5435adcd64b 100644 --- a/vllm/attention/backends/differential_flash_attn.py +++ b/vllm/attention/backends/differential_flash_attn.py @@ -676,8 +676,6 @@ def __init__( self.used_shared_kv_cache = \ self.differential_flash_attention_config.get( "used_shared_kv_cache", False) - # if kv_sharing_target_layer_name is not None: - # raise NotImplementedError("KV sharing is not supported in V0.") self.kv_sharing_target_layer_name = kv_sharing_target_layer_name if blocksparse_params is not None: raise ValueError( diff --git a/vllm/attention/backends/dual_chunk_flash_attn.py b/vllm/attention/backends/dual_chunk_flash_attn.py index f62a43b441f..40557a4e8f8 100644 --- a/vllm/attention/backends/dual_chunk_flash_attn.py +++ b/vllm/attention/backends/dual_chunk_flash_attn.py @@ -295,7 +295,8 @@ def __init__( dual_chunk_attention_config: Optional[Dict[str, Any]] = None, ) -> None: if kv_sharing_target_layer_name is not None: - raise NotImplementedError("KV sharing is not supported in V0.") + raise NotImplementedError("KV sharing is not supported in V0 " + "DUAL_CHUNK_FLASH_ATTN backend.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index bf8e373802f..20e67eb9b40 100755 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -622,7 +622,8 @@ def __init__( use_irope: bool = False, ) -> None: if kv_sharing_target_layer_name is not None: - raise NotImplementedError("KV sharing is not supported in V0.") + raise NotImplementedError("KV sharing is not supported in V0 " + "FLASH_ATTN backend.") if blocksparse_params is not None: raise ValueError( "FlashAttention does not support block-sparse attention.") diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 5bbe340b143..1f913ad8952 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -1006,7 +1006,8 @@ def __init__( use_irope: bool = False, ) -> None: if kv_sharing_target_layer_name is not None: - raise NotImplementedError("KV sharing is not supported in V0.") + raise NotImplementedError("KV sharing is not supported in V0 " + "FLASHINFER backend.") if use_irope: logger.warning_once( "Using irope in FlashInfer is not supported yet, it will fall" diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index bf778a1e501..b8fdf763a04 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -115,7 +115,8 @@ def __init__( ) -> None: super(AttentionImpl, self).__init__() if kv_sharing_target_layer_name is not None: - raise NotImplementedError("KV sharing is not supported in V0.") + raise NotImplementedError("KV sharing is not supported in V0 " + "HPU_ATTN backend.") if use_irope: logger.warning_once( "Using irope in HPU is not supported yet, it will fall back " diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 0b7783758dd..4653d5267e1 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -501,7 +501,8 @@ def __init__( use_irope: bool = False, ) -> None: if kv_sharing_target_layer_name is not None: - raise NotImplementedError("KV sharing is not supported in V0.") + raise NotImplementedError("KV sharing is not supported in V0 " + "ROCM_FLASH backend.") if use_irope: logger.warning_once( "Using irope in ROCm Flash Attention is not supported yet, it " diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index b583240c73c..3ef79bb6212 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -394,7 +394,8 @@ def __init__( use_irope: bool = False, ) -> None: if kv_sharing_target_layer_name is not None: - raise NotImplementedError("KV sharing is not supported in V0.") + raise NotImplementedError("KV sharing is not supported in V0 " + "XFORMERS backend.") if blocksparse_params is not None: raise ValueError( "XFormers does not support block-sparse attention.") diff --git a/vllm/model_executor/models/phi4flash.py b/vllm/model_executor/models/phi4flash.py index 1ded2ff476f..ebe027e0558 100644 --- a/vllm/model_executor/models/phi4flash.py +++ b/vllm/model_executor/models/phi4flash.py @@ -9,6 +9,7 @@ from transformers.activations import ACT2FN from vllm.attention import Attention, AttentionMetadata, AttentionType +from vllm.attention.selector import _Backend from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.forward_context import ForwardContext, get_forward_context @@ -173,6 +174,8 @@ def __init__(self, attn_type=AttentionType.DECODER, kv_sharing_target_layer_name=kv_sharing_target_layer_name, **params) + assert self.attn.backend == _Backend.DIFFERENTIAL_FLASH_ATTN,\ + "DIFFERENTIAL_FLASH_ATTN required" def lambda_init_fn(self, depth): return 0.8 - 0.6 * math.exp(-0.3 * depth) @@ -433,8 +436,6 @@ def __init__( self.yoco_mb = False self.yoco_cross = False - assert config.num_hidden_layers % 4 == 0, \ - 'n_layer should be divisible by 4 for SambaY + yoco' if layer_idx >= config.num_hidden_layers // 2: self.yoco_mb = True self.yoco_cross = (layer_idx @@ -608,11 +609,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): scheduler_config = vllm_config.scheduler_config self.compilation_config = vllm_config.compilation_config self.vllm_config = vllm_config - # Prefix caching is not supported since there are mamba layers in this - # mode. + # Prefix caching and chunked prefill is not supported for this model. assert not cache_config.enable_prefix_caching, \ "Phi4flash currently does not support prefix caching" - + assert not scheduler_config.chunked_prefill_enabled, \ + "Phi4Flash currently does not support prefix caching" super().__init__() self.config = config self.model_config = vllm_config.model_config diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index eea64072f3a..fd8f0ad91b6 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -2890,7 +2890,7 @@ def get_mp_context(): def bind_kv_cache( ctx: dict[str, Any], kv_cache: list[list[torch.Tensor]], # [virtual_engine][layer_index] - shared_kv_cache_layers: dict[str, str], + shared_kv_cache_layers: Optional[dict[str, str]] = None ) -> None: # Bind the kv_cache tensor to Attention modules, similar to # ctx[layer_name].kv_cache[ve]=kv_cache[ve][extract_layer_index(layer_name)] @@ -2902,6 +2902,10 @@ def bind_kv_cache( # attention of the same layer (e.g., bart's decoder.layers.1.self_attn # and decoder.layers.1.encoder_attn) is mapped to the same kv cache # tensor + # 5. Some models have attention layers that share kv cache with previous + # layers, this is specified through shared_kv_cache_layers + if shared_kv_cache_layers is None: + shared_kv_cache_layers = {} from vllm.attention import AttentionType from vllm.model_executor.models.utils import extract_layer_index layer_need_kv_cache = [ @@ -2913,16 +2917,19 @@ def bind_kv_cache( set( extract_layer_index(layer_name) for layer_name in layer_need_kv_cache)) + for layer_name in layer_need_kv_cache: + # 1. Get the kv_cache_idx of the target_layer_name. target_layer_name = shared_kv_cache_layers.get(layer_name, layer_name) kv_cache_idx = layer_index_sorted.index( extract_layer_index(target_layer_name)) + + # 2. Bind kv_cache to forward_ctx. forward_ctx = ctx[layer_name] assert len(forward_ctx.kv_cache) == len(kv_cache) - for ve, ve_kv_cache in enumerate(kv_cache): assert kv_cache_idx < len(ve_kv_cache), \ - "v0 doesn't support interleaving kv sharing, use v1 instead" + "v0 doesn't support interleaving kv sharing" forward_ctx.kv_cache[ve] = ve_kv_cache[kv_cache_idx] From bc52add7abf08b20584d87db2b4a83a30ac66884 Mon Sep 17 00:00:00 2001 From: Congcong Chen Date: Fri, 11 Jul 2025 09:23:08 +0000 Subject: [PATCH 20/24] address comments Signed-off-by: Congcong Chen --- tests/models/registry.py | 3 ++- vllm/model_executor/models/phi4flash.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/models/registry.py b/tests/models/registry.py index 02d90ff35fa..026d5c74b96 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -249,7 +249,8 @@ def check_available_online( trust_remote_code=True, v0_only=True), "Phi4FlashForCausalLM": _HfExamplesInfo("microsoft/Phi-4-mini-flash-reasoning", # noqa: E501 - trust_remote_code=True), + trust_remote_code=True, + v0_only=True), "PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct", trust_remote_code=True), "Plamo2ForCausalLM": _HfExamplesInfo("pfnet/plamo-2-1b", diff --git a/vllm/model_executor/models/phi4flash.py b/vllm/model_executor/models/phi4flash.py index ebe027e0558..498f33aef0c 100644 --- a/vllm/model_executor/models/phi4flash.py +++ b/vllm/model_executor/models/phi4flash.py @@ -8,6 +8,7 @@ import torch.nn as nn from transformers.activations import ACT2FN +import vllm.envs as envs from vllm.attention import Attention, AttentionMetadata, AttentionType from vllm.attention.selector import _Backend from vllm.config import CacheConfig, VllmConfig @@ -563,7 +564,7 @@ def forward( # the kv cache since we reuse the kv cache from last layer. # If in prefill phase, we can prune> truncate # the hidden state to save computation cost. - if attn_metadata.prefill_metadata: + if attn_metadata.prefill_metadata and not envs.VLLM_USE_V1: selected_token_indices = torch.cumsum( attn_metadata.seq_lens_tensor, dim=0) - 1 hidden_states = hidden_states.index_select( From 88cb7969f532757e67ee9402f1ee883109d61e5e Mon Sep 17 00:00:00 2001 From: Congcong Chen Date: Fri, 11 Jul 2025 09:39:56 +0000 Subject: [PATCH 21/24] address comments Signed-off-by: Congcong Chen --- vllm/attention/backends/differential_flash_attn.py | 4 +--- vllm/model_executor/models/phi4flash.py | 3 ++- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/vllm/attention/backends/differential_flash_attn.py b/vllm/attention/backends/differential_flash_attn.py index 5435adcd64b..0e2c868893d 100644 --- a/vllm/attention/backends/differential_flash_attn.py +++ b/vllm/attention/backends/differential_flash_attn.py @@ -673,9 +673,7 @@ def __init__( differential_flash_attention_config = {} self.differential_flash_attention_config = \ differential_flash_attention_config - self.used_shared_kv_cache = \ - self.differential_flash_attention_config.get( - "used_shared_kv_cache", False) + self.used_shared_kv_cache = kv_sharing_target_layer_name is not None self.kv_sharing_target_layer_name = kv_sharing_target_layer_name if blocksparse_params is not None: raise ValueError( diff --git a/vllm/model_executor/models/phi4flash.py b/vllm/model_executor/models/phi4flash.py index 498f33aef0c..10f8b6552af 100644 --- a/vllm/model_executor/models/phi4flash.py +++ b/vllm/model_executor/models/phi4flash.py @@ -147,7 +147,6 @@ def __init__(self, params = { 'differential_flash_attention_config': { - 'used_shared_kv_cache': self.yoco_cross, 'lambda_init': self.lambda_init, 'lambda_q1': self.lambda_q1, 'lambda_k1': self.lambda_k1, @@ -661,6 +660,8 @@ def forward( mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) attn_metadata = get_forward_context().attn_metadata + # input_ids and hidden_states isn't a one-to-one mapping in prefill + # stage due to YOCO optimization. hidden_states = self.model(input_ids, positions, attn_metadata, mamba_cache_params, intermediate_tensors, inputs_embeds) From 1ff7eeb66dd6ed26b7b97aa684a3dcea757fa6bf Mon Sep 17 00:00:00 2001 From: Congcong Chen Date: Fri, 11 Jul 2025 20:36:10 +0000 Subject: [PATCH 22/24] refactor Signed-off-by: Congcong Chen --- vllm/utils/__init__.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index fd8f0ad91b6..53afc4b0f62 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -2918,13 +2918,15 @@ def bind_kv_cache( extract_layer_index(layer_name) for layer_name in layer_need_kv_cache)) + # Map from layer_name to the kv cache layer idx. + layer_name_2_kv_cache_index = dict() for layer_name in layer_need_kv_cache: - # 1. Get the kv_cache_idx of the target_layer_name. target_layer_name = shared_kv_cache_layers.get(layer_name, layer_name) kv_cache_idx = layer_index_sorted.index( extract_layer_index(target_layer_name)) + layer_name_2_kv_cache_index[layer_name] = kv_cache_idx - # 2. Bind kv_cache to forward_ctx. + for layer_name, kv_cache_idx in layer_name_2_kv_cache_index.items(): forward_ctx = ctx[layer_name] assert len(forward_ctx.kv_cache) == len(kv_cache) for ve, ve_kv_cache in enumerate(kv_cache): From e1f2ca9a84a6f60ab1575ff0785975c48b57fe83 Mon Sep 17 00:00:00 2001 From: Congcong Chen Date: Sat, 12 Jul 2025 05:02:19 +0000 Subject: [PATCH 23/24] update Signed-off-by: Congcong Chen --- tests/models/registry.py | 3 ++- vllm/attention/backends/differential_flash_attn.py | 5 +++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/models/registry.py b/tests/models/registry.py index 026d5c74b96..c10d375683e 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -250,7 +250,8 @@ def check_available_online( v0_only=True), "Phi4FlashForCausalLM": _HfExamplesInfo("microsoft/Phi-4-mini-flash-reasoning", # noqa: E501 trust_remote_code=True, - v0_only=True), + v0_only=True, + max_model_len=10240), "PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct", trust_remote_code=True), "Plamo2ForCausalLM": _HfExamplesInfo("pfnet/plamo-2-1b", diff --git a/vllm/attention/backends/differential_flash_attn.py b/vllm/attention/backends/differential_flash_attn.py index 0e2c868893d..7c35e58967d 100644 --- a/vllm/attention/backends/differential_flash_attn.py +++ b/vllm/attention/backends/differential_flash_attn.py @@ -432,6 +432,11 @@ def _add_seq_group( 2. block table. 3. slot mapping. """ + # TODO: add support for chunked prefill and prefix caching. + assert not chunked_prefill_enabled, \ + "chunked prefill is not supported for now" + assert not prefix_cache_hit, "prefix caching is not supported for now" + is_prompt = inter_data.is_prompt block_tables = inter_data.block_tables From aeb9282f7b6e34adb3ce5a1a962ffba561e4b844 Mon Sep 17 00:00:00 2001 From: Congcong Chen Date: Sat, 12 Jul 2025 06:50:29 +0000 Subject: [PATCH 24/24] refactor Signed-off-by: Congcong Chen --- vllm/utils/__init__.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index 53afc4b0f62..495e359aa6d 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -2911,28 +2911,26 @@ def bind_kv_cache( layer_need_kv_cache = [ layer_name for layer_name in ctx if (hasattr(ctx[layer_name], 'attn_type') and ctx[layer_name].attn_type - in (AttentionType.DECODER, AttentionType.ENCODER_DECODER)) + in (AttentionType.DECODER, AttentionType.ENCODER_DECODER)) \ + and ctx[layer_name].kv_sharing_target_layer_name is None ] layer_index_sorted = sorted( set( extract_layer_index(layer_name) for layer_name in layer_need_kv_cache)) - - # Map from layer_name to the kv cache layer idx. - layer_name_2_kv_cache_index = dict() for layer_name in layer_need_kv_cache: - target_layer_name = shared_kv_cache_layers.get(layer_name, layer_name) kv_cache_idx = layer_index_sorted.index( - extract_layer_index(target_layer_name)) - layer_name_2_kv_cache_index[layer_name] = kv_cache_idx - - for layer_name, kv_cache_idx in layer_name_2_kv_cache_index.items(): + extract_layer_index(layer_name)) forward_ctx = ctx[layer_name] assert len(forward_ctx.kv_cache) == len(kv_cache) for ve, ve_kv_cache in enumerate(kv_cache): - assert kv_cache_idx < len(ve_kv_cache), \ - "v0 doesn't support interleaving kv sharing" forward_ctx.kv_cache[ve] = ve_kv_cache[kv_cache_idx] + if shared_kv_cache_layers is not None: + for layer_name, target_layer_name in shared_kv_cache_layers.items(): + assert extract_layer_index(target_layer_name) < \ + extract_layer_index(layer_name), \ + "v0 doesn't support interleaving kv sharing" + ctx[layer_name].kv_cache = ctx[target_layer_name].kv_cache def run_method(obj: Any, method: Union[str, bytes, Callable], args: tuple[Any],