diff --git a/tests/ops/test_rotary_embedding.py b/tests/ops/test_rotary_embedding.py index 2d5ec18da..35127b17f 100644 --- a/tests/ops/test_rotary_embedding.py +++ b/tests/ops/test_rotary_embedding.py @@ -10,7 +10,12 @@ import torch import torch.nn as nn -import vllm_ascend.platform # noqa: F401 +from vllm_ascend.utils import try_register_lib + +try_register_lib( + "vllm_ascend.vllm_ascend_C", + exc_info= + "Warning: Failed to register custom ops, all custom ops will be disabled.") # Only Neox style true scenario is supported for now IS_NEOX_STYLE = [True] diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index 0c2a00afb..ec7ff296b 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -22,11 +22,16 @@ from vllm.model_executor.layers.rotary_embedding import ( DeepseekScalingRotaryEmbedding, RotaryEmbedding) -from vllm_ascend.platform import CUSTOM_OP_ENABLED +from vllm_ascend.utils import try_register_lib def custom_rotary_embedding_enabled(query, neox_style, head_size): - return query.dtype == torch.float16 and neox_style and head_size % 32 == 0 and CUSTOM_OP_ENABLED + try_register_lib( + "vllm_ascend.vllm_ascend_C", + exc_info= + "Warning: Failed to register custom ops, all custom ops will be disabled." + ) + return query.dtype == torch.float16 and neox_style and head_size % 32 == 0 def rope_forward_oot( diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 2d8834b1b..2c104ad47 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -15,8 +15,6 @@ # This file is a part of the vllm-ascend project. # -import logging -import os from typing import TYPE_CHECKING, Optional, Tuple import torch @@ -27,18 +25,6 @@ from vllm_ascend.utils import ASCEND_QUATIZATION_METHOD, update_aclgraph_sizes -CUSTOM_OP_ENABLED = False -try: - # register custom ops into torch_library here - import vllm_ascend.vllm_ascend_C # type: ignore # noqa: F401 - -except ImportError: - logging.warning( - "Warning: Failed to register custom ops, all custom ops will be disabled" - ) -else: - CUSTOM_OP_ENABLED = True - if TYPE_CHECKING: from vllm.config import ModelConfig, VllmConfig from vllm.utils import FlexibleArgumentParser @@ -47,8 +33,6 @@ VllmConfig = None FlexibleArgumentParser = None -os.environ["RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES"] = "1" - class NPUPlatform(Platform): diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index cd83faed4..17340b338 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -41,7 +41,7 @@ ASCEND_QUATIZATION_METHOD = "ascend" -def try_register_lib(lib_name: str, lib_info: str = ""): +def try_register_lib(lib_name: str, lib_info: str = "", exc_info: str = ""): import importlib import importlib.util try: @@ -51,6 +51,8 @@ def try_register_lib(lib_name: str, lib_info: str = ""): if lib_info: logger.info(lib_info) except Exception: + if exc_info: + logger.info(exc_info) pass