|
1 |
| -import torch |
| 1 | +# |
| 2 | +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. |
| 3 | +# Copyright 2023 The vLLM team. |
| 4 | +# |
| 5 | +# |
| 6 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 7 | +# you may not use this file except in compliance with the License. |
| 8 | +# You may obtain a copy of the License at |
| 9 | +# |
| 10 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 11 | +# |
| 12 | +# Unless required by applicable law or agreed to in writing, software |
| 13 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 14 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 15 | +# See the License for the specific language governing permissions and |
| 16 | +# limitations under the License. |
| 17 | +# Adapted from vllm/model_executor/models/qwen2_vl.py |
| 18 | +# This file is a part of the vllm-ascend project. |
| 19 | + |
2 | 20 | import vllm
|
3 | 21 | import vllm.distributed
|
4 |
| -from torch.distributed import ProcessGroup |
5 |
| -from torch.distributed.distributed_c10d import (Backend, PrefixStore, |
6 |
| - _get_default_timeout, |
7 |
| - is_nccl_available) |
8 |
| -from torch.distributed.rendezvous import rendezvous |
9 | 22 | from vllm.config import ParallelConfig
|
10 | 23 |
|
| 24 | +from vllm_ascend.patch.platform.patch_0_8_4.patch_distributed import ( |
| 25 | + ascend_destroy_model_parallel, |
| 26 | + ascend_stateless_init_torch_distributed_process_group, |
| 27 | + parallel_config_get_dp_port) |
11 | 28 |
|
12 |
| -def ascend_destroy_model_parallel(): |
13 |
| - """Set the groups to none and destroy them.""" |
14 |
| - from vllm.distributed.parallel_state import _DP, _PP, _TP |
15 |
| - if _TP: |
16 |
| - _TP.destroy() |
17 |
| - _TP = None |
18 |
| - |
19 |
| - if _PP: |
20 |
| - _PP.destroy() |
21 |
| - _PP = None |
22 |
| - |
23 |
| - if _DP: |
24 |
| - _DP.destroy() |
25 |
| - _DP = None |
26 |
| - from vllm.platforms import current_platform |
27 |
| - current_platform.destroy_platform_model_parallel() |
28 |
| - |
29 |
| - |
30 |
| -def ascend_stateless_init_torch_distributed_process_group( |
31 |
| - host: str, port: int, rank: int, world_size: int, |
32 |
| - backend: str) -> ProcessGroup: |
33 |
| - """ |
34 |
| - A replacement for `torch.distributed.init_process_group` that does not |
35 |
| - pollute the global state. The created ProcessGroup object can be used for |
36 |
| - some operations such as `allreduce`, because it does not depend on the |
37 |
| - global rank. However, some operations such as `broadcast` cannot be used |
38 |
| - because it depends on the global rank. |
39 |
| -
|
40 |
| - # TODO: ask for help from PyTorch team if we need the `broadcast` operation. |
41 |
| -
|
42 |
| - This function is useful when we are not sure about the total number of |
43 |
| - processes in the process group. For example, we may have process |
44 |
| - 1, 2, ..., 8 who want to communicate, and process 9 might be the same |
45 |
| - process as process 1, or it might be a different process; process 10 |
46 |
| - might be the same process as process 5, or it might be a different process. |
47 |
| - In this case, how can we reliably form a communication channel within |
48 |
| - process 9 and 10, without affecting the communication channel within |
49 |
| - process 1, 2, ..., 8? |
50 |
| -
|
51 |
| - One possible solution is to figure out if process 9 and 10 are the same |
52 |
| - as process 1 and 5 beforehand, and then form a communication channel |
53 |
| - based on the information, adjusting the ranks and world_size etc. However, |
54 |
| - figuring out the information is not always easy, and it will interfere |
55 |
| - with the main communication channel. |
56 |
| -
|
57 |
| - Our solution is to always form a communication channel with process 1, 2, |
58 |
| - ..., 8, and then use this function to form another communication channel |
59 |
| - with process 9 and 10. This way, regardless of whether process 9 and 10 |
60 |
| - are the same as process 1 and 5, the main communication channel is |
61 |
| - always formed with process 1, 2, ..., 8, and the additional communication |
62 |
| - channel is formed with process 9 and 10. |
63 |
| - """ |
64 |
| - init_method = f"tcp://{host}:{port}" |
65 |
| - backend = Backend(backend) # it is basically string |
66 |
| - timeout = _get_default_timeout(backend) |
67 |
| - |
68 |
| - store, rank, world_size = next( |
69 |
| - rendezvous(init_method, rank, world_size, timeout=timeout)) |
70 |
| - store.set_timeout(timeout) |
71 |
| - |
72 |
| - group_rank = rank |
73 |
| - group_size = world_size |
74 |
| - |
75 |
| - # Use a PrefixStore to avoid accidental overrides of keys used by |
76 |
| - # different systems (e.g. RPC) in case the store is multi-tenant. |
77 |
| - prefix_store = PrefixStore(init_method, store) |
78 |
| - |
79 |
| - pg: ProcessGroup = ProcessGroup( |
80 |
| - prefix_store, |
81 |
| - group_rank, |
82 |
| - group_size, |
83 |
| - ) |
84 |
| - from vllm.platforms import current_platform |
85 |
| - if backend == "gloo": |
86 |
| - from torch.distributed.distributed_c10d import ProcessGroupGloo |
87 |
| - backend_class = ProcessGroupGloo(prefix_store, |
88 |
| - group_rank, |
89 |
| - group_size, |
90 |
| - timeout=timeout) |
91 |
| - backend_type = ProcessGroup.BackendType.GLOO |
92 |
| - device = torch.device("cpu") |
93 |
| - elif backend == "nccl": |
94 |
| - assert is_nccl_available() |
95 |
| - from torch.distributed.distributed_c10d import ProcessGroupNCCL |
96 |
| - |
97 |
| - backend_options = ProcessGroupNCCL.Options() |
98 |
| - backend_options._timeout = timeout |
99 |
| - |
100 |
| - backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size, |
101 |
| - backend_options) |
102 |
| - backend_type = ProcessGroup.BackendType.NCCL |
103 |
| - device = torch.device("cuda") |
104 |
| - elif current_platform.platform_has_backend_register(): |
105 |
| - current_platform.platform_register_backend() |
106 |
| - return pg |
107 |
| - else: |
108 |
| - raise RuntimeError(f"Unsupported torch distributed backend: {backend}") |
109 |
| - |
110 |
| - pg._set_default_backend(backend_type) |
111 |
| - backend_class._set_sequence_number_for_group() |
112 |
| - |
113 |
| - pg._register_backend(device, backend_type, backend_class) |
114 |
| - |
115 |
| - return pg |
116 |
| - |
117 |
| - |
118 |
| -def parallel_config_get_dp_port(self) -> int: |
119 |
| - """ |
120 |
| - We might need to initialize process groups in multiple |
121 |
| - processes that is related to data parallelism, |
122 |
| - e.g. both in the worker and in the engine, which |
123 |
| - can live in different processes. To avoid port conflicts, we |
124 |
| - increment the port number each time we need to initialize a |
125 |
| - new process group related to data parallelism. |
126 |
| - """ |
127 |
| - answer = self.data_parallel_master_port |
128 |
| - self.data_parallel_master_port += 1 |
129 |
| - import os |
130 |
| - |
131 |
| - # NOTE: Get port from envs directly when using torchrun |
132 |
| - port = int(os.environ.get("MASTER_PORT", answer)) # type: ignore |
133 |
| - return port |
134 |
| - |
135 |
| - |
| 29 | +# All details of those patch please refer to vllm_ascend/patch/platform/patch_0_8_4/patch_distributed.py |
136 | 30 | vllm.distributed.parallel_state.destroy_model_parallel = ascend_destroy_model_parallel
|
137 | 31 | vllm.distributed.stateless_init_torch_distributed_process_group = ascend_stateless_init_torch_distributed_process_group
|
138 | 32 | ParallelConfig.get_next_dp_init_port = parallel_config_get_dp_port
|
0 commit comments