Skip to content

Commit a5dd03c

Browse files
committed
Revert "[V0 deprecation] Remove V0 CPU/XPU/TPU backends (vllm-project#20412)"
This reverts commit e202dd2.
1 parent c18b3b8 commit a5dd03c

File tree

20 files changed

+5034
-46
lines changed

20 files changed

+5034
-46
lines changed

.buildkite/scripts/hardware_ci/run-cpu-test.sh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,10 @@ function cpu_tests() {
6666
tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_dynamic_per_token"
6767

6868
# Run AWQ test
69-
# docker exec cpu-test-"$NUMA_NODE" bash -c "
70-
# set -e
71-
# VLLM_USE_V1=0 pytest -s -v \
72-
# tests/quantization/test_ipex_quant.py"
69+
docker exec cpu-test-"$NUMA_NODE" bash -c "
70+
set -e
71+
VLLM_USE_V1=0 pytest -s -v \
72+
tests/quantization/test_ipex_quant.py"
7373

7474
# Run chunked-prefill and prefix-cache test
7575
docker exec cpu-test-"$NUMA_NODE" bash -c "

.buildkite/scripts/hardware_ci/run-xpu-test.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,5 +26,7 @@ docker run \
2626
--name "${container_name}" \
2727
"${image_name}" \
2828
sh -c '
29+
VLLM_USE_V1=0 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m
30+
VLLM_USE_V1=0 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m -tp 2
2931
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager
3032
'

examples/online_serving/chart-helm/values.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ image:
88
# -- Image tag
99
tag: "latest"
1010
# -- Container launch command
11-
command: ["vllm", "serve", "/data/", "--served-model-name", "opt-125m", "--enforce-eager", "--dtype", "bfloat16", "--block-size", "16", "--host", "0.0.0.0", "--port", "8000"]
11+
command: ["vllm", "serve", "/data/", "--served-model-name", "opt-125m", "--dtype", "float32", "--block-size", "16", "--host", "0.0.0.0", "--port", "8000"]
1212

1313
# -- Container port
1414
containerPort: 8000

tests/kernels/attention/test_attention_selector.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,7 @@ def clear_cache():
3636
DEVICE_MLA_BLOCK_SIZES = {
3737
"cuda": [16, 64], # CUDA supports both standard and extended block sizes
3838
"hip": [16, 1], # HIP requires special handling for block_size=1
39-
# "cpu": [16] # CPU uses fixed block size from test cases
40-
"cpu": [] # FIXME(woosuk): Temporarily disable CPU tests
39+
"cpu": [16] # CPU uses fixed block size from test cases
4140
}
4241

4342

@@ -82,14 +81,14 @@ def test_env(
8281
m.setenv("VLLM_MLA_DISABLE", "1" if use_mla else "0")
8382

8483
if device == "cpu":
85-
if not use_v1:
86-
pytest.skip("CPU backend only supports V1")
87-
8884
with patch("vllm.attention.selector.current_platform",
8985
CpuPlatform()):
9086
backend = get_attn_backend(16, torch.float16, torch.float16,
9187
block_size, False)
92-
assert backend.get_name() == "TORCH_SDPA_VLLM_V1"
88+
if use_v1:
89+
assert backend.get_name() == "TORCH_SDPA_VLLM_V1"
90+
else:
91+
assert backend.get_name() == "TORCH_SDPA"
9392

9493
elif device == "hip":
9594
with patch("vllm.attention.selector.current_platform",
@@ -205,14 +204,12 @@ def test_fp32_fallback(
205204
m.setenv("VLLM_USE_V1", "1" if use_v1 else "0")
206205

207206
if device == "cpu":
208-
if not use_v1:
209-
pytest.skip("CPU backend only supports V1")
210-
211207
with patch("vllm.attention.selector.current_platform",
212208
CpuPlatform()):
213209
backend = get_attn_backend(16, torch.float32, torch.float32,
214210
16, False)
215-
assert backend.get_name() == "TORCH_SDPA_VLLM_V1"
211+
assert (backend.get_name() == "TORCH_SDPA_VLLM_V1"
212+
if use_v1 else "TORCH_SDPA")
216213

217214
elif device == "cuda":
218215
with patch("vllm.attention.selector.current_platform",

vllm/attention/backends/cpu_mla.py

Lines changed: 307 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,307 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
from dataclasses import dataclass
5+
from typing import Any, Dict, List, Optional, Tuple, Type
6+
7+
import torch
8+
9+
import vllm._custom_ops as ops
10+
from vllm._ipex_ops import ipex_ops
11+
from vllm.attention.backends.abstract import (AttentionBackend,
12+
AttentionMetadataBuilder,
13+
AttentionType,
14+
is_quantized_kv_cache)
15+
from vllm.attention.backends.mla.common import MLACommonImpl, MLACommonState
16+
from vllm.attention.backends.torch_sdpa import TorchSDPAMetadata
17+
from vllm.utils import make_tensor_with_pad
18+
from vllm.worker.cpu_model_runner import ModelInputForCPUBuilder
19+
20+
21+
class CPUMLABackend(AttentionBackend):
22+
23+
@staticmethod
24+
def get_name() -> str:
25+
return "CPU_MLA"
26+
27+
@staticmethod
28+
def get_metadata_cls() -> Type["CPUMLAMetadata"]:
29+
return CPUMLAMetadata
30+
31+
@staticmethod
32+
def get_builder_cls() -> Type["CPUMLAMetadataBuilder"]:
33+
return CPUMLAMetadataBuilder
34+
35+
@staticmethod
36+
def get_state_cls() -> Type["MLACommonState"]:
37+
return MLACommonState
38+
39+
@staticmethod
40+
def get_impl_cls() -> Type["CPUMLAImpl"]:
41+
return CPUMLAImpl
42+
43+
@staticmethod
44+
def get_kv_cache_shape(
45+
num_blocks: int,
46+
block_size: int,
47+
num_kv_heads: int, # assumed to be 1 for MLA
48+
head_size: int,
49+
) -> Tuple[int, ...]:
50+
return (num_blocks, block_size, head_size)
51+
52+
@staticmethod
53+
def swap_blocks(
54+
src_kv_cache: torch.Tensor,
55+
dst_kv_cache: torch.Tensor,
56+
src_to_dst: torch.Tensor,
57+
) -> None:
58+
ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
59+
60+
@staticmethod
61+
def copy_blocks(
62+
kv_caches: List[torch.Tensor],
63+
src_to_dists: torch.Tensor,
64+
) -> None:
65+
ops.copy_blocks_mla(kv_caches, src_to_dists)
66+
67+
@staticmethod
68+
def get_supported_head_sizes() -> List[int]:
69+
return [576]
70+
71+
72+
@dataclass
73+
class CPUMLAMetadata(TorchSDPAMetadata):
74+
# New for MLA
75+
# Input positions for rotrary embeddings since for MLA the rotary
76+
# position embeddings are applied inside the attention backend
77+
input_positions: torch.Tensor = None
78+
79+
# required by MLACommonImpl
80+
is_profile_run: bool = False
81+
82+
83+
class CPUMLAMetadataBuilder(AttentionMetadataBuilder[CPUMLAMetadata]):
84+
85+
def __init__(self, input_builder: ModelInputForCPUBuilder) -> None:
86+
self.chunked_prefill = input_builder.chunked_prefill
87+
self.input_builder = input_builder
88+
assert not self.chunked_prefill, \
89+
"chunked prefill is currently not supported"
90+
91+
def prepare(self):
92+
self.input_data = self.input_builder.input_data
93+
94+
def build(self, seq_lens, query_lens, cuda_graph_pad_size, batch_size):
95+
input_data = self.input_data
96+
prefill_seq_lens = seq_lens[0:input_data.num_prefills]
97+
prefill_query_lens = query_lens[0:input_data.num_prefills]
98+
slot_mapping = torch.tensor(input_data.slot_mapping,
99+
dtype=torch.long,
100+
device="cpu")
101+
102+
# metadata for prefill
103+
if input_data.num_prefills > 0:
104+
query_lens_tensor = torch.tensor(prefill_query_lens,
105+
dtype=torch.int32,
106+
device="cpu")
107+
kv_lens_tensor = torch.tensor(prefill_seq_lens,
108+
dtype=torch.int32,
109+
device="cpu")
110+
query_start_loc = torch.zeros(input_data.num_prefills + 1,
111+
dtype=torch.int32,
112+
device="cpu")
113+
kv_start_loc = torch.zeros(input_data.num_prefills + 1,
114+
dtype=torch.int32,
115+
device="cpu")
116+
torch.cumsum(query_lens_tensor,
117+
dim=0,
118+
dtype=torch.int32,
119+
out=query_start_loc[1:])
120+
torch.cumsum(kv_lens_tensor,
121+
dim=0,
122+
dtype=torch.int32,
123+
out=kv_start_loc[1:])
124+
max_query_len = max(prefill_query_lens)
125+
max_kv_len = max(prefill_seq_lens)
126+
127+
# for chunked-prefill
128+
if self.chunked_prefill:
129+
prefill_block_tables = make_tensor_with_pad(
130+
self.input_data.prefill_block_tables,
131+
pad=0,
132+
dtype=torch.int32,
133+
device="cpu",
134+
)
135+
else:
136+
prefill_block_tables = None
137+
138+
else:
139+
query_start_loc = None
140+
kv_start_loc = None
141+
max_query_len = None
142+
max_kv_len = None
143+
prefill_block_tables = None
144+
145+
# metadata for decode
146+
if input_data.num_decode_tokens != 0:
147+
seq_lens_tensor = torch.tensor(
148+
input_data.seq_lens[input_data.num_prefills:],
149+
dtype=torch.int32,
150+
device="cpu",
151+
)
152+
block_tables = make_tensor_with_pad(
153+
self.input_data.decode_block_tables,
154+
pad=0,
155+
dtype=torch.int32,
156+
device="cpu",
157+
)
158+
else:
159+
block_tables = torch.tensor([])
160+
seq_lens_tensor = torch.tensor(
161+
input_data.seq_lens[:input_data.num_prefills],
162+
dtype=torch.int32,
163+
device="cpu",
164+
)
165+
166+
# For multi-modal models
167+
placeholder_index_maps = None
168+
if len(input_data.multi_modal_inputs_list) != 0:
169+
placeholder_index_maps = {
170+
modality: placeholder_map.index_map()
171+
for modality, placeholder_map in
172+
input_data.multi_modal_placeholder_maps.items()
173+
}
174+
175+
return CPUMLAMetadata(
176+
chunked_prefill=self.chunked_prefill,
177+
seq_lens=prefill_seq_lens,
178+
seq_lens_tensor=seq_lens_tensor,
179+
max_query_len=max_query_len,
180+
max_kv_len=max_kv_len,
181+
prefill_query_start_loc=query_start_loc,
182+
kv_start_loc=kv_start_loc,
183+
max_decode_seq_len=input_data.max_decode_seq_len,
184+
num_prefills=input_data.num_prefills,
185+
num_prefill_tokens=input_data.num_prefill_tokens,
186+
num_decode_tokens=input_data.num_decode_tokens,
187+
block_tables=block_tables,
188+
prefill_block_tables=prefill_block_tables,
189+
slot_mapping=slot_mapping,
190+
multi_modal_placeholder_index_maps=placeholder_index_maps,
191+
enable_kv_scales_calculation=False,
192+
input_positions=torch.tensor([self.input_data.input_positions]))
193+
194+
195+
class CPUMLAImpl(MLACommonImpl[CPUMLAMetadata]):
196+
197+
def __init__(
198+
self,
199+
num_heads: int,
200+
head_size: int,
201+
scale: float,
202+
num_kv_heads: int,
203+
alibi_slopes: Optional[List[float]],
204+
sliding_window: Optional[int],
205+
kv_cache_dtype: str,
206+
blocksparse_params: Optional[Dict[str, Any]],
207+
logits_soft_cap: Optional[float],
208+
attn_type: str,
209+
kv_sharing_target_layer_name: Optional[str],
210+
# MLA Specific Arguments
211+
**mla_args) -> None:
212+
super().__init__(num_heads, head_size, scale, num_kv_heads,
213+
alibi_slopes, sliding_window, kv_cache_dtype,
214+
blocksparse_params, logits_soft_cap, attn_type,
215+
kv_sharing_target_layer_name, **mla_args)
216+
217+
unsupported_features = [
218+
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
219+
]
220+
if any(unsupported_features):
221+
raise NotImplementedError(
222+
"CPUMLAImpl does not support one of the following: "
223+
"alibi_slopes, sliding_window, blocksparse_params, "
224+
"logits_soft_cap")
225+
226+
if attn_type != AttentionType.DECODER:
227+
raise NotImplementedError("Encoder self-attention and "
228+
"encoder/decoder cross-attention "
229+
"are not implemented for "
230+
"CPUMLAImpl")
231+
232+
# states is implemented.
233+
if is_quantized_kv_cache(self.kv_cache_dtype):
234+
raise NotImplementedError(
235+
"CPUMLAImpl with FP8 KV cache not yet supported")
236+
237+
def _forward_prefill(
238+
self,
239+
q: torch.Tensor,
240+
kv_c_normed: torch.Tensor,
241+
k_pe: torch.Tensor,
242+
kv_c_and_k_pe_cache: torch.Tensor,
243+
attn_metadata: CPUMLAMetadata, # type: ignore[override]
244+
) -> torch.Tensor:
245+
246+
prefill_metadata = attn_metadata.prefill_metadata
247+
assert prefill_metadata is not None
248+
249+
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\
250+
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
251+
k_nope, v = kv_nope\
252+
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
253+
254+
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
255+
256+
# For MLA the v head dim is smaller than qk head dim so we pad out
257+
# v with 0s to match the qk head dim
258+
v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
259+
value=0)
260+
261+
output = torch.empty_like(q)
262+
ipex_ops.varlen_attention(
263+
query=q,
264+
key=k,
265+
value=v_padded,
266+
out=output,
267+
seqlen_q=prefill_metadata.prefill_query_start_loc,
268+
seqlen_k=prefill_metadata.prefill_query_start_loc,
269+
max_seqlen_q=prefill_metadata.max_query_len,
270+
max_seqlen_k=prefill_metadata.max_query_len,
271+
pdropout=0.0,
272+
softmax_scale=self.scale,
273+
zero_tensors=False,
274+
is_causal=True,
275+
return_softmax=False,
276+
gen_=None,
277+
logits_soft_cap=0.0,
278+
window_size_left=-1,
279+
window_size_right=-1,
280+
alibi_slopes=None,
281+
)
282+
283+
# remove padding
284+
output = output.view(-1, self.num_heads,
285+
q.shape[-1])[..., :v.shape[-1]]
286+
return output.reshape(-1, self.num_heads * v.shape[-1])
287+
288+
def _forward_decode(
289+
self,
290+
q_nope: torch.Tensor,
291+
q_pe: torch.Tensor,
292+
kv_c_and_k_pe_cache: torch.Tensor,
293+
attn_metadata: CPUMLAMetadata, # type: ignore[override]
294+
) -> torch.Tensor:
295+
assert kv_c_and_k_pe_cache.numel() > 0
296+
297+
decode_meta = attn_metadata.decode_metadata
298+
assert decode_meta is not None
299+
300+
q = torch.cat([q_nope, q_pe], dim=-1)
301+
o = q.new_empty(q.shape[0], self.num_heads, self.kv_lora_rank)
302+
303+
# Run MQA
304+
ops.mla_decode_kvcache_cpu(o, q, kv_c_and_k_pe_cache, self.scale,
305+
decode_meta.block_tables,
306+
decode_meta.seq_lens_tensor)
307+
return self._v_up_proj(o)

0 commit comments

Comments
 (0)