Skip to content

Commit 4d36693

Browse files
authored
[Refactor] Create a function util and cache the results for has_deepgemm, has_deepep, has_pplx (#20187)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
1 parent daec9de commit 4d36693

File tree

12 files changed

+61
-58
lines changed

12 files changed

+61
-58
lines changed

tests/kernels/moe/test_deepep_deepgemm_moe.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
"""
77

88
import dataclasses
9-
import importlib
109
from typing import Optional
1110

1211
import pytest
@@ -21,38 +20,33 @@
2120
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
2221
per_token_group_quant_fp8)
2322
from vllm.platforms import current_platform
23+
from vllm.utils import has_deep_ep, has_deep_gemm
2424

2525
from .utils import ProcessGroupInfo, parallel_launch
2626

27-
has_deep_ep = importlib.util.find_spec("deep_ep") is not None
28-
29-
try:
30-
import deep_gemm
31-
has_deep_gemm = True
32-
except ImportError:
33-
has_deep_gemm = False
34-
35-
if has_deep_ep:
27+
if has_deep_ep():
3628
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
3729
DeepEPHTPrepareAndFinalize)
3830
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
3931
DeepEPLLPrepareAndFinalize)
4032

4133
from .deepep_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
4234

43-
if has_deep_gemm:
35+
if has_deep_gemm():
36+
import deep_gemm
37+
4438
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
4539
BatchedDeepGemmExperts)
4640
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
4741
DeepGemmExperts)
4842

4943
requires_deep_ep = pytest.mark.skipif(
50-
not has_deep_ep,
44+
not has_deep_ep(),
5145
reason="Requires deep_ep kernels",
5246
)
5347

5448
requires_deep_gemm = pytest.mark.skipif(
55-
not has_deep_gemm,
49+
not has_deep_gemm(),
5650
reason="Requires deep_gemm kernels",
5751
)
5852

tests/kernels/moe/test_deepep_moe.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
"""
55

66
import dataclasses
7-
import importlib
87
from typing import Optional, Union
98

109
import pytest
@@ -22,12 +21,11 @@
2221
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
2322
per_token_group_quant_fp8)
2423
from vllm.platforms import current_platform
24+
from vllm.utils import has_deep_ep
2525

2626
from .utils import ProcessGroupInfo, parallel_launch
2727

28-
has_deep_ep = importlib.util.find_spec("deep_ep") is not None
29-
30-
if has_deep_ep:
28+
if has_deep_ep():
3129
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
3230
DeepEPHTPrepareAndFinalize)
3331
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
@@ -36,7 +34,7 @@
3634
from .deepep_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
3735

3836
requires_deep_ep = pytest.mark.skipif(
39-
not has_deep_ep,
37+
not has_deep_ep(),
4038
reason="Requires deep_ep kernels",
4139
)
4240

vllm/distributed/device_communicators/all2all.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3-
import importlib.util
43
from typing import TYPE_CHECKING, Any
54

65
import torch
76
import torch.distributed as dist
87

98
from vllm.forward_context import get_forward_context
109
from vllm.logger import init_logger
10+
from vllm.utils import has_deep_ep, has_pplx
1111

1212
from .base_device_communicator import All2AllManagerBase, Cache
1313

@@ -80,8 +80,8 @@ class PPLXAll2AllManager(All2AllManagerBase):
8080
"""
8181

8282
def __init__(self, cpu_group):
83-
has_pplx = importlib.util.find_spec("pplx_kernels") is not None
84-
assert has_pplx, "pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install pplx_kernels." # noqa
83+
assert has_pplx(
84+
), "pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install pplx_kernels." # noqa
8585
super().__init__(cpu_group)
8686

8787
if self.internode:
@@ -133,8 +133,8 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
133133
"""
134134

135135
def __init__(self, cpu_group):
136-
has_deepep = importlib.util.find_spec("deep_ep") is not None
137-
assert has_deepep, "DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install DeepEP kernels." # noqa
136+
assert has_deep_ep(
137+
), "DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install DeepEP kernels." # noqa
138138
super().__init__(cpu_group)
139139
self.handle_cache = Cache()
140140

vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# SPDX-License-Identifier: Apache-2.0
2-
import importlib.util
32
from typing import Optional
43

54
import torch
@@ -11,8 +10,6 @@
1110

1211
logger = init_logger(__name__)
1312

14-
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
15-
1613

1714
@triton.jit
1815
def _silu_mul_fp8_quant_deep_gemm(

vllm/model_executor/layers/fused_moe/deep_gemm_moe.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
import functools
4-
import importlib.util
54
from typing import Optional
65

76
import torch
@@ -12,14 +11,13 @@
1211
_moe_permute)
1312
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
1413
MoEPrepareAndFinalizeNoEP)
15-
from vllm.model_executor.layers.fused_moe.utils import (
16-
_resize_cache, per_token_group_quant_fp8)
17-
from vllm.utils import round_up
14+
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
15+
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
16+
per_token_group_quant_fp8)
17+
from vllm.utils import has_deep_gemm, round_up
1818

