Skip to content

Commit a3e4e85

Browse files
[XPU][CI] enhance xpu test support (#20652)
Signed-off-by: Ma, Liangliang <liangliang.ma@intel.com> Co-authored-by: zhenwei-intel <zhenweiliu@habana.ai>
1 parent eb58f59 commit a3e4e85

File tree

5 files changed

+18
-12
lines changed

5 files changed

+18
-12
lines changed

tests/conftest.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -759,7 +759,8 @@ class VllmRunner:
759759
- `trust_remote_code`: Set to `True` instead of `False` for convenience.
760760
- `seed`: Set to `0` instead of `None` for test reproducibility.
761761
- `max_model_len`: Set to `1024` instead of `None` to reduce memory usage.
762-
- `block_size`: Set to `16` instead of `None` to reduce memory usage.
762+
- `block_size`: To reduce memory usage, set default to `64` if on XPU
763+
devices, otherwise default to `16`.
763764
- `enable_chunked_prefill`: Set to `False` instead of `None` for
764765
test reproducibility.
765766
- `enforce_eager`: Set to `False` to test CUDA graph.
@@ -777,7 +778,7 @@ def __init__(
777778
dtype: str = "auto",
778779
disable_log_stats: bool = True,
779780
tensor_parallel_size: int = 1,
780-
block_size: int = 16,
781+
block_size: int = 16 if not torch.xpu.is_available() else 64,
781782
enable_chunked_prefill: Optional[bool] = False,
782783
swap_space: int = 4,
783784
enforce_eager: Optional[bool] = False,

vllm/distributed/device_communicators/xpu_communicator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,6 @@ def gather(self,
5353
else:
5454
output_tensor = None
5555
return output_tensor
56+
57+
def broadcast(self, input_: torch.Tensor, src: int = 0) -> None:
58+
dist.broadcast(input_, src=src, group=self.device_group)

vllm/distributed/parallel_state.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,8 @@ def __init__(
240240

241241
if current_platform.is_cuda_alike():
242242
self.device = torch.device(f"cuda:{local_rank}")
243+
elif current_platform.is_xpu():
244+
self.device = torch.device(f"xpu:{local_rank}")
243245
elif current_platform.is_out_of_tree():
244246
self.device = torch.device(
245247
f"{current_platform.device_name}:{local_rank}")
@@ -1317,13 +1319,13 @@ def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup],
13171319

13181320
def is_global_first_rank() -> bool:
13191321
"""
1320-
Check if the current process is the first rank globally across all
1322+
Check if the current process is the first rank globally across all
13211323
parallelism strategies (PP, TP, DP, EP, etc.).
1322-
1324+
13231325
Unlike group-specific checks like `get_tensor_model_parallel_rank() == 0`
13241326
or `get_pp_group().is_first_rank`, this function checks the global rank
13251327
across all parallelism dimensions.
1326-
1328+
13271329
Returns:
13281330
bool: True if this is the global first rank (rank 0), False otherwise.
13291331
Returns True if distributed is not initialized (single process).
@@ -1352,7 +1354,7 @@ def _node_count(pg: Union[ProcessGroup, StatelessProcessGroup]) -> int:
13521354
13531355
Args:
13541356
pg: The process group to analyze
1355-
1357+
13561358
Returns:
13571359
int: The total number of nodes
13581360
"""

vllm/platforms/xpu.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
9191

9292
# FIXME: Temporarily forcing eager mode
9393
# remove after t.compile support stabilizes.
94+
9495
if (envs.VLLM_USE_V1 and vllm_config.model_config is not None
9596
and not vllm_config.model_config.enforce_eager):
9697
from vllm.config import CompilationLevel
@@ -111,9 +112,6 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
111112
"mode.")
112113
model_config.enforce_eager = True
113114

114-
if vllm_config.device_config is not None:
115-
assert vllm_config.device_config.device_type == "xpu"
116-
117115
# check and update parallel config
118116
parallel_config = vllm_config.parallel_config
119117
parallel_config.worker_cls = "vllm.v1.worker.xpu_worker.XPUWorker"
@@ -131,8 +129,10 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
131129
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
132130
logger.warning(
133131
"Please use spawn as start method if you want to use mp.")
134-
elif parallel_config.distributed_executor_backend != "ray" and \
135-
parallel_config.distributed_executor_backend != "uni":
132+
elif (parallel_config.distributed_executor_backend != "ray"
133+
and parallel_config.distributed_executor_backend != "uni"
134+
and parallel_config.distributed_executor_backend
135+
!= "external_launcher"):
136136
logger.warning(
137137
"%s is not supported on XPU, fallback to ray distributed"
138138
" executor backend.",

vllm/v1/worker/xpu_model_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def __init__(
2727
self.cascade_attn_enabled = False
2828

2929
def _init_device_properties(self) -> None:
30-
pass
30+
self.num_sms = None
3131

3232
def _sync_device(self) -> None:
3333
torch.xpu.synchronize()

0 commit comments

Comments
 (0)