Skip to content

[PERF] Symmetric memory allreduce #20759

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/design/v1/multiprocessing.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ The `multiproc_xpu_executor` forces the use of `spawn`.

There are other miscellaneous places hard-coding the use of `spawn`:

- <https://github.com/vllm-project/vllm/blob/d05f88679bedd73939251a17c3d785a354b2946c/vllm/distributed/device_communicators/custom_all_reduce_utils.py#L135>
- <https://github.com/vllm-project/vllm/blob/d05f88679bedd73939251a17c3d785a354b2946c/vllm/distributed/device_communicators/all_reduce_utils.py#L135>
- <https://github.com/vllm-project/vllm/blob/d05f88679bedd73939251a17c3d785a354b2946c/vllm/entrypoints/openai/api_server.py#L184>

Related PRs:
Expand Down
2 changes: 1 addition & 1 deletion tools/check_pickle_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
'vllm/distributed/utils.py',
'vllm/distributed/parallel_state.py',
'vllm/engine/multiprocessing/client.py',
'vllm/distributed/device_communicators/custom_all_reduce_utils.py',
'vllm/distributed/device_communicators/all_reduce_utils.py',
'vllm/distributed/device_communicators/shm_broadcast.py',
'vllm/engine/multiprocessing/engine.py',
'benchmarks/kernels/graph_machete_bench.py',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,48 @@
import vllm.envs as envs
from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
from vllm.logger import init_logger
from vllm.platforms import DeviceCapability
from vllm.utils import (cuda_device_count_stateless,
update_environment_variables)

logger = init_logger(__name__)

MiB = 1024 * 1024

# Max size for each world size in case symmetric memory is available
# For different SM architectures

# TODO(ilia): update max sizes for 6, 8 for sm90
CUSTOM_ALL_REDUCE_MAX_SIZES = {
DeviceCapability(9, 0): {
2: 64 * MiB, # 64 MB
4: MiB, # 1 MB
6: MiB, # 1 MB
8: MiB // 2, # 512 KB
},
DeviceCapability(10, 0): {
2: 2 * MiB, # 2 MB
4: 2 * MiB, # 2 MB
6: 8 * MiB, # 8 MB
8: 8 * MiB, # 8 MB
}
}

SYMM_MEM_ALL_REDUCE_MAX_SIZES = {
DeviceCapability(9, 0): {
2: 64 * MiB, # 64 MB
4: 32 * MiB, # 32 MB
6: 128 * MiB, # 128 MB
8: 128 * MiB, # 128 MB
},
DeviceCapability(10, 0): {
2: 8 * MiB, # 8 MB
4: 32 * MiB, # 32 MB
6: 128 * MiB, # 128 MB
8: 128 * MiB, # 128 MB
}
}


