Skip to content

Commit 888e406

Browse files
committed
Upd
Signed-off-by: ilmarkov <markovilya197@gmail.com>
1 parent 65fea63 commit 888e406

File tree

3 files changed

+121
-11
lines changed

3 files changed

+121
-11
lines changed
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import random
5+
import typing
6+
7+
import pytest
8+
import ray
9+
import torch
10+
import torch.distributed as dist
11+
12+
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
13+
from vllm.distributed.device_communicators.cuda_communicator import (
14+
CudaCommunicator)
15+
from vllm.distributed.parallel_state import (get_tensor_model_parallel_group,
16+
get_tp_group)
17+
from vllm.platforms import current_platform
18+
19+
from ..utils import (ensure_model_parallel_initialized,
20+
init_test_distributed_environment, multi_process_parallel)
21+
22+
torch.manual_seed(42)
23+
random.seed(44)
24+
25+
test_size_elements = 4 * 1024 * 1024
26+
27+
28+
@ray.remote(num_gpus=1, max_calls=1)
29+
def symm_mem_allreduce(
30+
monkeypatch: pytest.MonkeyPatch,
31+
tp_size,
32+
pp_size,
33+
rank,
34+
distributed_init_port,
35+
):
36+
with monkeypatch.context() as m:
37+
m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
38+
device = torch.device(f"cuda:{rank}")
39+
torch.cuda.set_device(device)
40+
41+
init_test_distributed_environment(tp_size, pp_size, rank,
42+
distributed_init_port)
43+
ensure_model_parallel_initialized(tp_size, pp_size)
44+
45+
dtype = torch.bfloat16
46+
47+
cuda_communicator = typing.cast(CudaCommunicator,
48+
get_tp_group().device_communicator)
49+
symm_mem_comm = cuda_communicator.symm_mem_comm
50+
if symm_mem_comm is None or symm_mem_comm.disabled:
51+
pytest.skip("SymmMemCommunicator is not available or disabled.")
52+
53+
inp_direct_symm_mem = torch.randint(1,
54+
23, (test_size_elements, ),
55+
dtype=dtype,
56+
device=device)
57+
if not symm_mem_comm.should_use_symm_mem(inp_direct_symm_mem):
58+
pytest.skip(
59+
"SymmMemCommunicator isn't used for this world and input size."
60+
)
61+
62+
original_inp_direct_symm_mem = inp_direct_symm_mem.clone()
63+
out_direct_symm_mem = symm_mem_comm.all_reduce(inp_direct_symm_mem)
64+
assert out_direct_symm_mem is not None
65+
66+
group = get_tensor_model_parallel_group().device_group
67+
dist.all_reduce(original_inp_direct_symm_mem, group=group)
68+
torch.testing.assert_close(out_direct_symm_mem,
69+
original_inp_direct_symm_mem,
70+
atol=2.5,
71+
rtol=0.1)
72+
73+
# Test tensor_model_parallel_all_reduce which should use symm_mem
74+
inp_tensor_parallel = torch.randint(-23,
75+
1, (test_size_elements, ),
76+
dtype=dtype,
77+
device=device)
78+
original_inp_tensor_parallel = inp_tensor_parallel.clone()
79+
out_tensor_parallel = tensor_model_parallel_all_reduce(
80+
inp_tensor_parallel)
81+
dist.all_reduce(original_inp_tensor_parallel, group=group)
82+
torch.testing.assert_close(out_tensor_parallel,
83+
original_inp_tensor_parallel,
84+
atol=2.5,
85+
rtol=0.1)
86+
87+
88+
@pytest.mark.skipif(
89+
not current_platform.is_cuda(),
90+
reason="SymmMemAllreduce is only available for CUDA platforms.")
91+
@pytest.mark.parametrize("tp_size", [2, 4])
92+
@pytest.mark.parametrize("pipeline_parallel_size", [1])
93+
@pytest.mark.parametrize("test_target", [symm_mem_allreduce])
94+
def test_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, tp_size,
95+
pipeline_parallel_size, test_target):
96+
world_size = tp_size * pipeline_parallel_size
97+
if world_size > torch.cuda.device_count():
98+
pytest.skip("Not enough GPUs to run the test.")
99+
100+
# Enable SymmMemCommunicator
101+
monkeypatch.setenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "1")
102+
103+
multi_process_parallel(monkeypatch, tp_size, pipeline_parallel_size,
104+
test_target)

