Skip to content

Commit 289199f

Browse files
authored
[Core] Use platform-agnostic device control for DP engine core (#17245)
Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
1 parent b9fd0d7 commit 289199f

File tree

4 files changed

+30
-39
lines changed

4 files changed

+30
-39
lines changed

vllm/platforms/cuda.py

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -34,24 +34,6 @@
3434
torch.backends.cuda.enable_cudnn_sdp(False)
3535

3636

37-
def device_id_to_physical_device_id(device_id: int) -> int:
38-
if "CUDA_VISIBLE_DEVICES" in os.environ:
39-
device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
40-
if device_ids == [""]:
41-
msg = (
42-
"CUDA_VISIBLE_DEVICES is set to empty string, which means"
43-
" GPU support is disabled. If you are using ray, please unset"
44-
" the environment variable `CUDA_VISIBLE_DEVICES` inside the"
45-
" worker/actor. "
46-
"Check https://github.com/vllm-project/vllm/issues/8402 for"
47-
" more information.")
48-
raise RuntimeError(msg)
49-
physical_device_id = device_ids[device_id]
50-
return int(physical_device_id)
51-
else:
52-
return device_id
53-
54-
5537
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
5638

5739
@wraps(fn)
@@ -338,7 +320,7 @@ def get_device_capability(cls,
338320
device_id: int = 0
339321
) -> Optional[DeviceCapability]:
340322
try:
341-
physical_device_id = device_id_to_physical_device_id(device_id)
323+
physical_device_id = cls.device_id_to_physical_device_id(device_id)
342324
handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
343325
major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
344326
return DeviceCapability(major=major, minor=minor)
@@ -360,20 +342,20 @@ def has_device_capability(
360342
@classmethod
361343
@with_nvml_context
362344
def get_device_name(cls, device_id: int = 0) -> str:
363-
physical_device_id = device_id_to_physical_device_id(device_id)
345+
physical_device_id = cls.device_id_to_physical_device_id(device_id)
364346
return cls._get_physical_device_name(physical_device_id)
365347

366348
@classmethod
367349
@with_nvml_context
368350
def get_device_uuid(cls, device_id: int = 0) -> str:
369-
physical_device_id = device_id_to_physical_device_id(device_id)
351+
physical_device_id = cls.device_id_to_physical_device_id(device_id)
370352
handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
371353
return pynvml.nvmlDeviceGetUUID(handle)
372354

373355
@classmethod
374356
@with_nvml_context
375357
def get_device_total_memory(cls, device_id: int = 0) -> int:
376-
physical_device_id = device_id_to_physical_device_id(device_id)
358+
physical_device_id = cls.device_id_to_physical_device_id(device_id)
377359
handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
378360
return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total)
379361

vllm/platforms/interface.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
import enum
3+
import os
34
import platform
45
import random
56
from platform import uname
@@ -161,6 +162,24 @@ def is_cuda_alike(self) -> bool:
161162
def is_sleep_mode_available(self) -> bool:
162163
return self._enum == PlatformEnum.CUDA
163164

165+
@classmethod
166+
def device_id_to_physical_device_id(cls, device_id: int):
167+
if cls.device_control_env_var in os.environ:
168+
device_ids = os.environ[cls.device_control_env_var].split(",")
169+
if device_ids == [""]:
170+
msg = (f"{cls.device_control_env_var} is set to empty string, "
171+
"which means current platform support is disabled. If "
172+
"you are using ray, please unset the environment "
173+
f"variable `{cls.device_control_env_var}` inside the "
174+
"worker/actor. Check "
175+
"https://github.com/vllm-project/vllm/issues/8402 for "
176+
"more information.")
177+
raise RuntimeError(msg)
178+
physical_device_id = device_ids[device_id]
179+
return int(physical_device_id)
180+
else:
181+
return device_id
182+
164183
@classmethod
165184
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
166185
dtype: torch.dtype, kv_cache_dtype: Optional[str],

vllm/platforms/rocm.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -95,15 +95,6 @@ def wrapper(*args, **kwargs):
9595
return wrapper
9696

9797

98-
def device_id_to_physical_device_id(device_id: int) -> int:
99-
if "CUDA_VISIBLE_DEVICES" in os.environ:
100-
device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
101-
physical_device_id = device_ids[device_id]
102-
return int(physical_device_id)
103-
else:
104-
return device_id
105-
106-
10798
@cache
10899
def on_mi250_mi300() -> bool:
109100
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
@@ -238,7 +229,7 @@ def is_fully_connected(physical_device_ids: List[int]) -> bool:
238229
@with_amdsmi_context
239230
@lru_cache(maxsize=8)
240231
def get_device_name(cls, device_id: int = 0) -> str:
241-
physical_device_id = device_id_to_physical_device_id(device_id)
232+
physical_device_id = cls.device_id_to_physical_device_id(device_id)
242233
handle = amdsmi_get_processor_handles()[physical_device_id]
243234
asic_info = amdsmi_get_gpu_asic_info(handle)
244235
device_name: str = asic_info["device_id"]

vllm/v1/engine/core.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -622,13 +622,12 @@ def __init__(
622622
assert 0 <= local_dp_rank <= dp_rank < dp_size
623623

624624
from vllm.platforms import current_platform
625-
if current_platform.is_cuda_alike():
626-
from vllm.platforms.cuda import device_id_to_physical_device_id
627-
tp_size = vllm_config.parallel_config.tensor_parallel_size
628-
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
629-
str(device_id_to_physical_device_id(i))
630-
for i in range(local_dp_rank * tp_size, (local_dp_rank + 1) *
631-
tp_size))
625+
device_control_env_var = current_platform.device_control_env_var
626+
tp_size = vllm_config.parallel_config.tensor_parallel_size
627+
os.environ[device_control_env_var] = ",".join(
628+
str(current_platform.device_id_to_physical_device_id(i))
629+
for i in range(local_dp_rank * tp_size, (local_dp_rank + 1) *
630+
tp_size))
632631

633632
self.local_dp_rank = local_dp_rank
634633
self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()

0 commit comments

Comments
 (0)