diff --git a/docs/design/v1/multiprocessing.md b/docs/design/v1/multiprocessing.md index 06ebd772585..247072d1cb2 100644 --- a/docs/design/v1/multiprocessing.md +++ b/docs/design/v1/multiprocessing.md @@ -77,7 +77,7 @@ The `multiproc_xpu_executor` forces the use of `spawn`. There are other miscellaneous places hard-coding the use of `spawn`: -- +- - Related PRs: diff --git a/tests/distributed/test_symm_mem_allreduce.py b/tests/distributed/test_symm_mem_allreduce.py new file mode 100644 index 00000000000..1a653c66f18 --- /dev/null +++ b/tests/distributed/test_symm_mem_allreduce.py @@ -0,0 +1,104 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import random +import typing + +import pytest +import ray +import torch +import torch.distributed as dist + +from vllm.distributed.communication_op import tensor_model_parallel_all_reduce +from vllm.distributed.device_communicators.cuda_communicator import ( + CudaCommunicator) +from vllm.distributed.parallel_state import (get_tensor_model_parallel_group, + get_tp_group) +from vllm.platforms import current_platform + +from ..utils import (ensure_model_parallel_initialized, + init_test_distributed_environment, multi_process_parallel) + +torch.manual_seed(42) +random.seed(44) + +test_size_elements = 4 * 1024 * 1024 + + +@ray.remote(num_gpus=1, max_calls=1) +def symm_mem_allreduce( + monkeypatch: pytest.MonkeyPatch, + tp_size, + pp_size, + rank, + distributed_init_port, +): + with monkeypatch.context() as m: + m.delenv("CUDA_VISIBLE_DEVICES", raising=False) + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + + init_test_distributed_environment(tp_size, pp_size, rank, + distributed_init_port) + ensure_model_parallel_initialized(tp_size, pp_size) + + dtype = torch.bfloat16 + + cuda_communicator = typing.cast(CudaCommunicator, + get_tp_group().device_communicator) + symm_mem_comm = cuda_communicator.symm_mem_comm + if symm_mem_comm is None or symm_mem_comm.disabled: + pytest.skip("SymmMemCommunicator is not available or disabled.") + + inp_direct_symm_mem = torch.randint(1, + 23, (test_size_elements, ), + dtype=dtype, + device=device) + if not symm_mem_comm.should_use_symm_mem(inp_direct_symm_mem): + pytest.skip( + "SymmMemCommunicator isn't used for this world and input size." + ) + + original_inp_direct_symm_mem = inp_direct_symm_mem.clone() + out_direct_symm_mem = symm_mem_comm.all_reduce(inp_direct_symm_mem) + assert out_direct_symm_mem is not None + + group = get_tensor_model_parallel_group().device_group + dist.all_reduce(original_inp_direct_symm_mem, group=group) + torch.testing.assert_close(out_direct_symm_mem, + original_inp_direct_symm_mem, + atol=2.5, + rtol=0.1) + + # Test tensor_model_parallel_all_reduce which should use symm_mem + inp_tensor_parallel = torch.randint(-23, + 1, (test_size_elements, ), + dtype=dtype, + device=device) + original_inp_tensor_parallel = inp_tensor_parallel.clone() + out_tensor_parallel = tensor_model_parallel_all_reduce( + inp_tensor_parallel) + dist.all_reduce(original_inp_tensor_parallel, group=group) + torch.testing.assert_close(out_tensor_parallel, + original_inp_tensor_parallel, + atol=2.5, + rtol=0.1) + + +@pytest.mark.skipif( + not current_platform.is_cuda(), + reason="SymmMemAllreduce is only available for CUDA platforms.") +@pytest.mark.parametrize("tp_size", [2, 4]) +@pytest.mark.parametrize("pipeline_parallel_size", [1]) +@pytest.mark.parametrize("test_target", [symm_mem_allreduce]) +def test_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, tp_size, + pipeline_parallel_size, test_target): + world_size = tp_size * pipeline_parallel_size + if world_size > torch.cuda.device_count(): + pytest.skip("Not enough GPUs to run the test.") + + # Enable SymmMemCommunicator + monkeypatch.setenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "1") + + multi_process_parallel(monkeypatch, tp_size, pipeline_parallel_size, + test_target) diff --git a/tools/check_pickle_imports.py b/tools/check_pickle_imports.py index ef197d1fbac..ab2ddc09547 100644 --- a/tools/check_pickle_imports.py +++ b/tools/check_pickle_imports.py @@ -38,7 +38,7 @@ 'vllm/distributed/utils.py', 'vllm/distributed/parallel_state.py', 'vllm/engine/multiprocessing/client.py', - 'vllm/distributed/device_communicators/custom_all_reduce_utils.py', + 'vllm/distributed/device_communicators/all_reduce_utils.py', 'vllm/distributed/device_communicators/shm_broadcast.py', 'vllm/engine/multiprocessing/engine.py', 'benchmarks/kernels/graph_machete_bench.py', diff --git a/vllm/distributed/device_communicators/custom_all_reduce_utils.py b/vllm/distributed/device_communicators/all_reduce_utils.py similarity index 93% rename from vllm/distributed/device_communicators/custom_all_reduce_utils.py rename to vllm/distributed/device_communicators/all_reduce_utils.py index 7c6001e8703..59f56560c5f 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce_utils.py +++ b/vllm/distributed/device_communicators/all_reduce_utils.py @@ -23,6 +23,40 @@ logger = init_logger(__name__) +MiB = 1024 * 1024 + +# Max size for each world size in case symmetric memory is available +# For different SM architectures +CUSTOM_ALL_REDUCE_MAX_SIZES = { + "9.0": { + 2: 64 * MiB, # 64 MB + 4: 1 * MiB, # 1 MB + 6: MiB // 2, # 512 KB + 8: MiB // 4, # 256 KB + }, + "10.0": { + 2: 2 * MiB, # 2 MB + 4: 2 * MiB, # 2 MB + 6: 2 * MiB, # 2 MB + 8: MiB, # 1 MB + } +} + +SYMM_MEM_ALL_REDUCE_MAX_SIZES = { + "9.0": { + 2: 64 * MiB, # 64 MB + 4: 32 * MiB, # 32 MB + 6: 64 * MiB, # 64 MB + 8: 64 * MiB, # 64 MB + }, + "10.0": { + 2: 8 * MiB, # 8 MB + 4: 32 * MiB, # 32 MB + 6: 128 * MiB, # 128 MB + 8: 128 * MiB, # 128 MB + } +} + def producer(batch_src: Sequence[int], producer_queue, diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index e4804691f0f..aa588b99265 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -44,6 +44,8 @@ def __init__(self, PyNcclCommunicator) from vllm.distributed.device_communicators.quick_all_reduce import ( QuickAllReduce) + from vllm.distributed.device_communicators.symm_mem import ( + SymmMemCommunicator) self.pynccl_comm: Optional[PyNcclCommunicator] = None if use_pynccl and self.world_size > 1: @@ -54,6 +56,7 @@ def __init__(self, self.ca_comm: Optional[CustomAllreduce] = None self.qr_comm: Optional[QuickAllReduce] = None + self.symm_mem_comm: Optional[SymmMemCommunicator] = None if use_custom_allreduce and self.world_size > 1: # Initialize a custom fast all-reduce implementation. self.ca_comm = CustomAllreduce( @@ -69,6 +72,12 @@ def __init__(self, # currently be an MI300 series. self.qr_comm = QuickAllReduce(group=self.cpu_group, device=self.device) + if envs.VLLM_ALLREDUCE_USE_SYMM_MEM and current_platform.is_cuda(): + self.symm_mem_comm = SymmMemCommunicator( + group=self.cpu_group, + device=self.device, + ) + if self.use_all2all: all2all_backend = envs.VLLM_ALL2ALL_BACKEND if all2all_backend == "naive": @@ -105,6 +114,12 @@ def all_reduce(self, input_): out = ca_comm.custom_all_reduce(input_) assert out is not None return out + symm_mem_comm = self.symm_mem_comm + if symm_mem_comm is not None and not symm_mem_comm.disabled and \ + symm_mem_comm.should_use_symm_mem(input_): + out = symm_mem_comm.all_reduce(input_) + assert out is not None + return out pynccl_comm = self.pynccl_comm assert pynccl_comm is not None out = pynccl_comm.all_reduce(input_) diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index 7dd104a4fcc..8a678adb481 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -10,8 +10,8 @@ import vllm.envs as envs from vllm import _custom_ops as ops -from vllm.distributed.device_communicators.custom_all_reduce_utils import ( - gpu_p2p_access_check) +from vllm.distributed.device_communicators.all_reduce_utils import ( + CUSTOM_ALL_REDUCE_MAX_SIZES, gpu_p2p_access_check) from vllm.distributed.parallel_state import in_the_same_node_as from vllm.logger import init_logger from vllm.platforms import current_platform @@ -109,6 +109,13 @@ def __init__(self, # now `device` is a `torch.device` object assert isinstance(device, torch.device) self.device = device + device_capability = current_platform.get_device_capability( + ).as_version_str() + if (current_platform.is_cuda() and envs.VLLM_ALLREDUCE_USE_SYMM_MEM + and device_capability in CUSTOM_ALL_REDUCE_MAX_SIZES): + max_size = min( + CUSTOM_ALL_REDUCE_MAX_SIZES[device_capability][world_size], + max_size) cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES if cuda_visible_devices: diff --git a/vllm/distributed/device_communicators/symm_mem.py b/vllm/distributed/device_communicators/symm_mem.py new file mode 100644 index 00000000000..f337f3de071 --- /dev/null +++ b/vllm/distributed/device_communicators/symm_mem.py @@ -0,0 +1,115 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional, Union + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup + +from vllm.distributed.device_communicators.all_reduce_utils import ( + SYMM_MEM_ALL_REDUCE_MAX_SIZES) +from vllm.logger import init_logger +from vllm.platforms import current_platform + +try: + import torch.distributed._symmetric_memory as torch_symm_mem + + symm_mem_available = True +except ImportError: + symm_mem_available = False + +logger = init_logger(__name__) + + +class SymmMemCommunicator: + + # World sizes where multi-mem all-reduce performs the best + _WORLD_SIZES_MULTIMEM = { + "9.0": [4, 6, 8], + "10.0": [6, 8], + } + + def __init__(self, group: ProcessGroup, device: Union[int, str, + torch.device]): + self.disabled = True + + if not symm_mem_available: + return + + if not current_platform.is_cuda(): + logger.warning("SymmMemCommunicator: symmetric " + "memory is not available.") + return + if isinstance(device, int): + device = torch.device(f"cuda:{device}") + elif isinstance(device, str): + device = torch.device(device) + torch.cuda.set_device(device) + self.dtype = torch.bfloat16 + self.device = device + self.group = group + self.world_size = dist.get_world_size(self.group) + self.device_capability = current_platform.get_device_capability( + ).as_version_str() + + if self.device_capability not in SYMM_MEM_ALL_REDUCE_MAX_SIZES \ + or self.device_capability not in self._WORLD_SIZES_MULTIMEM: + logger.warning( + "SymmMemCommunicator: Device capability %s not supported, " + "communicator is not available.", + self.device_capability, + ) + return + if self.world_size not in SYMM_MEM_ALL_REDUCE_MAX_SIZES[ + self.device_capability]: + logger.warning( + "SymmMemCommunicator: World size %d not supported, " + "communicator is not available.", + self.world_size, + ) + return + self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability][ + self.world_size] + self.buffer = torch_symm_mem.empty( + self.max_size // self.dtype.itemsize, + device=self.device, + dtype=self.dtype, + ) + handle = torch_symm_mem.rendezvous(self.buffer, self.group.group_name) + if handle.multicast_ptr == 0: + logger.warning("SymmMemCommunicator: symmetric memory " + "multicast operations are not supported.") + return + self.disabled = False + + def should_use_symm_mem(self, inp: torch.Tensor): + if self.disabled: + return False + if inp.dtype != self.dtype: + return False + inp_size = inp.numel() * inp.element_size() + if inp_size % 4 != 0: + return False + return inp_size < self.max_size + + def all_reduce( + self, + inp: torch.Tensor, + *, + out: Optional[torch.Tensor] = None) -> Optional[torch.Tensor]: + if not self.should_use_symm_mem(inp): + return None + if out is None: + out = torch.empty_like(inp) + self.buffer[:inp.numel()].copy_(inp.view(-1)) + if self.world_size in self._WORLD_SIZES_MULTIMEM[ + self.device_capability]: + torch.ops.symm_mem.multimem_all_reduce_(self.buffer[:inp.numel()], + "sum", + self.group.group_name) + else: + torch.ops.symm_mem.two_shot_all_reduce_(self.buffer[:inp.numel()], + "sum", + self.group.group_name) + out.copy_(self.buffer[:inp.numel()].view(out.shape)) + return out diff --git a/vllm/envs.py b/vllm/envs.py index 7bff6ade815..d8738463a2d 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -139,6 +139,7 @@ VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: Optional[int] = None VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 120 + VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False def get_default_cache_root(): @@ -620,7 +621,7 @@ def get_vllm_port() -> Optional[int]: ("1", "true")), # By default, vLLM will check the peer-to-peer capability itself, - # in case of broken drivers. See https://github.com/vllm-project/vllm/blob/a9b15c606fea67a072416ea0ea115261a2756058/vllm/distributed/device_communicators/custom_all_reduce_utils.py#L101-L108 for details. # noqa + # in case of broken drivers. See https://github.com/vllm-project/vllm/blob/a9b15c606fea67a072416ea0ea115261a2756058/vllm/distributed/device_communicators/all_reduce_utils.py#L101-L108 for details. # noqa # If this env var is set to 1, vLLM will skip the peer-to-peer check, # and trust the driver's peer-to-peer capability report. "VLLM_SKIP_P2P_CHECK": @@ -964,6 +965,9 @@ def get_vllm_port() -> Optional[int]: # If set to 1, use the TRTLLM Decode Attention backend in flashinfer. "VLLM_USE_TRTLLM_DECODE_ATTENTION": lambda: os.getenv("VLLM_USE_TRTLLM_DECODE_ATTENTION", None), + # Whether to use pytorch symmetric memory for allreduce + "VLLM_ALLREDUCE_USE_SYMM_MEM": + lambda: bool(int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "0"))), } # --8<-- [end:env-vars-definition]