vllm/distributed/device_communicators/all_reduce_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
CUSTOM_ALL_REDUCE_MAX_SIZES = {
3131
"9.0": {
3232
2: 64 * MiB, # 64 MB
33-
4: 32 * MiB, # 32 MB
33+
4: 1 * MiB, # 1 MB
3434
6: MiB // 2, # 512 KB
3535
8: MiB // 4, # 256 KB
3636
},

vllm/distributed/device_communicators/symm_mem.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,12 @@
2323

2424
class SymmMemCommunicator:
2525

26+
# World sizes where multi-mem all-reduce performs the best
27+
_WORLD_SIZES_MULTIMEM = {
28+
"9.0": [4, 6, 8],
29+
"10.0": [6, 8],
30+
}
31+
2632
def __init__(self, group: ProcessGroup, device: Union[int, str,
2733
torch.device]):
2834
self.disabled = True
@@ -43,25 +49,26 @@ def __init__(self, group: ProcessGroup, device: Union[int, str,
4349
self.device = device
4450
self.group = group
4551
self.world_size = dist.get_world_size(self.group)
46-
device_capability = current_platform.get_device_capability(
52+
self.device_capability = current_platform.get_device_capability(
4753
).as_version_str()
4854

49-
if device_capability not in SYMM_MEM_ALL_REDUCE_MAX_SIZES:
55+
if self.device_capability not in SYMM_MEM_ALL_REDUCE_MAX_SIZES \
56+
or self.device_capability not in self._WORLD_SIZES_MULTIMEM:
5057
logger.warning(
5158
"SymmMemCommunicator: Device capability %s not supported, "
5259
"communicator is not available.",
53-
device_capability,
60+
self.device_capability,
5461
)
5562
return
5663
if self.world_size not in SYMM_MEM_ALL_REDUCE_MAX_SIZES[
57-
device_capability]:
64+
self.device_capability]:
5865
logger.warning(
5966
"SymmMemCommunicator: World size %d not supported, "
6067
"communicator is not available.",
6168
self.world_size,
6269
)
6370
return
64-
self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[device_capability][
71+
self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability][
6572
self.world_size]
6673
self.buffer = torch_symm_mem.empty(
6774
self.max_size // self.dtype.itemsize,
@@ -95,14 +102,13 @@ def all_reduce(
95102
if out is None:
96103
out = torch.empty_like(inp)
97104
self.buffer[:inp.numel()].copy_(inp.view(-1))
98-
if self.world_size in [2, 4]:
99-
# Use two-shot all-reduce for 2 and 4 GPUs
100-
torch.ops.symm_mem.two_shot_all_reduce_(self.buffer[:inp.numel()],
105+
if self.world_size in self._WORLD_SIZES_MULTIMEM[
106+
self.device_capability]:
107+
torch.ops.symm_mem.multimem_all_reduce_(self.buffer[:inp.numel()],
101108
"sum",
102109
self.group.group_name)
103110
else:
104-
# Use multi-mem all-reduce for 6 and 8 GPUs
105-
torch.ops.symm_mem.multimem_all_reduce_(self.buffer[:inp.numel()],
111+
torch.ops.symm_mem.two_shot_all_reduce_(self.buffer[:inp.numel()],
106112
"sum",
107113
self.group.group_name)
108114
out.copy_(self.buffer[:inp.numel()].view(out.shape))

0 commit comments

Comments
 (0)