Skip to content

Commit b665443

Browse files
umeiko子阳梅
authored andcommitted
first stage support of eagle 1
Signed-off-by: umeiko <umeko@stu.xmu.edu.cn>
1 parent 15592c0 commit b665443

File tree

6 files changed

+712
-178
lines changed

6 files changed

+712
-178
lines changed

vllm_ascend/envs.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,15 @@
111111
# 1: enable moe_all2all_buffer.
112112
"MOE_ALL2ALL_BUFFER":
113113
lambda: bool(int(os.getenv("MOE_ALL2ALL_BUFFER", '0'))),
114+
# VLLM_ASCEND_ACL_OP_INIT_MODE:
115+
# 0: default, normal init.
116+
# 1: delay init until launch aclops.
117+
# 2: forbid aclops init and launch.
118+
# Find more details at https://gitee.com/ascend/pytorch/pulls/18094
119+
# We set this var default to `1` in vllm-ascend to avoid segment fault when
120+
# enable `pin_memory` while creating a tensor using `torch.tensor`.
121+
"VLLM_ASCEND_ACL_OP_INIT_MODE":
122+
lambda: os.getenv("VLLM_ASCEND_ACL_OP_INIT_MODE", '1'),
114123
# Some models are optimized by vllm ascend. While in some case, e.g. rlhf
115124
# training, the optimized model may not be suitable. In this case, set this
116125
# value to False to disable the optimized model.

vllm_ascend/patch/platform/patch_common/patch_distributed.py

Lines changed: 113 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,17 @@
1717
# Adapted from vllm/model_executor/models/qwen2_vl.py
1818
# This file is a part of the vllm-ascend project.
1919

20+
import torch
2021
import torch
2122
import vllm
2223
import vllm.distributed
2324
import vllm.envs as envs
2425
from torch.distributed import ProcessGroup
26+
from torch.distributed.distributed_c10d import (Backend, PrefixStore,
27+
_get_default_timeout,
28+
is_nccl_available)
29+
from torch.distributed.rendezvous import rendezvous
2530
from vllm.config import ParallelConfig
26-
from vllm.distributed.utils import \
27-
stateless_init_torch_distributed_process_group
28-
29-
from vllm_ascend.utils import NullHandle, is_310p
3031

3132

3233
def ascend_destroy_model_parallel():
@@ -48,6 +49,112 @@ def ascend_destroy_model_parallel():
4849
destory_ascend_model_parallel()
4950

5051

