Skip to content

Commit 837e37a

Browse files
zhuo97wangxiaoxin (A)
authored andcommitted
Fix the device error when using ray as vllm-acend backend (vllm-project#884)
1. Remove RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES 2. Add lazy init for vllm_ascend_C Signed-off-by: zhuo97 <1103045176@qq.com> Signed-off-by: wangxiaoxin (A) <wangxiaoxin7@huawei.com>
1 parent 9f9dfde commit 837e37a

File tree

7 files changed

+40
-20
lines changed

7 files changed

+40
-20
lines changed

examples/offline_multi_step_custom_ops.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,6 @@
1919

2020
from vllm import LLM, SamplingParams
2121

22-
import vllm_ascend.platform as pf
23-
24-
pf.CUSTOM_OP_ENABLED = True # set True for custom Ops of Multi-Step.
2522
prompts = [
2623
"Hello, my name is",
2724
"The president of the United States is",

tests/singlecard/ops/test_rotary_embedding.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010
import torch
1111
import torch.nn as nn
1212

13-
import vllm_ascend.platform # noqa: F401
13+
from vllm_ascend.utils import enable_custom_op
14+
15+
enable_custom_op()
1416

1517
# Only Neox style true scenario is supported for now
1618
IS_NEOX_STYLE = [True]

vllm_ascend/attention/attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636

3737
from vllm_ascend.ascend_config import get_ascend_config
3838
from vllm_ascend.ops.cache import concat_and_cache_mla
39-
from vllm_ascend.platform import CUSTOM_OP_ENABLED
39+
from vllm_ascend.utils import enable_custom_op
4040
from vllm_ascend.worker.model_runner import (
4141
ModelInputForNPUBuilder, ModelInputForNPUWithSamplingMetadata)
4242

@@ -462,7 +462,7 @@ def advance_step(self,
462462
for i in range(num_queries):
463463
self.seq_lens[i] += 1
464464
self.max_decode_seq_len = max(self.seq_lens)
465-
if CUSTOM_OP_ENABLED:
465+
if enable_custom_op():
466466
#advance a step on NPU for existing inputs for a multi-step runner if custom ops is enabled
467467
torch.ops._C.advance_step_flashattn_ascendc(
468468
num_seqs=num_seqs,

vllm_ascend/ops/rotary_embedding.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,12 @@
2222
from vllm.model_executor.layers.rotary_embedding import (
2323
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
2424

25-
from vllm_ascend.platform import CUSTOM_OP_ENABLED
25+
from vllm_ascend.utils import enable_custom_op
2626

2727

2828
def custom_rotary_embedding_enabled(query, neox_style, head_size):
29-
return query.dtype == torch.float16 and neox_style and head_size % 32 == 0 and CUSTOM_OP_ENABLED
29+
return query.dtype == torch.float16 and neox_style and head_size % 32 == 0 and enable_custom_op(
30+
)
3031

3132

3233
def rope_forward_oot(

vllm_ascend/platform.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
#
1717

1818
import gc
19-
import logging
2019
import os
2120
from datetime import timedelta
2221
from typing import TYPE_CHECKING, Optional, Tuple
@@ -32,16 +31,6 @@
3231
from vllm_ascend.ascend_config import check_ascend_config, init_ascend_config
3332
from vllm_ascend.utils import ASCEND_QUATIZATION_METHOD, update_aclgraph_sizes
3433

35-
CUSTOM_OP_ENABLED = False
36-
try:
37-
# register custom ops into torch_library here
38-
import vllm_ascend.vllm_ascend_C # type: ignore # noqa: F401
39-
CUSTOM_OP_ENABLED = True
40-
except ImportError as e:
41-
logging.warning(
42-
"Failed to import 'vllm_ascend.vllm_ascend_C': %s. All custom ops will be disabled. ",
43-
e)
44-
4534
if TYPE_CHECKING:
4635
from vllm.config import ModelConfig, VllmConfig
4736
from vllm.utils import FlexibleArgumentParser
@@ -50,7 +39,6 @@
5039
VllmConfig = None
5140
FlexibleArgumentParser = None
5241

53-
os.environ["RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES"] = "1"
5442
os.environ["ACL_OP_INIT_MODE"] = ascend_envs.VLLM_ASCEND_ACL_OP_INIT_MODE
5543

5644

vllm_ascend/utils.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@
5353

5454
ASCEND_QUATIZATION_METHOD = "ascend"
5555

56+
CUSTOM_OP_ENABLED = None
57+
5658

5759
def try_register_lib(lib_name: str, lib_info: str = ""):
5860
import importlib
@@ -67,6 +69,31 @@ def try_register_lib(lib_name: str, lib_info: str = ""):
6769
pass
6870

6971

72+
def enable_custom_op():
73+
"""
74+
Enable lazy init for vllm_ascend_C to avoid early initialization of CANN's RTS component.
75+
Ensure that ASCEND_RT_VISIBLE_DEVICES can be dynamically modified before torch.npu.set_device().
76+
"""
77+
global CUSTOM_OP_ENABLED
78+
79+
if CUSTOM_OP_ENABLED is not None:
80+
return CUSTOM_OP_ENABLED
81+
82+
else:
83+
try:
84+
# register custom ops into torch_library here
85+
import vllm_ascend.vllm_ascend_C # type: ignore # noqa: F401
86+
CUSTOM_OP_ENABLED = True
87+
88+
except ImportError:
89+
CUSTOM_OP_ENABLED = False
90+
logger.warning(
91+
"Warning: Failed to register custom ops, all custom ops will be disabled"
92+
)
93+
94+
return CUSTOM_OP_ENABLED
95+
96+
7097
def find_hccl_library() -> str:
7198
"""
7299
We either use the library file specified by the `HCCL_SO_PATH`

vllm_ascend/worker/worker_v1.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,11 @@ def wake_up(self, tags: Optional[list[str]] = None) -> None:
117117
allocator = CaMemAllocator.get_instance()
118118
allocator.wake_up(tags=tags)
119119

120+
def initialize_cache(self, num_gpu_blocks: int,
121+
num_cpu_blocks: int) -> None:
122+
self.cache_config.num_gpu_blocks = num_gpu_blocks
123+
self.cache_config.num_cpu_blocks = num_cpu_blocks
124+
120125
def init_device(self):
121126
if self.device_config.device.type == "npu":
122127
self.device = torch.device(f"npu:{self.local_rank_across_dp}")

0 commit comments

Comments
 (0)