Skip to content

Commit b69781f

Browse files
authored
[Hardware][Intel GPU] Add v1 Intel GPU support with Flash attention backend. (#19560)
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
1 parent 0bceac9 commit b69781f

File tree

10 files changed

+394
-43
lines changed

10 files changed

+394
-43
lines changed

.buildkite/scripts/hardware_ci/run-xpu-test.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,5 @@ docker run \
2828
sh -c '
2929
VLLM_USE_V1=0 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m
3030
VLLM_USE_V1=0 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m -tp 2
31+
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager
3132
'

docker/Dockerfile.xpu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ RUN --mount=type=bind,source=.git,target=.git \
3535
if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh; fi
3636

3737
ENV VLLM_TARGET_DEVICE=xpu
38+
ENV VLLM_WORKER_MULTIPROC_METHOD=spawn
3839

3940
RUN --mount=type=cache,target=/root/.cache/pip \
4041
--mount=type=bind,source=.git,target=.git \

requirements/xpu.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ setuptools>=77.0.3,<80.0.0
99
wheel
1010
jinja2>=3.1.6
1111
datasets # for benchmark scripts
12+
numba == 0.60.0 # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding
1213

1314
torch==2.7.0+xpu
1415
torchaudio

vllm/_ipex_ops.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,111 @@ def reshape_and_cache(
228228
ipex.llm.modules.PagedAttention.reshape_and_cache(
229229
key, value, key_cache, value_cache, slot_mapping)
230230

231+
@staticmethod
232+
def reshape_and_cache_flash(
233+
key: torch.Tensor,
234+
value: torch.Tensor,
235+
key_cache: torch.Tensor,
236+
value_cache: torch.Tensor,
237+
slot_mapping: torch.Tensor,
238+
kv_cache_dtype: str,
239+
k_scale: Optional[torch.Tensor] = None,
240+
v_scale: Optional[torch.Tensor] = None,
241+
k_scale_float: float = 1.0,
242+
v_scale_float: float = 1.0,
243+
) -> None:
244+
assert kv_cache_dtype == "auto"
245+
# TODO: support FP8 kv cache.
246+
ipex.llm.modules.PagedAttention.reshape_and_cache_flash(
247+
key, value, key_cache, value_cache, slot_mapping)
248+
249+
@staticmethod
250+
def flash_attn_varlen_func(
251+
out: torch.Tensor,
252+
q: torch.Tensor,
253+
k: torch.Tensor,
254+
v: torch.Tensor,
255+
cu_seqlens_q: torch.Tensor,
256+
seqused_k: torch.Tensor, # we don't support this in ipex kernel
257+
max_seqlen_q: int,
258+
max_seqlen_k: int,
259+
softmax_scale: float,
260+
causal: bool,
261+
block_table: torch.Tensor,
262+
alibi_slopes: Optional[torch.Tensor],
263+
window_size: Optional[list[int]] = None,
264+
softcap: Optional[float] = 0.0,
265+
cu_seqlens_k: Optional[torch.Tensor] = None,
266+
# The following parameters are not used in ipex kernel currently,
267+
# we keep API compatible to CUDA's.
268+
scheduler_metadata=None,
269+
fa_version: int = 2,
270+
q_descale=None,
271+
k_descale=None,
272+
v_descale=None,
273+
):
274+
if cu_seqlens_k is None:
275+
# cu_seqlens_k is not used in ipex kernel.
276+
cu_seqlens_k = torch.cumsum(seqused_k, dim=0)
277+
cu_seqlens_k = torch.cat([
278+
torch.tensor([0], device=seqused_k.device, dtype=torch.int32),
279+
cu_seqlens_k
280+
]).to(torch.int32)
281+
282+
real_window_size: tuple[int, int]
283+
if window_size is None:
284+
real_window_size = (-1, -1)
285+
else:
286+
assert len(window_size) == 2
287+
real_window_size = (window_size[0], window_size[1])
288+
return ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
289+
out,
290+
q.contiguous(),
291+
k,
292+
v,
293+
cu_seqlens_q,
294+
cu_seqlens_k,
295+
max_seqlen_q,
296+
max_seqlen_k,
297+
softmax_scale,
298+
causal,
299+
block_table,
300+
alibi_slopes,
301+
softcap=softcap,
302+
window_size_left=real_window_size[0],
303+
window_size_right=real_window_size[1],
304+
k_scale=1.0,
305+
v_scale=1.0,
306+
)
307+
308+
@staticmethod
309+
def get_scheduler_metadata(
310+
batch_size,
311+
max_seqlen_q,
312+
max_seqlen_k,
313+
num_heads_q,
314+
num_heads_kv,
315+
headdim,
316+
cache_seqlens: torch.Tensor,
317+
qkv_dtype=torch.bfloat16,
318+
headdim_v=None,
319+
cu_seqlens_q: Optional[torch.Tensor] = None,
320+
cu_seqlens_k_new: Optional[torch.Tensor] = None,
321+
cache_leftpad: Optional[torch.Tensor] = None,
322+
page_size: Optional[int] = None,
323+
max_seqlen_k_new=0,
324+
causal=False,
325+
window_size=(-1, -1), # -1 means infinite context window
326+
has_softcap=False,
327+
num_splits=0, # Can be tuned for speed
328+
pack_gqa=None, # Can be tuned for speed
329+
sm_margin=0, # Can be tuned if some SMs are used for communication
330+
) -> None:
331+
logger.warning_once(
332+
"get_scheduler_metadata is not implemented for ipex_ops, "
333+
"returning None.")
334+
return None
335+
231336
@staticmethod
232337
def copy_blocks(key_caches: list[torch.Tensor],
233338
value_caches: list[torch.Tensor],

vllm/attention/utils/fa_utils.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,27 @@
44

55
from vllm import envs
66
from vllm.logger import init_logger
7+
from vllm.platforms import current_platform
78

89
logger = init_logger(__name__)
910

11+
if current_platform.is_cuda():
12+
from vllm import _custom_ops as ops
13+
reshape_and_cache_flash = ops.reshape_and_cache_flash
14+
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
15+
get_scheduler_metadata)
16+
elif current_platform.is_xpu():
17+
from vllm._ipex_ops import ipex_ops as ops
18+
reshape_and_cache_flash = ops.reshape_and_cache_flash
19+
flash_attn_varlen_func = ops.flash_attn_varlen_func
20+
get_scheduler_metadata = ops.get_scheduler_metadata
21+
1022

1123
def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]:
1224
# import here to avoid circular dependencies
1325
from vllm.platforms import current_platform
26+
if current_platform.is_xpu():
27+
return 2
1428
try:
1529
from vllm.vllm_flash_attn.flash_attn_interface import (
1630
fa_version_unsupported_reason, is_fa_version_supported)
@@ -50,6 +64,5 @@ def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]:
5064