52+
def stateless_init_torch_distributed_process_group(
53+
host: str, port: int, rank: int, world_size: int,
54+
backend: str) -> ProcessGroup:
55+
"""
56+
A replacement for `torch.distributed.init_process_group` that does not
57+
pollute the global state. The created ProcessGroup object can be used for
58+
some operations such as `allreduce`, because it does not depend on the
59+
global rank. However, some operations such as `broadcast` cannot be used
60+
because it depends on the global rank.
61+
62+
# TODO: ask for help from PyTorch team if we need the `broadcast` operation.
63+
64+
This function is useful when we are not sure about the total number of
65+
processes in the process group. For example, we may have process
66+
1, 2, ..., 8 who want to communicate, and process 9 might be the same
67+
process as process 1, or it might be a different process; process 10
68+
might be the same process as process 5, or it might be a different process.
69+
In this case, how can we reliably form a communication channel within
70+
process 9 and 10, without affecting the communication channel within
71+
process 1, 2, ..., 8?
72+
73+
One possible solution is to figure out if process 9 and 10 are the same
74+
as process 1 and 5 beforehand, and then form a communication channel
75+
based on the information, adjusting the ranks and world_size etc. However,
76+
figuring out the information is not always easy, and it will interfere
77+
with the main communication channel.
78+
79+
Our solution is to always form a communication channel with process 1, 2,
80+
..., 8, and then use this function to form another communication channel
81+
with process 9 and 10. This way, regardless of whether process 9 and 10
82+
are the same as process 1 and 5, the main communication channel is
83+
always formed with process 1, 2, ..., 8, and the additional communication
84+
channel is formed with process 9 and 10.
85+
"""
86+
init_method = f"tcp://{host}:{port}"
87+
backend = Backend(backend) # it is basically string
88+
timeout = _get_default_timeout(backend)
89+
90+
store, rank, world_size = next(
91+
rendezvous(init_method, rank, world_size, timeout=timeout))
92+
store.set_timeout(timeout)
93+
94+
group_rank = rank
95+
group_size = world_size
96+
97+
# Use a PrefixStore to avoid accidental overrides of keys used by
98+
# different systems (e.g. RPC) in case the store is multi-tenant.
99+
prefix_store = PrefixStore(init_method, store)
100+
101+
# TODO(Yizhou): The reason we need to set options while vllm does not
102+
# seems to be related to the version of PyTorch. In the latest version,
103+
# there is no need to set options. While in the older version, 2.5.1
104+
# specifically, we need to set options.
105+
options = ProcessGroup.Options(backend=backend)
106+
pg: ProcessGroup = ProcessGroup(
107+
prefix_store,
108+
group_rank,
109+
group_size,
110+
options,
111+
)
112+
if backend == "gloo":
113+
from torch.distributed.distributed_c10d import ProcessGroupGloo
114+
backend_class = ProcessGroupGloo(prefix_store,
115+
group_rank,
116+
group_size,
117+
timeout=timeout)
118+
backend_type = ProcessGroup.BackendType.GLOO
119+
device = torch.device("cpu")
120+
elif backend == "nccl":
121+
assert is_nccl_available()
122+
from torch.distributed.distributed_c10d import ProcessGroupNCCL
123+
124+
backend_options = ProcessGroupNCCL.Options()
125+
backend_options._timeout = timeout
126+
127+
backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size,
128+
backend_options)
129+
backend_type = ProcessGroup.BackendType.NCCL
130+
device = torch.device("cuda")
131+
elif backend == "hccl":
132+
from torch.distributed import is_hccl_available
133+
assert is_hccl_available()
134+
from torch_npu._C._distributed_c10d import ProcessGroupHCCL
135+
backend_options = ProcessGroupHCCL.Options()
136+
backend_options._timeout = timeout
137+
backend_class = ProcessGroupHCCL(prefix_store, group_rank, group_size,
138+
backend_options)
139+
device = torch.device("npu")
140+
backend_class._set_sequence_number_for_group()
141+
backend_type = ProcessGroup.BackendType.CUSTOM
142+
pg._register_backend(device, backend_type, backend_class)
143+
return pg
144+
else:
145+
raise RuntimeError(f"Unsupported torch distributed backend: {backend}")
146+
147+
# TODO(Yizhou): Like we mentioned above, _set_default_backend is not
148+
# implemented in the 2.5.1 version of PyTorch. But we need to set it
149+
# after the latest version is released.
150+
# pg._set_default_backend(backend_type)
151+
backend_class._set_sequence_number_for_group()
152+
153+
pg._register_backend(device, backend_type, backend_class)
154+
155+
return pg
156+
157+
51158
def parallel_config_get_dp_port(self) -> int:
52159
"""
53160
We might need to initialize process groups in multiple
@@ -65,7 +172,7 @@ def parallel_config_get_dp_port(self) -> int:
65172
return port
66173

67174

68-
def stateless_init_dp_group(self) -> "ProcessGroup":
175+
def ascend_stateless_init_dp_group(self) -> "ProcessGroup":
69176
# TODO(Yizhou): Currently we have to set the backend to gloo
70177
# because in vllm.config.ParallelConfig.has_unfinished_dp the
71178
# device is set to cpu. We need to fix this in the future.
@@ -83,71 +190,4 @@ def stateless_init_dp_group(self) -> "ProcessGroup":
83190

84191
vllm.distributed.parallel_state.destroy_model_parallel = ascend_destroy_model_parallel
85192
ParallelConfig.get_next_dp_init_port = parallel_config_get_dp_port
86-
ParallelConfig.stateless_init_dp_group = stateless_init_dp_group
87-
88-
89-
def communication_adaptation_310p():
90-
91-
def broadcast310p(tensor, src, group=None, async_op=False):
92-
rank = torch.distributed.get_rank(group)
93-
world_size = torch.distributed.get_world_size(group)
94-
tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
95-
tensor_list[rank] = tensor
96-
torch.distributed.all_gather(tensor_list, tensor, group=group)
97-
tensor[...] = tensor_list[src]
98-
if async_op:
99-
return NullHandle()
100-
else:
101-
return None
102-
103-
torch.distributed.broadcast = broadcast310p
104-
torch.distributed.distributed_c10d.broadcast = broadcast310p
105-
106-
def all_reduce_wrapper_310p(fn):
107-
108-
def all_reduce(
109-
tensor,
110-
op=torch.distributed.ReduceOp.SUM,
111-
group=None,
112-
async_op=False,
113-
):
114-
if tensor.dtype != torch.int64:
115-
return fn(tensor, op, group, async_op)
116-
rank = torch.distributed.get_rank(group)
117-
world_size = torch.distributed.get_world_size(group)
118-
tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
119-
tensor_list[rank] = tensor
120-
torch.distributed.all_gather(tensor_list, tensor, group=group)
121-
if op == torch.distributed.ReduceOp.SUM:
122-
return torch.stack(tensor_list).sum(0)
123-
elif op == torch.distributed.ReduceOp.MAX:
124-
return torch.tensor(
125-
torch.stack(tensor_list).cpu().numpy().max(0),
126-
device=tensor.device,
127-
)
128-
else:
129-
raise RuntimeError(f"not implement op {op}")
130-
131-
return all_reduce
132-
133-
torch.distributed.all_reduce = all_reduce_wrapper_310p(
134-
torch.distributed.all_reduce)
135-
torch.distributed.distributed_c10d.all_reduce = all_reduce_wrapper_310p(
136-
torch.distributed.distributed_c10d.all_reduce)
137-
138-
def reduce_scatter_310p(output_tensor, input_tensor, group=None):
139-
rank = torch.distributed.get_rank(group)
140-
world_size = torch.distributed.get_world_size(group)
141-
torch.distributed.all_reduce(input_tensor,
142-
torch.distributed.ReduceOp.SUM,
143-
group,
144-
async_op=False)
145-
interval = input_tensor.shape[0] // world_size
146-
output_tensor[:] = input_tensor[rank * interval:(rank + 1) * interval]
147-
148-
torch.distributed._reduce_scatter_base = reduce_scatter_310p
149-
torch.distributed.distributed_c10d._reduce_scatter_base = reduce_scatter_310p
150-
151-
152-
if is_310p():
153-
communication_adaptation_310p()
193+
ParallelConfig.stateless_init_dp_group = ascend_stateless_init_dp_group

vllm_ascend/platform.py

Lines changed: 0 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,10 @@
1717

1818
import gc
1919
import os
20-
from datetime import timedelta
2120
from typing import TYPE_CHECKING, Optional, Tuple
2221

2322
import torch
2423
import vllm.envs as envs
25-
from torch.distributed import ProcessGroup
26-
from torch.distributed.distributed_c10d import PrefixStore
2724
from vllm.logger import logger
2825
from vllm.platforms import Platform, PlatformEnum
2926

@@ -263,45 +260,3 @@ def get_piecewise_backend_cls(cls) -> str:
263260
Get piecewise backend class for piecewise graph.
264261
"""
265262
return "vllm_ascend.compilation.piecewise_backend.NPUPiecewiseBackend" # noqa
266-
267-
@classmethod
268-
def stateless_init_device_torch_dist_pg(
269-
cls,
270-
backend: str,
271-
prefix_store: PrefixStore,
272-
group_rank: int,
273-
group_size: int,
274-
timeout: timedelta,
275-
) -> ProcessGroup:
276-
from torch.distributed import is_hccl_available
277-
from torch_npu._C._distributed_c10d import ProcessGroupHCCL
278-
279-
assert is_hccl_available()
280-
281-
# TODO(Yizhou): The reason we need to set options while vllm does not
282-
# seems to be related to the version of PyTorch. In the latest version,
283-
# there is no need to set options. While in the older version, 2.5.1
284-
# specifically, we need to set options.
285-
options = ProcessGroup.Options(backend=backend)
286-
pg: ProcessGroup = ProcessGroup(
287-
prefix_store,
288-
group_rank,
289-
group_size,
290-
options,
291-
)
292-
293-
backend_options = ProcessGroupHCCL.Options()
294-
backend_options._timeout = timeout
295-
296-
backend_class = ProcessGroupHCCL(prefix_store, group_rank, group_size,
297-
backend_options)
298-
device = torch.device("npu")
299-
# TODO(Yizhou): Like we mentioned above, _set_default_backend is not
300-
# implemented in the 2.5.1 version of PyTorch. But we need to set it
301-
# after the latest version is released.
302-
# pg._set_default_backend(backend_type)
303-
backend_class._set_sequence_number_for_group()
304-
backend_type = ProcessGroup.BackendType.CUSTOM
305-
306-
pg._register_backend(device, backend_type, backend_class)
307-
return pg

0 commit comments

Comments
 (0)