1919
logger = init_logger(__name__)
2020

21-
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
22-
2321

2422
@functools.cache
2523
def deep_gemm_block_shape() -> list[int]:
@@ -41,7 +39,7 @@ def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor,
4139
gemm kernel. All of M, N, K and the quantization block_shape must be
4240
aligned by `dg.get_m_alignment_for_contiguous_layout()`.
4341
"""
44-
if not has_deep_gemm:
42+
if not has_deep_gemm():
4543
logger.debug("DeepGemm disabled: deep_gemm not available.")
4644
return False
4745

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4-
import importlib
54
from abc import abstractmethod
65
from collections.abc import Iterable
76
from dataclasses import dataclass
@@ -32,20 +31,17 @@
3231
from vllm.model_executor.utils import set_weight_attrs
3332
from vllm.platforms import current_platform
3433
from vllm.platforms.interface import CpuArchEnum
35-
from vllm.utils import direct_register_custom_op
36-
37-
has_pplx = importlib.util.find_spec("pplx_kernels") is not None
38-
has_deepep = importlib.util.find_spec("deep_ep") is not None
34+
from vllm.utils import direct_register_custom_op, has_deep_ep, has_pplx
3935

4036
if current_platform.is_cuda_alike():
4137
from .fused_batched_moe import BatchedTritonExperts
4238
from .fused_moe import TritonExperts, fused_experts
4339
from .modular_kernel import (FusedMoEModularKernel,
4440
FusedMoEPermuteExpertsUnpermute,
4541
FusedMoEPrepareAndFinalize)
46-
if has_pplx:
42+
if has_pplx():
4743
from .pplx_prepare_finalize import PplxPrepareAndFinalize
48-
if has_deepep:
44+
if has_deep_ep():
4945
from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize
5046
from .deepep_ll_prepare_finalize import (DEEPEP_QUANT_BLOCK_SIZE,
5147
DeepEPLLPrepareAndFinalize)

vllm/model_executor/layers/fused_moe/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,4 +104,4 @@ def find_free_port():
104104
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
105105
s.bind(('', 0))
106106
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
107-
return s.getsockname()[1]
107+
return s.getsockname()[1]

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
import enum
5-
import importlib
65
from enum import Enum
76
from typing import Callable, Optional
87

@@ -29,13 +28,12 @@
2928
from vllm.model_executor.utils import set_weight_attrs
3029
from vllm.platforms import current_platform
3130
from vllm.scalar_type import scalar_types
32-
33-
has_pplx = importlib.util.find_spec("pplx_kernels") is not None
31+
from vllm.utils import has_pplx
3432

3533
if current_platform.is_cuda_alike():
3634
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
3735
BatchedPrepareAndFinalize)
38-
if has_pplx:
36+
if has_pplx():
3937
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
4038
PplxPrepareAndFinalize)
4139

@@ -577,7 +575,7 @@ def select_gemm_impl(self, prepare_finalize, moe):
577575
use_batched_format=True,
578576
)
579577

580-
if has_pplx and isinstance(
578+
if has_pplx() and isinstance(
581579
prepare_finalize,
582580
(BatchedPrepareAndFinalize, PplxPrepareAndFinalize)):
583581
# no expert_map support in this case

vllm/model_executor/layers/quantization/deepgemm.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
# SPDX-License-Identifier: Apache-2.0
2-
import importlib.util
32
import logging
43

54
import torch
65

76
from vllm.platforms import current_platform
87
from vllm.triton_utils import triton
9-
from vllm.utils import direct_register_custom_op
8+
from vllm.utils import direct_register_custom_op, has_deep_gemm
109

11-
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
12-
if has_deep_gemm:
10+
if has_deep_gemm():
1311
import deep_gemm
1412

1513
logger = logging.getLogger(__name__)

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
import functools
5-
import importlib.util
65
from typing import Any, Callable, Optional, Union
76

87
import torch
@@ -38,13 +37,12 @@
3837
from vllm.model_executor.utils import set_weight_attrs
3938
from vllm.platforms import current_platform
4039
from vllm.scalar_type import scalar_types
40+
from vllm.utils import has_deep_gemm
4141

4242
ACTIVATION_SCHEMES = ["static", "dynamic"]
4343

4444
logger = init_logger(__name__)
4545

46-
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
47-
4846

4947
def _is_col_major(x: torch.Tensor) -> bool:
5048
assert x.dim() == 3
@@ -451,7 +449,7 @@ def __init__(self, quant_config: Fp8Config):
451449
# Check for DeepGemm support.
452450
self.allow_deep_gemm = False
453451
if envs.VLLM_USE_DEEP_GEMM:
454-
if not has_deep_gemm:
452+
if not has_deep_gemm():
455453
logger.warning_once("Failed to import DeepGemm kernels.")
456454
elif not self.block_quant:
457455
logger.warning_once("Model is not block quantized. Not using "

0 commit comments

Comments
 (0)