Skip to content

Commit a0389e0

Browse files
[UT][intel GPU] use current_platform instead of device hardcode in v1 tests (#20169)
Signed-off-by: Ma, Liangliang <liangliang.ma@intel.com>
1 parent 3be8d31 commit a0389e0

File tree

10 files changed

+44
-26
lines changed

10 files changed

+44
-26
lines changed

tests/conftest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
from vllm.outputs import RequestOutput
3535
from vllm.sampling_params import BeamSearchParams
3636
from vllm.transformers_utils.utils import maybe_model_redirect
37-
from vllm.utils import cuda_device_count_stateless
3837

3938
logger = init_logger(__name__)
4039

@@ -1094,7 +1093,8 @@ def num_gpus_available():
10941093
"""Get number of GPUs without initializing the CUDA context
10951094
in current process."""
10961095

1097-
return cuda_device_count_stateless()
1096+
from vllm.platforms import current_platform
1097+
return current_platform.device_count()
10981098

10991099

11001100
temp_dir = tempfile.gettempdir()

tests/v1/sample/test_rejection_sampler.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,13 @@
66
import torch
77
import torch.nn.functional as F
88

9+
from vllm.platforms import current_platform
910
from vllm.v1.sample.metadata import SamplingMetadata
1011
from vllm.v1.sample.rejection_sampler import (PLACEHOLDER_TOKEN_ID,
1112
RejectionSampler)
1213
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
1314

14-
DEVICE = "cuda"
15+
DEVICE = current_platform.device_type
1516

1617

1718
@pytest.fixture
@@ -21,7 +22,7 @@ def rejection_sampler():
2122

2223
def create_logits_tensor(output_token_ids: list[list[int]],
2324
vocab_size: int = 100) -> torch.Tensor:
24-
"""Helper function to create logits tensor that
25+
"""Helper function to create logits tensor that
2526
will produce desired token ids on argmax"""
2627
token_ids = [tokens[:-1] for tokens in output_token_ids]
2728
num_total_tokens = sum(len(tokens) for tokens in token_ids)
@@ -41,8 +42,8 @@ def create_sampling_metadata(
4142
top_p: Optional[torch.Tensor] = None,
4243
generators: Optional[dict[int, Any]] = None,
4344
) -> SamplingMetadata:
44-
"""Create a v1 sampling metadata object with all_greedy set
45-
to the given value. Either all greedy or all random sampling
45+
"""Create a v1 sampling metadata object with all_greedy set
46+
to the given value. Either all greedy or all random sampling
4647
is used.
4748
"""
4849
generators = generators or {}

tests/v1/sample/test_topk_topp_sampler.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,25 +2,26 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
import pytest
44
import torch
5-
from flashinfer.sampling import top_k_renorm_probs, top_p_renorm_probs
65
from torch import Generator
76

87
from vllm.platforms import current_platform
98
from vllm.v1.sample.ops.topk_topp_sampler import (apply_top_k_top_p,
109
is_flashinfer_available)
1110

12-
DEVICE = "cuda"
11+
DEVICE = current_platform.device_type
1312

1413
BATCH_SIZE = 1024
1514
VOCAB_SIZE = 128 * 1024
1615

1716
FLASHINFER_ENABLED = current_platform.is_cuda() and is_flashinfer_available
17+
if is_flashinfer_available:
18+
from flashinfer.sampling import top_k_renorm_probs, top_p_renorm_probs
1819

1920

2021
@pytest.fixture(autouse=True)
2122
def reset_default_device():
2223
"""
23-
Explicitly set the default device, which can affect subsequent tests.
24+
Explicitly set the default device, which can affect subsequent tests.
2425
Adding this fixture helps avoid this problem.
2526
"""
2627
original_device = torch.get_default_device()
@@ -58,8 +59,8 @@ def test_flashinfer_sampler():
5859
This test verifies that the FlashInfer top-k and top-p sampling
5960
implementation produces the same results as the Python implementation.
6061
61-
NOTE: FlashInfer did not directly expose an interface for fused top-k and
62-
top-p prob renorm (it did provide fused sampling but we cannot compare
62+
NOTE: FlashInfer did not directly expose an interface for fused top-k and
63+
top-p prob renorm (it did provide fused sampling but we cannot compare
6364
sampling results due to randomness), so we will compare the probability
6465
renormed consequently by top-k and then top-p of FlashInfer implementation.
6566
'''

tests/v1/spec_decode/test_eagle.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
ParallelConfig, SchedulerConfig, SpeculativeConfig,
1111
VllmConfig)
1212
from vllm.model_executor.models.llama import LlamaForCausalLM
13+
from vllm.platforms import current_platform
1314
from vllm.v1.spec_decode.eagle import EagleProposer
1415

1516
model_dir = "meta-llama/Llama-3.1-8B-Instruct"
@@ -38,15 +39,17 @@ def _create_proposer(method: str, k: int) -> EagleProposer:
3839
num_speculative_tokens=k,
3940
)
4041

41-
vllm_config = VllmConfig(model_config=model_config,
42-
cache_config=CacheConfig(),
43-
speculative_config=speculative_config,
44-
device_config=DeviceConfig(device="cuda"),
45-
parallel_config=ParallelConfig(),
46-
load_config=LoadConfig(),
47-
scheduler_config=SchedulerConfig())
42+
vllm_config = VllmConfig(
43+
model_config=model_config,
44+
cache_config=CacheConfig(),
45+
speculative_config=speculative_config,
46+
device_config=DeviceConfig(device=current_platform.device_type),
47+
parallel_config=ParallelConfig(),
48+
load_config=LoadConfig(),
49+
scheduler_config=SchedulerConfig())
4850

49-
return EagleProposer(vllm_config=vllm_config, device='cuda')
51+
return EagleProposer(vllm_config=vllm_config,
52+
device=current_platform.device_type)
5053

5154

5255
def test_prepare_inputs():
@@ -59,7 +62,7 @@ def test_prepare_inputs():
5962
a, a + 1, ..., a + b - n2 - 1,
6063
a + b, a + b + 1, ..., a + b + c - n3 - 1]
6164
"""
62-
device = torch.device('cuda')
65+
device = torch.device(current_platform.device_type)
6366

6467
# a = 4, b = 7, c = 5
6568
# n1 = 1, n2 = 3, n3 = 2
@@ -198,7 +201,7 @@ class _TargetModelStub(LlamaForCausalLM):
198201
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8])
199202
def test_propose(num_speculative_tokens):
200203
# Use GPU device
201-
device = torch.device('cuda')
204+
device = torch.device(current_platform.device_type)
202205

203206
# Setup test parameters
204207
batch_size = 2

tests/v1/worker/test_gpu_input_batch.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import pytest
99
import torch
1010

11+
from vllm.platforms import current_platform
1112
from vllm.sampling_params import SamplingParams
1213
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
1314
from vllm.v1.pool.metadata import PoolingMetadata
@@ -19,7 +20,8 @@
1920
NUM_OUTPUT_TOKENS = 20
2021
MAX_PROMPT_SIZE = 100
2122
CUDA_DEVICES = [
22-
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
23+
f"{current_platform.device_type}:{i}"
24+
for i in range(min(current_platform.device_count(), 2))
2325
]
2426
MAX_NUM_PROMPT_TOKENS = 64
2527

tests/v1/worker/test_gpu_model_runner.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from vllm.attention import Attention
1010
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
1111
SchedulerConfig, VllmConfig, set_current_vllm_config)
12+
from vllm.platforms import current_platform
1213
from vllm.sampling_params import SamplingParams
1314
from vllm.utils import GiB_bytes
1415
from vllm.v1.core.kv_cache_utils import (estimate_max_model_len,
@@ -23,7 +24,7 @@
2324

2425
BLOCK_SIZE = 16
2526
NUM_BLOCKS = 10
26-
DEVICE = "cuda"
27+
DEVICE = current_platform.device_type
2728

2829

2930
def initialize_kv_cache(runner: GPUModelRunner):

vllm/platforms/cuda.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import vllm._C # noqa
1919
import vllm.envs as envs
2020
from vllm.logger import init_logger
21-
from vllm.utils import import_pynvml
21+
from vllm.utils import cuda_device_count_stateless, import_pynvml
2222

2323
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
2424

@@ -401,6 +401,10 @@ def stateless_init_device_torch_dist_pg(
401401
pg._register_backend(device, backend_type, backend_class)
402402
return pg
403403

404+
@classmethod
405+
def device_count(cls) -> int:
406+
return cuda_device_count_stateless()
407+
404408

405409
# NVML utils
406410
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,

vllm/platforms/rocm.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import vllm.envs as envs
1414
from vllm.logger import init_logger
15+
from vllm.utils import cuda_device_count_stateless
1516

1617
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
1718

@@ -446,3 +447,7 @@ def stateless_init_device_torch_dist_pg(
446447

447448
pg._register_backend(device, backend_type, backend_class)
448449
return pg
450+
451+
@classmethod
452+
def device_count(cls) -> int:
453+
return cuda_device_count_stateless()

vllm/platforms/xpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,4 +189,4 @@ def supports_v1(cls, model_config: ModelConfig) -> bool:
189189

190190
@classmethod
191191
def device_count(cls) -> int:
192-
return torch.xpu.device_count()
192+
return torch.xpu.device_count()

vllm/v1/attention/backends/mla/common.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,8 @@
217217
is_vllm_fa = True
218218
except ImportError:
219219
# For rocm use upstream flash attention
220-
from flash_attn import flash_attn_varlen_func
220+
if current_platform.is_rocm():
221+
from flash_attn import flash_attn_varlen_func
221222
is_vllm_fa = False
222223

223224
if TYPE_CHECKING:

0 commit comments

Comments
 (0)