Skip to content

Commit b05ed25

Browse files
ilmarkovilmarkov
authored andcommitted
Add pytorch symm memory communicator
Signed-off-by: ilmarkov <imarkov@redhat.com>
1 parent 37a7d5d commit b05ed25

File tree

4 files changed

+125
-0
lines changed

4 files changed

+125
-0
lines changed

vllm/distributed/device_communicators/cuda_communicator.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ def __init__(self,
4444
PyNcclCommunicator)
4545
from vllm.distributed.device_communicators.quick_all_reduce import (
4646
QuickAllReduce)
47+
from vllm.distributed.device_communicators.symm_mem import (
48+
SymmMemCommunicator)
4749

4850
self.pynccl_comm: Optional[PyNcclCommunicator] = None
4951
if use_pynccl and self.world_size > 1:
@@ -54,6 +56,7 @@ def __init__(self,
5456

5557
self.ca_comm: Optional[CustomAllreduce] = None
5658
self.qr_comm: Optional[QuickAllReduce] = None
59+
self.symm_mem_comm: Optional[SymmMemCommunicator] = None
5760
if use_custom_allreduce and self.world_size > 1:
5861
# Initialize a custom fast all-reduce implementation.
5962
self.ca_comm = CustomAllreduce(
@@ -69,6 +72,12 @@ def __init__(self,
6972
# currently be an MI300 series.
7073
self.qr_comm = QuickAllReduce(group=self.cpu_group,
7174
device=self.device)
75+
if envs.VLLM_USE_SYMM_MEM and current_platform.is_cuda():
76+
self.symm_mem_comm = SymmMemCommunicator(
77+
group=self.cpu_group,
78+
device=self.device,
79+
)
80+
7281
if self.use_all2all:
7382
all2all_backend = envs.VLLM_ALL2ALL_BACKEND
7483
if all2all_backend == "naive":
@@ -105,6 +114,12 @@ def all_reduce(self, input_):
105114
out = ca_comm.custom_all_reduce(input_)
106115
assert out is not None
107116
return out
117+
symm_mem_comm = self.symm_mem_comm
118+
if symm_mem_comm is not None and not symm_mem_comm.disabled and \
119+
symm_mem_comm.should_use_symm_mem(input_):
120+
out = symm_mem_comm.all_reduce(input_)
121+
assert out is not None
122+
return out
108123
pynccl_comm = self.pynccl_comm
109124
assert pynccl_comm is not None
110125
out = pynccl_comm.all_reduce(input_)

vllm/distributed/device_communicators/custom_all_reduce.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,14 @@ def is_weak_contiguous(inp: torch.Tensor):
4949
class CustomAllreduce:
5050

5151
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
52+
MB = 1024 * 1024
53+
# Max sizes for each world size in case symmetric memory is available
54+
_MAX_SIZES = {
55+
2: MB, # 1 MB
56+
4: MB, # 1 MB
57+
6: MB // 2, # 512 KB
58+
8: MB // 2, # 512 KB
59+
}
5260

5361
# max_size: max supported allreduce size
5462
def __init__(self,
@@ -109,6 +117,8 @@ def __init__(self,
109117
# now `device` is a `torch.device` object
110118
assert isinstance(device, torch.device)
111119
self.device = device
120+
if current_platform.is_cuda() and envs.VLLM_USE_SYMM_MEM:
121+
max_size = CustomAllreduce._MAX_SIZES[world_size]
112122

113123
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
114124
if cuda_visible_devices:
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from typing import Optional, Union
4+
5+
import torch
6+
import torch.distributed as dist
7+
from torch.distributed import ProcessGroup
8+
9+
from vllm.logger import init_logger
10+
11+
try:
12+
import torch.distributed._symmetric_memory as torch_symm_mem
13+
14+
symm_mem_available = True
15+
except ImportError:
16+
symm_mem_available = False
17+
18+
logger = init_logger(__name__)
19+
20+
21+
class SymmMemCommunicator:
22+
MB = 1024 * 1024
23+
# Max sizes for each world size
24+
_MAX_SIZES = {
25+
2: 8 * MB,
26+
4: 32 * MB,
27+
6: 64 * MB,
28+
8: 256 * MB,
29+
}
30+
31+
def __init__(self, group: ProcessGroup, device: Union[int, str,
32+
torch.device]):
33+
self.disabled = True
34+
35+
if not symm_mem_available:
36+
return
37+
if isinstance(device, int):
38+
device = torch.device(f"cuda:{device}")
39+
elif isinstance(device, str):
40+
device = torch.device(device)
41+
torch.cuda.set_device(device)
42+
self.dtype = torch.bfloat16
43+
self.device = device
44+
self.group = group
45+
self.world_size = dist.get_world_size(self.group)
46+
if self.world_size not in self._MAX_SIZES:
47+
logger.warning(
48+
"SymmMemCommunicator: World size %d not supported, "
49+
"communicator is not available.",
50+
self.world_size,
51+
)
52+
return
53+
self.buffer = torch_symm_mem.empty(
54+
self._MAX_SIZES[self.world_size] // self.dtype.itemsize,
55+
device=self.device,
56+
dtype=self.dtype,
57+
)
58+
handle = torch_symm_mem.rendezvous(self.buffer, self.group.group_name)
59+
if handle.multicast_ptr == 0:
60+
logger.warning("SymmMemCommunicator: symmetric memory "
61+
"multicast operations are not supported.")
62+
return
63+
self.disabled = False
64+
65+
def should_use_symm_mem(self, inp: torch.Tensor):
66+
if self.disabled:
67+
return False
68+
if inp.dtype != self.dtype:
69+
return False
70+
inp_size = inp.numel() * inp.element_size()
71+
if inp_size % 4 != 0:
72+
return False
73+
return inp_size <= self._MAX_SIZES[self.world_size]
74+
75+
def all_reduce(
76+
self,
77+
inp: torch.Tensor,
78+
*,
79+
out: Optional[torch.Tensor] = None) -> Optional[torch.Tensor]:
80+
if not self.should_use_symm_mem(inp):
81+
return None
82+
if out is None:
83+
out = torch.empty_like(inp)
84+
self.buffer[:inp.numel()].copy_(inp.view(-1))
85+
if self.world_size in [2, 4]:
86+
# Use two-shot all-reduce for 2 and 4 GPUs
87+
torch.ops.symm_mem.two_shot_all_reduce_(self.buffer[:inp.numel()],
88+
"sum",
89+
self.group.group_name)
90+
else:
91+
# Use multi-mem all-reduce for 6 and 8 GPUs
92+
torch.ops.symm_mem.multimem_all_reduce_(self.buffer[:inp.numel()],
93+
"sum",
94+
self.group.group_name)
95+
out.copy_(self.buffer[:inp.numel()].view(out.shape))
96+
return out

vllm/envs.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@
139139
VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True
140140
VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: Optional[int] = None
141141
VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 120
142+
VLLM_USE_SYMM_MEM: bool = False
142143

143144

144145
def get_default_cache_root():
@@ -964,6 +965,9 @@ def get_vllm_port() -> Optional[int]:
964965
# If set to 1, use the TRTLLM Decode Attention backend in flashinfer.
965966
"VLLM_USE_TRTLLM_DECODE_ATTENTION":
966967
lambda: os.getenv("VLLM_USE_TRTLLM_DECODE_ATTENTION", None),
968+
# Whether to use pytorch symmetric memory for allreduce
969+
"VLLM_USE_SYMM_MEM":
970+
lambda: bool(int(os.getenv("VLLM_USE_SYMM_MEM", "0"))),
967971
}
968972

969973
# --8<-- [end:env-vars-definition]

0 commit comments

Comments
 (0)