Skip to content

Commit 95e7aa4

Browse files
authored
[Platform] format platform to make it more clear (#610)
Platform should only contain the function that based from vllm. This PR move the unrelated function to the right place to make platform more clear. Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
1 parent b917361 commit 95e7aa4

File tree

3 files changed

+17
-33
lines changed

3 files changed

+17
-33
lines changed

vllm_ascend/patch/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
# inside of the repo, and needs a common interface to destroy them, this patch add the interface of destroy
6363
# platform owned `CoordinatorGroup` to make sure all the CoordinateGroup can be properly destroyed
6464
# How:
65-
# Call platform method `destroy_platform_model_parallel` to destroy all the `CoordinateGroup`
65+
# Call `vllm_ascend.distributed.parallel_state method `destroy_platform_model_parallel` to destroy all the `CoordinateGroup`
6666
# Related PR (if no, explain why): no related PR, we want add this ability into vllm
6767
# Future Plan:
6868
# Remove those patch when vllm merged them
@@ -73,7 +73,7 @@
7373
# call to the `stateless_init_torch_distributed_process_group`, to enable other platform which may support
7474
# stateless process group initialize method
7575
# How:
76-
# Call platform method `platform_has_backend_register` to judge if there is a stateless process group initialize
76+
# rewrite stateless_init_torch_distributed_process_group to judge if there is a stateless process group initialize
7777
# method and call platform method `platform_register_backend` to initialize them
7878
# Related PR (if no, explain why): no related PR, we want add this ability into vllm
7979
# Future Plan:

vllm_ascend/patch/platform/patch_common/patch_distributed.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,9 @@ def ascend_destroy_model_parallel():
4242
if _DP:
4343
_DP.destroy()
4444
_DP = None
45-
from vllm.platforms import current_platform
46-
current_platform.destroy_platform_model_parallel()
45+
from vllm_ascend.distributed.parallel_state import \
46+
destory_ascend_model_parallel
47+
destory_ascend_model_parallel()
4748

4849

4950
def ascend_stateless_init_torch_distributed_process_group(
@@ -100,7 +101,6 @@ def ascend_stateless_init_torch_distributed_process_group(
100101
group_rank,
101102
group_size,
102103
)
103-
from vllm.platforms import current_platform
104104
if backend == "gloo":
105105
from torch.distributed.distributed_c10d import ProcessGroupGloo
106106
backend_class = ProcessGroupGloo(prefix_store,
@@ -120,8 +120,18 @@ def ascend_stateless_init_torch_distributed_process_group(
120120
backend_options)
121121
backend_type = ProcessGroup.BackendType.NCCL
122122
device = torch.device("cuda")
123-
elif current_platform.platform_has_backend_register():
124-
current_platform.platform_register_backend()
123+
elif backend == "hccl":
124+
from torch.distributed import is_hccl_available
125+
assert is_hccl_available()
126+
from torch_npu._C._distributed_c10d import ProcessGroupHCCL
127+
backend_options = ProcessGroupHCCL.Options()
128+
backend_options._timeout = timeout
129+
backend_class = ProcessGroupHCCL(prefix_store, group_rank, group_size,
130+
backend_options)
131+
device = torch.device("npu")
132+
backend_class._set_sequence_number_for_group()
133+
backend_type = ProcessGroup.BackendType.CUSTOM
134+
pg._register_backend(device, backend_type, backend_class)
125135
return pg
126136
else:
127137
raise RuntimeError(f"Unsupported torch distributed backend: {backend}")

vllm_ascend/platform.py

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -226,29 +226,3 @@ def supports_v1(cls, model_config: ModelConfig) -> bool:
226226
model configuration.
227227
"""
228228
return True
229-
230-
@classmethod
231-
def destroy_platform_model_parallel(cls) -> None:
232-
from vllm_ascend.distributed.parallel_state import \
233-
destory_ascend_model_parallel
234-
destory_ascend_model_parallel()
235-
236-
@classmethod
237-
def platform_has_backend_register(cls) -> bool:
238-
return True
239-
240-
@classmethod
241-
def platform_register_backend(cls, pg, prefix_store, group_rank,
242-
group_size, backend_options,
243-
timeout) -> None:
244-
from torch.distributed import ProcessGroup, is_hccl_available
245-
assert is_hccl_available()
246-
from torch_npu._C._distributed_c10d import ProcessGroupHCCL
247-
backend_options = ProcessGroupHCCL.Options()
248-
backend_options._timeout = timeout
249-
backend_class = ProcessGroupHCCL(prefix_store, group_rank, group_size,
250-
backend_options)
251-
device = torch.device("npu")
252-
backend_class._set_sequence_number_for_group()
253-
backend_type = ProcessGroup.BackendType.CUSTOM
254-
pg._register_backend(device, backend_type, backend_class)

0 commit comments

Comments
 (0)