5165

5266
def flash_attn_supports_fp8() -> bool:
53-
from vllm.platforms import current_platform
5467
return get_flash_attn_version() == 3 and \
5568
current_platform.get_device_capability().major == 9

vllm/executor/ray_distributed_executor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
7373

7474
def _init_executor(self) -> None:
7575
self.forward_dag: Optional[ray.dag.CompiledDAG] = None
76-
if envs.VLLM_USE_V1:
76+
if envs.VLLM_USE_V1 and not current_platform.is_xpu():
7777
# V1 uses SPMD worker and compiled DAG
7878
os.environ["VLLM_USE_RAY_SPMD_WORKER"] = "1"
7979
os.environ["VLLM_USE_RAY_COMPILED_DAG"] = "1"

vllm/platforms/xpu.py

Lines changed: 70 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,21 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4+
import os
45
from typing import TYPE_CHECKING, Optional
56

67
import torch
78

9+
import vllm.envs as envs
810
from vllm.logger import init_logger
911
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
1012

1113
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
1214

1315
if TYPE_CHECKING:
14-
from vllm.config import VllmConfig
16+
from vllm.config import ModelConfig, VllmConfig
1517
else:
18+
ModelConfig = None
1619
VllmConfig = None
1720

1821
logger = init_logger(__name__)
@@ -35,8 +38,13 @@ def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
3538
use_mla: bool) -> str:
3639
if selected_backend != _Backend.IPEX:
3740
logger.info("Cannot use %s backend on XPU.", selected_backend)
38-
logger.info("Using IPEX attention backend.")
39-
return "vllm.attention.backends.ipex_attn.IpexAttnBackend"
41+
use_v1 = envs.VLLM_USE_V1
42+
if use_v1:
43+
logger.info("Using Flash Attention backend on V1 engine.")
44+
return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
45+
else:
46+
logger.info("Using IPEX attention backend.")
47+
return "vllm.attention.backends.ipex_attn.IpexAttnBackend"
4048

4149
@classmethod
4250
def get_device_capability(
@@ -67,25 +75,27 @@ def inference_mode(cls):
6775
@classmethod
6876
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
6977
cache_config = vllm_config.cache_config
78+
# in V1(or with ipex chunked prefill) block_size is 64
7079
if cache_config and cache_config.block_size is None:
71-
cache_config.block_size = 16
72-
73-
# check and update model config
74-
model_config = vllm_config.model_config
75-
if model_config.dtype == torch.bfloat16:
76-
bf16_supported = cls.device_support_bf16()
77-
if not bf16_supported:
80+
if envs.VLLM_USE_V1:
81+
cache_config.block_size = 64
82+
else:
83+
cache_config.block_size = 16
84+
85+
# Instances created using VllmConfig() typically have model_config as
86+
# None by default. The modification involves adding a check to prevent
87+
# potential null exceptions check and update model config.
88+
if vllm_config.model_config is not None:
89+
model_config = vllm_config.model_config
90+
if model_config.dtype == torch.bfloat16:
91+
bf16_supported = cls.device_support_bf16()
92+
if not bf16_supported:
93+
model_config.dtype = torch.float16
94+
if not model_config.enforce_eager:
7895
logger.warning(
79-
"bfloat16 is only supported on Intel Data Center GPU, "
80-
"Intel Arc GPU is not supported yet. Your device is %s,"
81-
" which is not supported. will fallback to float16",
82-
cls.get_device_name())
83-
model_config.dtype = torch.float16
84-
if not model_config.enforce_eager:
85-
logger.warning(
86-
"CUDA graph is not supported on XPU, fallback to the eager "
87-
"mode.")
88-
model_config.enforce_eager = True
96+
"CUDA graph is not supported on XPU, fallback to the eager "
97+
"mode.")
98+
model_config.enforce_eager = True
8999

