Skip to content

[PERF] Symmetric memory allreduce #20759

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions vllm/distributed/device_communicators/cuda_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def __init__(self,
PyNcclCommunicator)
from vllm.distributed.device_communicators.quick_all_reduce import (
QuickAllReduce)
from vllm.distributed.device_communicators.symm_mem import (
SymmMemCommunicator)

self.pynccl_comm: Optional[PyNcclCommunicator] = None
if use_pynccl and self.world_size > 1:
Expand All @@ -54,6 +56,7 @@ def __init__(self,

self.ca_comm: Optional[CustomAllreduce] = None
self.qr_comm: Optional[QuickAllReduce] = None
self.symm_mem_comm: Optional[SymmMemCommunicator] = None
if use_custom_allreduce and self.world_size > 1:
# Initialize a custom fast all-reduce implementation.
self.ca_comm = CustomAllreduce(
Expand All @@ -69,6 +72,12 @@ def __init__(self,
# currently be an MI300 series.
self.qr_comm = QuickAllReduce(group=self.cpu_group,
device=self.device)
if envs.VLLM_ALLREDUCE_USE_SYMM_MEM and current_platform.is_cuda():
self.symm_mem_comm = SymmMemCommunicator(
group=self.cpu_group,
device=self.device,
)

if self.use_all2all:
all2all_backend = envs.VLLM_ALL2ALL_BACKEND
if all2all_backend == "naive":
Expand Down Expand Up @@ -105,6 +114,12 @@ def all_reduce(self, input_):
out = ca_comm.custom_all_reduce(input_)
assert out is not None
return out
symm_mem_comm = self.symm_mem_comm
if symm_mem_comm is not None and not symm_mem_comm.disabled and \
symm_mem_comm.should_use_symm_mem(input_):
Comment on lines +118 to +119
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The check not symm_mem_comm.disabled is redundant because should_use_symm_mem already performs this check. Removing the redundant check will make the code more concise.

Suggested change
if symm_mem_comm is not None and not symm_mem_comm.disabled and \
symm_mem_comm.should_use_symm_mem(input_):
if symm_mem_comm is not None and \
symm_mem_comm.should_use_symm_mem(input_):

out = symm_mem_comm.all_reduce(input_)
assert out is not None
return out
pynccl_comm = self.pynccl_comm
assert pynccl_comm is not None
out = pynccl_comm.all_reduce(input_)
Expand Down
10 changes: 10 additions & 0 deletions vllm/distributed/device_communicators/custom_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,14 @@ def is_weak_contiguous(inp: torch.Tensor):
class CustomAllreduce:

_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
MiB = 1024 * 1024
# Max sizes for each world size in case symmetric memory is available
_MAX_SIZES = {
2: 2 * MiB, # 1 MB
4: 2 * MiB, # 1 MB
6: MiB, # 512 KB
8: MiB // 2, # 512 KB
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comments for the sizes in the _MAX_SIZES dictionary are inaccurate. Update the comments to reflect the correct sizes.

Suggested change
# Max sizes for each world size in case symmetric memory is available
_MAX_SIZES = {
2: 2 * MiB, # 1 MB
4: 2 * MiB, # 1 MB
6: MiB, # 512 KB
8: MiB // 2, # 512 KB
}
# Max sizes for each world size in case symmetric memory is available
_MAX_SIZES = {
2: 2 * MiB, # 2 MiB
4: 2 * MiB, # 2 MiB
6: MiB, # 1 MiB
8: MiB // 2, # 512 KiB
}


# max_size: max supported allreduce size
def __init__(self,
Expand Down Expand Up @@ -109,6 +117,8 @@ def __init__(self,
# now `device` is a `torch.device` object
assert isinstance(device, torch.device)
self.device = device
if current_platform.is_cuda() and envs.VLLM_ALLREDUCE_USE_SYMM_MEM:
max_size = CustomAllreduce._MAX_SIZES[world_size]

cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
if cuda_visible_devices:
Expand Down
96 changes: 96 additions & 0 deletions vllm/distributed/device_communicators/symm_mem.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional, Union

import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup

from vllm.logger import init_logger

try:
import torch.distributed._symmetric_memory as torch_symm_mem

symm_mem_available = True
except ImportError:
symm_mem_available = False

logger = init_logger(__name__)


class SymmMemCommunicator:
MiB = 1024 * 1024
# Max sizes for each world size
_MAX_SIZES = {
2: 8 * MiB,
4: 32 * MiB,
6: 128 * MiB,
8: 128 * MiB,
}

def __init__(self, group: ProcessGroup, device: Union[int, str,
torch.device]):
self.disabled = True

if not symm_mem_available:
return
if isinstance(device, int):
device = torch.device(f"cuda:{device}")
elif isinstance(device, str):
device = torch.device(device)
torch.cuda.set_device(device)
self.dtype = torch.bfloat16
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The SymmMemCommunicator is hardcoded to use torch.bfloat16, limiting its use with models using other dtypes. Consider initializing buffers based on the input tensor's dtype during the first all_reduce call to increase flexibility.

self.device = device
self.group = group
self.world_size = dist.get_world_size(self.group)
if self.world_size not in self._MAX_SIZES:
logger.warning(
"SymmMemCommunicator: World size %d not supported, "
"communicator is not available.",
self.world_size,
)
return
self.buffer = torch_symm_mem.empty(
self._MAX_SIZES[self.world_size] // self.dtype.itemsize,
device=self.device,
dtype=self.dtype,
)
handle = torch_symm_mem.rendezvous(self.buffer, self.group.group_name)
if handle.multicast_ptr == 0:
logger.warning("SymmMemCommunicator: symmetric memory "
"multicast operations are not supported.")
return
self.disabled = False

def should_use_symm_mem(self, inp: torch.Tensor):
if self.disabled:
return False
if inp.dtype != self.dtype:
return False
inp_size = inp.numel() * inp.element_size()
if inp_size % 4 != 0:
return False
return inp_size <= self._MAX_SIZES[self.world_size]

def all_reduce(
self,
inp: torch.Tensor,
*,
out: Optional[torch.Tensor] = None) -> Optional[torch.Tensor]:
if not self.should_use_symm_mem(inp):
return None
if out is None:
out = torch.empty_like(inp)
self.buffer[:inp.numel()].copy_(inp.view(-1))
if self.world_size in [2, 4]:
# Use two-shot all-reduce for 2 and 4 GPUs
torch.ops.symm_mem.two_shot_all_reduce_(self.buffer[:inp.numel()],
"sum",
self.group.group_name)
else:
# Use multi-mem all-reduce for 6 and 8 GPUs
torch.ops.symm_mem.multimem_all_reduce_(self.buffer[:inp.numel()],
"sum",
self.group.group_name)
out.copy_(self.buffer[:inp.numel()].view(out.shape))
return out
7 changes: 6 additions & 1 deletion vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@
VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True
VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: Optional[int] = None
VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 120
VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False


def get_default_cache_root():
Expand Down Expand Up @@ -961,7 +962,11 @@ def get_vllm_port() -> Optional[int]:
# consumer. This is only applicable when using NixlConnector in a
# disaggregated decode-prefill setup.
"VLLM_NIXL_ABORT_REQUEST_TIMEOUT":
lambda: int(os.getenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", "120"))
lambda: int(os.getenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", "120")),

# Whether to use pytorch symmetric memory for allreduce
"VLLM_ALLREDUCE_USE_SYMM_MEM":
lambda: bool(int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "0"))),
}

# --8<-- [end:env-vars-definition]
Expand Down