|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | + |
| 4 | +from dataclasses import dataclass |
| 5 | +from typing import Any, Dict, List, Optional, Tuple, Type |
| 6 | + |
| 7 | +import torch |
| 8 | + |
| 9 | +import vllm._custom_ops as ops |
| 10 | +from vllm._ipex_ops import ipex_ops |
| 11 | +from vllm.attention.backends.abstract import (AttentionBackend, |
| 12 | + AttentionMetadataBuilder, |
| 13 | + AttentionType, |
| 14 | + is_quantized_kv_cache) |
| 15 | +from vllm.attention.backends.mla.common import MLACommonImpl, MLACommonState |
| 16 | +from vllm.attention.backends.torch_sdpa import TorchSDPAMetadata |
| 17 | +from vllm.utils import make_tensor_with_pad |
| 18 | +from vllm.worker.cpu_model_runner import ModelInputForCPUBuilder |
| 19 | + |
| 20 | + |
| 21 | +class CPUMLABackend(AttentionBackend): |
| 22 | + |
| 23 | + @staticmethod |
| 24 | + def get_name() -> str: |
| 25 | + return "CPU_MLA" |
| 26 | + |
| 27 | + @staticmethod |
| 28 | + def get_metadata_cls() -> Type["CPUMLAMetadata"]: |
| 29 | + return CPUMLAMetadata |
| 30 | + |
| 31 | + @staticmethod |
| 32 | + def get_builder_cls() -> Type["CPUMLAMetadataBuilder"]: |
| 33 | + return CPUMLAMetadataBuilder |
| 34 | + |
| 35 | + @staticmethod |
| 36 | + def get_state_cls() -> Type["MLACommonState"]: |
| 37 | + return MLACommonState |
| 38 | + |
| 39 | + @staticmethod |
| 40 | + def get_impl_cls() -> Type["CPUMLAImpl"]: |
| 41 | + return CPUMLAImpl |
| 42 | + |
| 43 | + @staticmethod |
| 44 | + def get_kv_cache_shape( |
| 45 | + num_blocks: int, |
| 46 | + block_size: int, |
| 47 | + num_kv_heads: int, # assumed to be 1 for MLA |
| 48 | + head_size: int, |
| 49 | + ) -> Tuple[int, ...]: |
| 50 | + return (num_blocks, block_size, head_size) |
| 51 | + |
| 52 | + @staticmethod |
| 53 | + def swap_blocks( |
| 54 | + src_kv_cache: torch.Tensor, |
| 55 | + dst_kv_cache: torch.Tensor, |
| 56 | + src_to_dst: torch.Tensor, |
| 57 | + ) -> None: |
| 58 | + ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) |
| 59 | + |
| 60 | + @staticmethod |
| 61 | + def copy_blocks( |
| 62 | + kv_caches: List[torch.Tensor], |
| 63 | + src_to_dists: torch.Tensor, |
| 64 | + ) -> None: |
| 65 | + ops.copy_blocks_mla(kv_caches, src_to_dists) |
| 66 | + |
| 67 | + @staticmethod |
| 68 | + def get_supported_head_sizes() -> List[int]: |
| 69 | + return [576] |
| 70 | + |
| 71 | + |
| 72 | +@dataclass |
| 73 | +class CPUMLAMetadata(TorchSDPAMetadata): |
| 74 | + # New for MLA |
| 75 | + # Input positions for rotrary embeddings since for MLA the rotary |
| 76 | + # position embeddings are applied inside the attention backend |
| 77 | + input_positions: torch.Tensor = None |
| 78 | + |
| 79 | + # required by MLACommonImpl |
| 80 | + is_profile_run: bool = False |
| 81 | + |
| 82 | + |
| 83 | +class CPUMLAMetadataBuilder(AttentionMetadataBuilder[CPUMLAMetadata]): |
| 84 | + |
| 85 | + def __init__(self, input_builder: ModelInputForCPUBuilder) -> None: |
| 86 | + self.chunked_prefill = input_builder.chunked_prefill |
| 87 | + self.input_builder = input_builder |
| 88 | + assert not self.chunked_prefill, \ |
| 89 | + "chunked prefill is currently not supported" |
| 90 | + |
| 91 | + def prepare(self): |
| 92 | + self.input_data = self.input_builder.input_data |
| 93 | + |
| 94 | + def build(self, seq_lens, query_lens, cuda_graph_pad_size, batch_size): |
| 95 | + input_data = self.input_data |
| 96 | + prefill_seq_lens = seq_lens[0:input_data.num_prefills] |
| 97 | + prefill_query_lens = query_lens[0:input_data.num_prefills] |
| 98 | + slot_mapping = torch.tensor(input_data.slot_mapping, |
| 99 | + dtype=torch.long, |
| 100 | + device="cpu") |
| 101 | + |
| 102 | + # metadata for prefill |
| 103 | + if input_data.num_prefills > 0: |
| 104 | + query_lens_tensor = torch.tensor(prefill_query_lens, |
| 105 | + dtype=torch.int32, |
| 106 | + device="cpu") |
| 107 | + kv_lens_tensor = torch.tensor(prefill_seq_lens, |
| 108 | + dtype=torch.int32, |
| 109 | + device="cpu") |
| 110 | + query_start_loc = torch.zeros(input_data.num_prefills + 1, |
| 111 | + dtype=torch.int32, |
| 112 | + device="cpu") |
| 113 | + kv_start_loc = torch.zeros(input_data.num_prefills + 1, |
| 114 | + dtype=torch.int32, |
| 115 | + device="cpu") |
| 116 | + torch.cumsum(query_lens_tensor, |
| 117 | + dim=0, |
| 118 | + dtype=torch.int32, |
| 119 | + out=query_start_loc[1:]) |
| 120 | + torch.cumsum(kv_lens_tensor, |
| 121 | + dim=0, |
| 122 | + dtype=torch.int32, |
| 123 | + out=kv_start_loc[1:]) |
| 124 | + max_query_len = max(prefill_query_lens) |
| 125 | + max_kv_len = max(prefill_seq_lens) |
| 126 | + |
| 127 | + # for chunked-prefill |
| 128 | + if self.chunked_prefill: |
| 129 | + prefill_block_tables = make_tensor_with_pad( |
| 130 | + self.input_data.prefill_block_tables, |
| 131 | + pad=0, |
| 132 | + dtype=torch.int32, |
| 133 | + device="cpu", |
| 134 | + ) |
| 135 | + else: |
| 136 | + prefill_block_tables = None |
| 137 | + |
| 138 | + else: |
| 139 | + query_start_loc = None |
| 140 | + kv_start_loc = None |
| 141 | + max_query_len = None |
| 142 | + max_kv_len = None |
| 143 | + prefill_block_tables = None |
| 144 | + |
| 145 | + # metadata for decode |
| 146 | + if input_data.num_decode_tokens != 0: |
| 147 | + seq_lens_tensor = torch.tensor( |
| 148 | + input_data.seq_lens[input_data.num_prefills:], |
| 149 | + dtype=torch.int32, |
| 150 | + device="cpu", |
| 151 | + ) |
| 152 | + block_tables = make_tensor_with_pad( |
| 153 | + self.input_data.decode_block_tables, |
| 154 | + pad=0, |
| 155 | + dtype=torch.int32, |
| 156 | + device="cpu", |
| 157 | + ) |
| 158 | + else: |
| 159 | + block_tables = torch.tensor([]) |
| 160 | + seq_lens_tensor = torch.tensor( |
| 161 | + input_data.seq_lens[:input_data.num_prefills], |
| 162 | + dtype=torch.int32, |
| 163 | + device="cpu", |
| 164 | + ) |
| 165 | + |
| 166 | + # For multi-modal models |
| 167 | + placeholder_index_maps = None |
| 168 | + if len(input_data.multi_modal_inputs_list) != 0: |
| 169 | + placeholder_index_maps = { |
| 170 | + modality: placeholder_map.index_map() |
| 171 | + for modality, placeholder_map in |
| 172 | + input_data.multi_modal_placeholder_maps.items() |
| 173 | + } |
| 174 | + |
| 175 | + return CPUMLAMetadata( |
| 176 | + chunked_prefill=self.chunked_prefill, |
| 177 | + seq_lens=prefill_seq_lens, |
| 178 | + seq_lens_tensor=seq_lens_tensor, |
| 179 | + max_query_len=max_query_len, |
| 180 | + max_kv_len=max_kv_len, |
| 181 | + prefill_query_start_loc=query_start_loc, |
| 182 | + kv_start_loc=kv_start_loc, |
| 183 | + max_decode_seq_len=input_data.max_decode_seq_len, |
| 184 | + num_prefills=input_data.num_prefills, |
| 185 | + num_prefill_tokens=input_data.num_prefill_tokens, |
| 186 | + num_decode_tokens=input_data.num_decode_tokens, |
| 187 | + block_tables=block_tables, |
| 188 | + prefill_block_tables=prefill_block_tables, |
| 189 | + slot_mapping=slot_mapping, |
| 190 | + multi_modal_placeholder_index_maps=placeholder_index_maps, |
| 191 | + enable_kv_scales_calculation=False, |
| 192 | + input_positions=torch.tensor([self.input_data.input_positions])) |
| 193 | + |
| 194 | + |
| 195 | +class CPUMLAImpl(MLACommonImpl[CPUMLAMetadata]): |
| 196 | + |
| 197 | + def __init__( |
| 198 | + self, |
| 199 | + num_heads: int, |
| 200 | + head_size: int, |
| 201 | + scale: float, |
| 202 | + num_kv_heads: int, |
| 203 | + alibi_slopes: Optional[List[float]], |
| 204 | + sliding_window: Optional[int], |
| 205 | + kv_cache_dtype: str, |
| 206 | + blocksparse_params: Optional[Dict[str, Any]], |
| 207 | + logits_soft_cap: Optional[float], |
| 208 | + attn_type: str, |
| 209 | + kv_sharing_target_layer_name: Optional[str], |
| 210 | + # MLA Specific Arguments |
| 211 | + **mla_args) -> None: |
| 212 | + super().__init__(num_heads, head_size, scale, num_kv_heads, |
| 213 | + alibi_slopes, sliding_window, kv_cache_dtype, |
| 214 | + blocksparse_params, logits_soft_cap, attn_type, |
| 215 | + kv_sharing_target_layer_name, **mla_args) |
| 216 | + |
| 217 | + unsupported_features = [ |
| 218 | + alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap |
| 219 | + ] |
| 220 | + if any(unsupported_features): |
| 221 | + raise NotImplementedError( |
| 222 | + "CPUMLAImpl does not support one of the following: " |
| 223 | + "alibi_slopes, sliding_window, blocksparse_params, " |
| 224 | + "logits_soft_cap") |
| 225 | + |
| 226 | + if attn_type != AttentionType.DECODER: |
| 227 | + raise NotImplementedError("Encoder self-attention and " |
| 228 | + "encoder/decoder cross-attention " |
| 229 | + "are not implemented for " |
| 230 | + "CPUMLAImpl") |
| 231 | + |
| 232 | + # states is implemented. |
| 233 | + if is_quantized_kv_cache(self.kv_cache_dtype): |
| 234 | + raise NotImplementedError( |
| 235 | + "CPUMLAImpl with FP8 KV cache not yet supported") |
| 236 | + |
| 237 | + def _forward_prefill( |
| 238 | + self, |
| 239 | + q: torch.Tensor, |
| 240 | + kv_c_normed: torch.Tensor, |
| 241 | + k_pe: torch.Tensor, |
| 242 | + kv_c_and_k_pe_cache: torch.Tensor, |
| 243 | + attn_metadata: CPUMLAMetadata, # type: ignore[override] |
| 244 | + ) -> torch.Tensor: |
| 245 | + |
| 246 | + prefill_metadata = attn_metadata.prefill_metadata |
| 247 | + assert prefill_metadata is not None |
| 248 | + |
| 249 | + kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\ |
| 250 | + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) |
| 251 | + k_nope, v = kv_nope\ |
| 252 | + .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) |
| 253 | + |
| 254 | + k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) |
| 255 | + |
| 256 | + # For MLA the v head dim is smaller than qk head dim so we pad out |
| 257 | + # v with 0s to match the qk head dim |
| 258 | + v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], |
| 259 | + value=0) |
| 260 | + |
| 261 | + output = torch.empty_like(q) |
| 262 | + ipex_ops.varlen_attention( |
| 263 | + query=q, |
| 264 | + key=k, |
| 265 | + value=v_padded, |
| 266 | + out=output, |
| 267 | + seqlen_q=prefill_metadata.prefill_query_start_loc, |
| 268 | + seqlen_k=prefill_metadata.prefill_query_start_loc, |
| 269 | + max_seqlen_q=prefill_metadata.max_query_len, |
| 270 | + max_seqlen_k=prefill_metadata.max_query_len, |
| 271 | + pdropout=0.0, |
| 272 | + softmax_scale=self.scale, |
| 273 | + zero_tensors=False, |
| 274 | + is_causal=True, |
| 275 | + return_softmax=False, |
| 276 | + gen_=None, |
| 277 | + logits_soft_cap=0.0, |
| 278 | + window_size_left=-1, |
| 279 | + window_size_right=-1, |
| 280 | + alibi_slopes=None, |
| 281 | + ) |
| 282 | + |
| 283 | + # remove padding |
| 284 | + output = output.view(-1, self.num_heads, |
| 285 | + q.shape[-1])[..., :v.shape[-1]] |
| 286 | + return output.reshape(-1, self.num_heads * v.shape[-1]) |
| 287 | + |
| 288 | + def _forward_decode( |
| 289 | + self, |
| 290 | + q_nope: torch.Tensor, |
| 291 | + q_pe: torch.Tensor, |
| 292 | + kv_c_and_k_pe_cache: torch.Tensor, |
| 293 | + attn_metadata: CPUMLAMetadata, # type: ignore[override] |
| 294 | + ) -> torch.Tensor: |
| 295 | + assert kv_c_and_k_pe_cache.numel() > 0 |
| 296 | + |
| 297 | + decode_meta = attn_metadata.decode_metadata |
| 298 | + assert decode_meta is not None |
| 299 | + |
| 300 | + q = torch.cat([q_nope, q_pe], dim=-1) |
| 301 | + o = q.new_empty(q.shape[0], self.num_heads, self.kv_lora_rank) |
| 302 | + |
| 303 | + # Run MQA |
| 304 | + ops.mla_decode_kvcache_cpu(o, q, kv_c_and_k_pe_cache, self.scale, |
| 305 | + decode_meta.block_tables, |
| 306 | + decode_meta.seq_lens_tensor) |
| 307 | + return self._v_up_proj(o) |
0 commit comments