90100
if vllm_config.speculative_config is not None:
91101
raise NotImplementedError(
@@ -96,21 +106,27 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
96106

97107
# check and update parallel config
98108
parallel_config = vllm_config.parallel_config
99-
if parallel_config.worker_cls == "auto":
109+
if envs.VLLM_USE_V1:
110+
parallel_config.worker_cls =\
111+
"vllm.v1.worker.xpu_worker.XPUWorker"
112+
else:
100113
parallel_config.worker_cls = "vllm.worker.xpu_worker.XPUWorker"
101114

102115
if parallel_config.distributed_executor_backend is None:
103-
parallel_config.distributed_executor_backend = "ray"
116+
if parallel_config.world_size > 1:
117+
parallel_config.distributed_executor_backend = "ray"
118+
else:
119+
parallel_config.distributed_executor_backend = "uni"
104120
elif parallel_config.distributed_executor_backend == "mp":
105121
# FIXME(kunshang):
106122
# spawn needs calling `if __name__ == '__main__':``
107123
# fork is not supported for xpu start new process.
108-
logger.error(
109-
"Both start methods (spawn and fork) have issue "
110-
"on XPU if you use mp backend, setting it to ray instead.")
111-
parallel_config.distributed_executor_backend = "ray"
112-
113-
elif parallel_config.distributed_executor_backend != "ray":
124+
if envs.VLLM_WORKER_MULTIPROC_METHOD != "spawn":
125+
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
126+
logger.warning(
127+
"Please use spawn as start method if you want to use mp.")
128+
elif parallel_config.distributed_executor_backend != "ray" and \
129+
parallel_config.distributed_executor_backend != "uni":
114130
logger.warning(
115131
"%s is not supported on XPU, fallback to ray distributed"
116132
" executor backend.",
@@ -142,15 +158,35 @@ def get_current_memory_usage(cls,
142158
@classmethod
143159
def device_support_bf16(cls) -> bool:
144160
device_name = cls.get_device_name().lower()
145-
if device_name.count("arc") > 0:
161+
if cls.is_client_gpu_a770():
162+
logger.warning("Intel Arc A770 have bfloat16 accuracy known issue,"
163+
" fallback to float16")
146164
return False
147-
elif device_name.count("data center gpu") > 0:
148-
return True
149165
else:
150-
logger.warning("Unknown device name %s, always use float16",
151-
device_name)
152-
return False
166+
logger.info(
167+
"Device name %s supports bfloat16. Please file an issue "
168+
"if you encounter any accuracy problems with bfloat16.",
169+
device_name)
170+
return True
171+
172+
@classmethod
173+
def is_data_center_gpu(cls) -> bool:
174+
device_name = cls.get_device_name().lower()
175+
return device_name.count("data center gpu") > 0
176+
177+
@classmethod
178+
def is_client_gpu_a770(cls) -> bool:
179+
device_name = cls.get_device_name().lower()
180+
return device_name.count("a770") > 0
153181

154182
@classmethod
155183
def get_device_communicator_cls(cls) -> str:
156184
return "vllm.distributed.device_communicators.xpu_communicator.XpuCommunicator" # noqa
185+
186+
@classmethod
187+
def supports_v1(cls, model_config: ModelConfig) -> bool:
188+
return True
189+
190+
@classmethod
191+
def device_count(cls) -> int:
192+
return torch.xpu.device_count()

vllm/v1/attention/backends/flash_attn.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@
1414
from vllm.attention.layer import Attention
1515
from vllm.attention.ops.merge_attn_states import merge_attn_states
1616
from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8,
17-
get_flash_attn_version)
17+
flash_attn_varlen_func,
18+
get_flash_attn_version,
19+
get_scheduler_metadata,
20+
reshape_and_cache_flash)
1821
from vllm.config import VllmConfig, get_layers_from_vllm_config
1922
from vllm.logger import init_logger
20-
from vllm.platforms import current_platform
2123
from vllm.utils import cdiv
2224
from vllm.v1.attention.backends.utils import (
2325
AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout,
@@ -28,10 +30,6 @@
2830
if TYPE_CHECKING:
2931
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
3032

31-
if current_platform.is_cuda():
32-
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
33-
get_scheduler_metadata)
34-
3533
logger = init_logger(__name__)
3634

3735

@@ -443,7 +441,7 @@ def forward(
443441
# and value[:num_actual_tokens] because the reshape_and_cache_flash
444442
# op uses the slot_mapping's shape to determine the number of
445443
# actual tokens.
446-
torch.ops._C_cache_ops.reshape_and_cache_flash(
444+
reshape_and_cache_flash(
447445
key,
448446
value,
449447
key_cache,

0 commit comments

Comments
 (0)