def producer(batch_src: Sequence[int],
producer_queue,
Expand Down
15 changes: 15 additions & 0 deletions vllm/distributed/device_communicators/cuda_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def __init__(self,
PyNcclCommunicator)
from vllm.distributed.device_communicators.quick_all_reduce import (
QuickAllReduce)
from vllm.distributed.device_communicators.symm_mem import (
SymmMemCommunicator)

self.pynccl_comm: Optional[PyNcclCommunicator] = None
if use_pynccl and self.world_size > 1:
Expand All @@ -54,6 +56,7 @@ def __init__(self,

self.ca_comm: Optional[CustomAllreduce] = None
self.qr_comm: Optional[QuickAllReduce] = None
self.symm_mem_comm: Optional[SymmMemCommunicator] = None
if use_custom_allreduce and self.world_size > 1:
# Initialize a custom fast all-reduce implementation.
self.ca_comm = CustomAllreduce(
Expand All @@ -69,6 +72,12 @@ def __init__(self,
# currently be an MI300 series.
self.qr_comm = QuickAllReduce(group=self.cpu_group,
device=self.device)
if envs.VLLM_ALLREDUCE_USE_SYMM_MEM and current_platform.is_cuda():
self.symm_mem_comm = SymmMemCommunicator(
group=self.cpu_group,
device=self.device,
)

if self.use_all2all:
all2all_backend = envs.VLLM_ALL2ALL_BACKEND
if all2all_backend == "naive":
Expand Down Expand Up @@ -105,6 +114,12 @@ def all_reduce(self, input_):
out = ca_comm.custom_all_reduce(input_)
assert out is not None
return out
symm_mem_comm = self.symm_mem_comm
if symm_mem_comm is not None and not symm_mem_comm.disabled and \
symm_mem_comm.should_use_symm_mem(input_):
Comment on lines +118 to +119
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The check not symm_mem_comm.disabled is redundant because should_use_symm_mem already performs this check. Removing the redundant check will make the code more concise.

Suggested change
if symm_mem_comm is not None and not symm_mem_comm.disabled and \
symm_mem_comm.should_use_symm_mem(input_):
if symm_mem_comm is not None and \
symm_mem_comm.should_use_symm_mem(input_):

out = symm_mem_comm.all_reduce(input_)
assert out is not None
return out
pynccl_comm = self.pynccl_comm
assert pynccl_comm is not None
out = pynccl_comm.all_reduce(input_)
Expand Down
10 changes: 8 additions & 2 deletions vllm/distributed/device_communicators/custom_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.distributed.device_communicators.custom_all_reduce_utils import (
gpu_p2p_access_check)
from vllm.distributed.device_communicators.all_reduce_utils import (
CUSTOM_ALL_REDUCE_MAX_SIZES, gpu_p2p_access_check)
from vllm.distributed.parallel_state import in_the_same_node_as
from vllm.logger import init_logger
from vllm.platforms import current_platform
Expand Down Expand Up @@ -109,6 +109,12 @@ def __init__(self,
# now `device` is a `torch.device` object
assert isinstance(device, torch.device)
self.device = device
device_capability = current_platform.get_device_capability()
if (current_platform.is_cuda() and envs.VLLM_ALLREDUCE_USE_SYMM_MEM
and device_capability in CUSTOM_ALL_REDUCE_MAX_SIZES):
max_size = min(
CUSTOM_ALL_REDUCE_MAX_SIZES[device_capability][world_size],
max_size)

cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
if cuda_visible_devices:
Expand Down
108 changes: 108 additions & 0 deletions vllm/distributed/device_communicators/symm_mem.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional, Union

import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup

from vllm.distributed.device_communicators.all_reduce_utils import (
SYMM_MEM_ALL_REDUCE_MAX_SIZES)
from vllm.logger import init_logger
from vllm.platforms import current_platform

try:
import torch.distributed._symmetric_memory as torch_symm_mem

symm_mem_available = True
except ImportError:
symm_mem_available = False

logger = init_logger(__name__)


class SymmMemCommunicator:

def __init__(self, group: ProcessGroup, device: Union[int, str,
torch.device]):
self.disabled = True

if not symm_mem_available:
return

if not current_platform.is_cuda():
logger.warning("SymmMemCommunicator: symmetric "
"memory is not available.")
return
if isinstance(device, int):
device = torch.device(f"cuda:{device}")
elif isinstance(device, str):
device = torch.device(device)
torch.cuda.set_device(device)
self.dtype = torch.bfloat16
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The SymmMemCommunicator is hardcoded to use torch.bfloat16, limiting its use with models using other dtypes. Consider initializing buffers based on the input tensor's dtype during the first all_reduce call to increase flexibility.

self.device = device
self.group = group
self.world_size = dist.get_world_size(self.group)
device_capability = current_platform.get_device_capability()

if device_capability not in SYMM_MEM_ALL_REDUCE_MAX_SIZES:
logger.warning(
"SymmMemCommunicator: Device capability %s not supported, "
"communicator is not available.",
device_capability,
)
return
if self.world_size not in SYMM_MEM_ALL_REDUCE_MAX_SIZES[
device_capability]:
logger.warning(
"SymmMemCommunicator: World size %d not supported, "
"communicator is not available.",
self.world_size,
)
return
self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[device_capability][
self.world_size]
self.buffer = torch_symm_mem.empty(
self.max_size // self.dtype.itemsize,
device=self.device,
dtype=self.dtype,
)
handle = torch_symm_mem.rendezvous(self.buffer, self.group.group_name)
if handle.multicast_ptr == 0:
logger.warning("SymmMemCommunicator: symmetric memory "
"multicast operations are not supported.")
return
self.disabled = False

def should_use_symm_mem(self, inp: torch.Tensor):
if self.disabled:
return False
if inp.dtype != self.dtype:
return False
inp_size = inp.numel() * inp.element_size()
if inp_size % 4 != 0:
return False
return inp_size <= self.max_size

def all_reduce(
self,
inp: torch.Tensor,
*,
out: Optional[torch.Tensor] = None) -> Optional[torch.Tensor]:
if not self.should_use_symm_mem(inp):
return None
if out is None:
out = torch.empty_like(inp)
self.buffer[:inp.numel()].copy_(inp.view(-1))
if self.world_size in [2, 4]:
# Use two-shot all-reduce for 2 and 4 GPUs
torch.ops.symm_mem.two_shot_all_reduce_(self.buffer[:inp.numel()],
"sum",
self.group.group_name)
else:
# Use multi-mem all-reduce for 6 and 8 GPUs
torch.ops.symm_mem.multimem_all_reduce_(self.buffer[:inp.numel()],
"sum",
self.group.group_name)
out.copy_(self.buffer[:inp.numel()].view(out.shape))
return out
9 changes: 7 additions & 2 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@
VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True
VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: Optional[int] = None
VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 120
VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False


def get_default_cache_root():
Expand Down Expand Up @@ -620,7 +621,7 @@ def get_vllm_port() -> Optional[int]:
("1", "true")),

# By default, vLLM will check the peer-to-peer capability itself,
# in case of broken drivers. See https://github.com/vllm-project/vllm/blob/a9b15c606fea67a072416ea0ea115261a2756058/vllm/distributed/device_communicators/custom_all_reduce_utils.py#L101-L108 for details. # noqa
# in case of broken drivers. See https://github.com/vllm-project/vllm/blob/a9b15c606fea67a072416ea0ea115261a2756058/vllm/distributed/device_communicators/all_reduce_utils.py#L101-L108 for details. # noqa
# If this env var is set to 1, vLLM will skip the peer-to-peer check,
# and trust the driver's peer-to-peer capability report.
"VLLM_SKIP_P2P_CHECK":
Expand Down Expand Up @@ -961,7 +962,11 @@ def get_vllm_port() -> Optional[int]:
# consumer. This is only applicable when using NixlConnector in a
# disaggregated decode-prefill setup.
"VLLM_NIXL_ABORT_REQUEST_TIMEOUT":
lambda: int(os.getenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", "120"))
lambda: int(os.getenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", "120")),

# Whether to use pytorch symmetric memory for allreduce
"VLLM_ALLREDUCE_USE_SYMM_MEM":
lambda: bool(int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "0"))),
}

# --8<-- [end:env-vars-definition]
Expand Down