Skip to content

Commit c8ea982

Browse files
authored
Update deprecated type hinting in platform, plugins, triton_utils, vllm_flash_attn (#18129)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
1 parent dc372b9 commit c8ea982

File tree

6 files changed

+18
-24
lines changed

6 files changed

+18
-24
lines changed

pyproject.toml

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,13 +78,8 @@ exclude = [
7878
"vllm/executor/**/*.py" = ["UP006", "UP035"]
7979
"vllm/model_executor/model_loader/**/*.py" = ["UP006", "UP035"]
8080
"vllm/model_executor/models/**/*.py" = ["UP006", "UP035"]
81-
"vllm/platforms/**/*.py" = ["UP006", "UP035"]
82-
"vllm/plugins/**/*.py" = ["UP006", "UP035"]
8381
"vllm/prompt_adapter/**/*.py" = ["UP006", "UP035"]
8482
"vllm/spec_decode/**/*.py" = ["UP006", "UP035"]
85-
"vllm/transformers_utils/**/*.py" = ["UP006", "UP035"]
86-
"vllm/triton_utils/**/*.py" = ["UP006", "UP035"]
87-
"vllm/vllm_flash_attn/**/*.py" = ["UP006", "UP035"]
8883
"vllm/worker/**/*.py" = ["UP006", "UP035"]
8984
"vllm/utils.py" = ["UP006", "UP035"]
9085

vllm/platforms/cuda.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@
55

66
import os
77
from functools import wraps
8-
from typing import (TYPE_CHECKING, Callable, List, Optional, Tuple, TypeVar,
9-
Union)
8+
from typing import TYPE_CHECKING, Callable, Optional, TypeVar, Union
109

1110
import torch
1211
from typing_extensions import ParamSpec
@@ -56,7 +55,7 @@ class CudaPlatformBase(Platform):
5655
device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
5756

5857
@property
59-
def supported_dtypes(self) -> List[torch.dtype]:
58+
def supported_dtypes(self) -> list[torch.dtype]:
6059
if self.has_device_capability(80):
6160
# Ampere and Hopper or later NVIDIA GPUs.
6261
return [torch.bfloat16, torch.float16, torch.float32]
@@ -93,7 +92,7 @@ def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
9392
return True
9493

9594
@classmethod
96-
def is_fully_connected(cls, device_ids: List[int]) -> bool:
95+
def is_fully_connected(cls, device_ids: list[int]) -> bool:
9796
raise NotImplementedError
9897

9998
@classmethod
@@ -335,7 +334,7 @@ def get_device_capability(cls,
335334
@with_nvml_context
336335
def has_device_capability(
337336
cls,
338-
capability: Union[Tuple[int, int], int],
337+
capability: Union[tuple[int, int], int],
339338
device_id: int = 0,
340339
) -> bool:
341340
try:
@@ -365,7 +364,7 @@ def get_device_total_memory(cls, device_id: int = 0) -> int:
365364

366365
@classmethod
367366
@with_nvml_context
368-
def is_fully_connected(cls, physical_device_ids: List[int]) -> bool:
367+
def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
369368
"""
370369
query if the set of gpus are fully connected by nvlink (1 hop)
371370
"""
@@ -430,7 +429,7 @@ def get_device_total_memory(cls, device_id: int = 0) -> int:
430429
return device_props.total_memory
431430

432431
@classmethod
433-
def is_fully_connected(cls, physical_device_ids: List[int]) -> bool:
432+
def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
434433
logger.exception(
435434
"NVLink detection not possible, as context support was"
436435
" not found. Assuming no NVLink available.")

vllm/platforms/interface.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import platform
55
import random
66
from platform import uname
7-
from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Union
7+
from typing import TYPE_CHECKING, NamedTuple, Optional, Union
88

99
import numpy as np
1010
import torch
@@ -200,7 +200,7 @@ def get_device_capability(
200200
@classmethod
201201
def has_device_capability(
202202
cls,
203-
capability: Union[Tuple[int, int], int],
203+
capability: Union[tuple[int, int], int],
204204
device_id: int = 0,
205205
) -> bool:
206206
"""
@@ -362,7 +362,7 @@ def get_punica_wrapper(cls) -> str:
362362
raise NotImplementedError
363363

364364
@classmethod
365-
def get_infinity_values(cls, dtype: torch.dtype) -> Tuple[float, float]:
365+
def get_infinity_values(cls, dtype: torch.dtype) -> tuple[float, float]:
366366
"""
367367
Return the platform specific values for (-inf, inf)
368368
"""

vllm/platforms/rocm.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import os
44
from functools import cache, lru_cache, wraps
5-
from typing import TYPE_CHECKING, Dict, List, Optional
5+
from typing import TYPE_CHECKING, Optional
66

77
import torch
88

@@ -35,15 +35,15 @@
3535
logger.warning("Failed to import from vllm._rocm_C with %r", e)
3636

3737
# Models not supported by ROCm.
38-
_ROCM_UNSUPPORTED_MODELS: List[str] = []
38+
_ROCM_UNSUPPORTED_MODELS: list[str] = []
3939

4040
# Models partially supported by ROCm.
4141
# Architecture -> Reason.
4242
_ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in "
4343
"Triton flash attention. For half-precision SWA support, "
4444
"please use CK flash attention by setting "
4545
"`VLLM_USE_TRITON_FLASH_ATTN=0`")
46-
_ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
46+
_ROCM_PARTIALLY_SUPPORTED_MODELS: dict[str, str] = {
4747
"Qwen2ForCausalLM":
4848
_ROCM_SWA_REASON,
4949
"MistralForCausalLM":
@@ -58,7 +58,7 @@
5858
"excessive use of shared memory. If this happens, disable Triton FA "
5959
"by setting `VLLM_USE_TRITON_FLASH_ATTN=0`")
6060
}
61-
_ROCM_DEVICE_ID_NAME_MAP: Dict[str, str] = {
61+
_ROCM_DEVICE_ID_NAME_MAP: dict[str, str] = {
6262
"0x74a0": "AMD_Instinct_MI300A",
6363
"0x74a1": "AMD_Instinct_MI300X",
6464
"0x74b5": "AMD_Instinct_MI300X", # MI300X VF
@@ -203,7 +203,7 @@ def get_device_capability(cls,
203203

204204
@staticmethod
205205
@with_amdsmi_context
206-
def is_fully_connected(physical_device_ids: List[int]) -> bool:
206+
def is_fully_connected(physical_device_ids: list[int]) -> bool:
207207
"""
208208
Query if the set of gpus are fully connected by xgmi (1 hop)
209209
"""

vllm/platforms/tpu.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3-
from typing import TYPE_CHECKING, Optional, Tuple, Union, cast
3+
from typing import TYPE_CHECKING, Optional, Union, cast
44

55
import torch
66
from tpu_info import device
@@ -73,7 +73,7 @@ def get_punica_wrapper(cls) -> str:
7373
return "vllm.lora.punica_wrapper.punica_tpu.PunicaWrapperTPU"
7474

7575
@classmethod
76-
def get_infinity_values(cls, dtype: torch.dtype) -> Tuple[float, float]:
76+
def get_infinity_values(cls, dtype: torch.dtype) -> tuple[float, float]:
7777
return torch.finfo(dtype).min, torch.finfo(dtype).max
7878

7979
@classmethod

vllm/plugins/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import logging
44
import os
5-
from typing import Callable, Dict
5+
from typing import Callable
66

77
import torch
88

@@ -14,7 +14,7 @@
1414
plugins_loaded = False
1515

1616

17-
def load_plugins_by_group(group: str) -> Dict[str, Callable]:
17+
def load_plugins_by_group(group: str) -> dict[str, Callable]:
1818
import sys
1919
if sys.version_info < (3, 10):
2020
from importlib_metadata import entry_points

0 commit comments

Comments
 (0)