Skip to content

[wip] add nccl allocator and symm memory and enable TP all reduce for nccl symm #21383

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 14 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions tests/distributed/test_pynccl_symm_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import multiprocessing
import os

import numpy as np
import pytest
import torch
import torch.distributed

from vllm.distributed.communication_op import ( # noqa
tensor_model_parallel_all_reduce,
)
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary
from vllm.distributed.device_communicators.pynccl_allocator import (
get_nccl_mem_pool,
)

from vllm.distributed.parallel_state import (
ensure_model_parallel_initialized,
get_world_group,
graph_capture,
init_distributed_environment,
)
from vllm.utils import update_environment_variables


def distributed_run(fn, world_size):
number_of_processes = world_size
processes: list[multiprocessing.Process] = []
for i in range(number_of_processes):
env: dict[str, str] = {}
env["RANK"] = str(i)
env["LOCAL_RANK"] = str(i)
env["WORLD_SIZE"] = str(number_of_processes)
env["LOCAL_WORLD_SIZE"] = str(number_of_processes)
env["MASTER_ADDR"] = "localhost"
env["MASTER_PORT"] = "12345"
p = multiprocessing.Process(target=fn, args=(env,))
processes.append(p)
p.start()

for p in processes:
p.join()

for p in processes:
assert p.exitcode == 0


def worker_fn_wrapper(fn):
# `multiprocessing.Process` cannot accept environment variables directly
# so we need to pass the environment variables as arguments
# and update the environment variables in the function
def wrapped_fn(env):
update_environment_variables(env)
local_rank = os.environ["LOCAL_RANK"]
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)
init_distributed_environment()
fn()

return wrapped_fn


@worker_fn_wrapper
def multiple_allreduce_worker_fn():
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
groups = [
torch.distributed.new_group(ranks=[0, 1], backend="gloo"),
torch.distributed.new_group(ranks=[2, 3], backend="gloo"),
]
group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1]
pynccl_comm = PyNcclCommunicator(group=group, device=device)
with torch.cuda.use_mem_pool(get_nccl_mem_pool()):
symm_tensor = torch.ones(
16, 1024, 1024, dtype=torch.float32, device=device
)
win = pynccl_comm.register_comm_window(symm_tensor)
stream = torch.cuda.default_stream()
# two groups can communicate independently
if torch.distributed.get_rank() in [0, 1]:
tensor = pynccl_comm.all_reduce(symm_tensor, stream=stream)
tensor = pynccl_comm.all_reduce(symm_tensor, stream=stream)
torch.cuda.synchronize()
assert torch.all(tensor == 4).cpu().item()
else:
tensor = pynccl_comm.all_reduce(symm_tensor, stream=stream)
torch.cuda.synchronize()
assert torch.all(tensor == 2).cpu().item()
pynccl_comm.deregister_comm_window(win)



@pytest.mark.skipif(
torch.cuda.device_count() < 4,
reason="Need at least 4 GPUs to run the test.",
)
def test_pynccl_multiple_allreduce():
# this tests pynccl for multiple tp groups, in a standalone way
# i.e. call `pynccl_comm.all_reduce` directly
distributed_run(multiple_allreduce_worker_fn, 4)
16 changes: 13 additions & 3 deletions vllm/distributed/communication_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,17 @@

import torch
import torch.distributed

from contextlib import nullcontext
from vllm.distributed.device_communicators.pynccl_allocator import (
get_nccl_mem_pool, use_symmetric_memory)
from .parallel_state import get_tp_group


def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
def tensor_model_parallel_all_reduce(
input_: torch.Tensor, output_: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""All-reduce the input tensor across model parallel group."""
return get_tp_group().all_reduce(input_)
return get_tp_group().all_reduce(input_, output_)


def tensor_model_parallel_all_gather(input_: torch.Tensor,
Expand Down Expand Up @@ -39,3 +43,9 @@ def broadcast_tensor_dict(tensor_dict: Optional[dict[Any, Union[torch.Tensor,
if not torch.distributed.is_initialized():
return tensor_dict
return get_tp_group().broadcast_tensor_dict(tensor_dict, src)


def tensor_model_parallel_use_symmetric_memory():
# if torch.compiler.is_compiling():
# return nullcontext()
return use_symmetric_memory(get_tp_group())
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,10 @@ def __init__(self,
self.use_all2all = "ep" in unique_name and use_ep
self.all2all_manager: Optional[All2AllManagerBase] = None

def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
def all_reduce(
self, input_: torch.Tensor, output_: Optional[torch.Tensor] = None
) -> torch.Tensor:
assert output_ is None, "output_ is not supported in the base class"
dist.all_reduce(input_, group=self.device_group)
return input_

Expand Down
18 changes: 15 additions & 3 deletions vllm/distributed/device_communicators/cuda_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.platforms import current_platform

from vllm.distributed.device_communicators.pynccl_allocator import (
get_nccl_mem_pool)
from .base_device_communicator import DeviceCommunicatorBase

logger = init_logger(__name__)
Expand Down Expand Up @@ -47,6 +48,8 @@ def __init__(self,

self.pynccl_comm: Optional[PyNcclCommunicator] = None
if use_pynccl and self.world_size > 1:
# initialize the mem pool to avoid torch dynamo error
get_nccl_mem_pool()
self.pynccl_comm = PyNcclCommunicator(
group=self.cpu_group,
device=self.device,
Expand Down Expand Up @@ -90,7 +93,16 @@ def __init__(self,
else:
raise ValueError(f"Unknown all2all backend: {all2all_backend}")

def all_reduce(self, input_):
def all_reduce(self, input_, output_=None):
if (
self.pynccl_comm is not None
and self.pynccl_comm.nccl_version >= 22703
and hasattr(input_, "symmetric_memory")
and input_.symmetric_memory
):
# TODO(asamani): this is under change_state in sglang, double check!
input_ = self.pynccl_comm.all_reduce(input_, input_)
return input_
# always try quick reduce first, then custom allreduce,
# and then pynccl. (quick reduce just for ROCM MI3*)
qr_comm = self.qr_comm
Expand All @@ -107,7 +119,7 @@ def all_reduce(self, input_):
return out
pynccl_comm = self.pynccl_comm
assert pynccl_comm is not None
out = pynccl_comm.all_reduce(input_)
out = pynccl_comm.all_reduce(input_, output_)
if out is None:
# fall back to the default all-reduce using PyTorch.
# this usually happens during testing.
Expand Down
29 changes: 24 additions & 5 deletions vllm/distributed/device_communicators/pynccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def __init__(
self.available = True
self.disabled = False

self.nccl_version = self.nccl.ncclGetRawVersion()
logger.info("vLLM is using nccl==%s", self.nccl.ncclGetVersion())

if self.rank == 0:
Expand Down Expand Up @@ -109,6 +110,7 @@ def __init__(

def all_reduce(self,
in_tensor: torch.Tensor,
out_tensor: torch.Tensor = None,
op: ReduceOp = ReduceOp.SUM,
stream=None) -> torch.Tensor:
if self.disabled:
Expand All @@ -119,18 +121,19 @@ def all_reduce(self,
assert in_tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {in_tensor.device}")

out_tensor = torch.empty_like(in_tensor)

if out_tensor is None:
output = torch.empty_like(in_tensor)
else:
output = out_tensor
if stream is None:
stream = current_stream()
self.nccl.ncclAllReduce(buffer_type(in_tensor.data_ptr()),
buffer_type(out_tensor.data_ptr()),
buffer_type(output.data_ptr()),
in_tensor.numel(),
ncclDataTypeEnum.from_torch(in_tensor.dtype),
ncclRedOpTypeEnum.from_torch(op), self.comm,
cudaStream_t(stream.cuda_stream))
return out_tensor
return output

def all_gather(self,
output_tensor: torch.Tensor,
Expand Down Expand Up @@ -288,3 +291,19 @@ def group_start(self):

def group_end(self):
self.nccl.ncclGroupEnd()

def register_comm_window(self, tensor: torch.Tensor):
return self.nccl.ncclCommWindowRegister(
self.comm,
buffer_type(tensor.data_ptr()),
tensor.numel() * tensor.element_size(),
1,
)

def register_comm_window_raw(self, ptr: int, size: int):
return self.nccl.ncclCommWindowRegister(
self.comm, buffer_type(ptr), size, 1
)

def deregister_comm_window(self, window):
return self.nccl.ncclCommWindowDeregister(self.comm, window)
92 changes: 92 additions & 0 deletions vllm/distributed/device_communicators/pynccl_allocator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import tempfile
import torch
from torch.cuda.memory import CUDAPluggableAllocator
from vllm.distributed.parallel_state import GroupCoordinator

nccl_allocator_source = """
#include <nccl.h>
#include <c10/cuda/CUDAGuard.h>
extern "C" {

void* nccl_alloc_plug(size_t size, int device, void* stream) {
void* ptr;
at::cuda::OptionalCUDAGuard gpuGuard(device);
ncclResult_t err = ncclMemAlloc(&ptr, size);
return ptr;

}

void nccl_free_plug(void* ptr, size_t size, int device, void* stream) {
at::cuda::OptionalCUDAGuard gpuGuard(device);
ncclResult_t err = ncclMemFree(ptr);
}

}
"""

_allocator = None
_mem_pool = None
_registered_base_addrs = set()
_graph_pool_id = None


def get_nccl_mem_pool():
global _allocator, _mem_pool
if _mem_pool is None:
out_dir = tempfile.gettempdir()
nccl_allocator_libname = "nccl_allocator"
torch.utils.cpp_extension.load_inline(
name=nccl_allocator_libname,
cpp_sources=nccl_allocator_source,
with_cuda=True,
extra_ldflags=["-lnccl"],
verbose=True,
is_python_module=False,
build_directory=out_dir,
)

_allocator = CUDAPluggableAllocator(
f"{out_dir}/{nccl_allocator_libname}.so",
"nccl_alloc_plug",
"nccl_free_plug",
).allocator()
_mem_pool = torch.cuda.MemPool(_allocator)

return _mem_pool


class use_symmetric_memory:
def __init__(self, group_coordinator: GroupCoordinator):
self.group_coordinator = group_coordinator
self._mem_pool_ctx = torch.cuda.use_mem_pool(get_nccl_mem_pool())
self.is_graph_capture = torch.cuda.is_current_stream_capturing()
self.device = torch.cuda.current_device()

def __enter__(self):
if self.is_graph_capture:
assert (
_graph_pool_id is not None
), "graph_pool_id is not set under graph capture"
torch._C._cuda_endAllocateCurrentStreamToPool(
self.device, _graph_pool_id
)
self._mem_pool_ctx.__enter__()
return self

def __exit__(self, exc_type, exc_val, exc_tb):
global _registered_base_addrs
self._mem_pool_ctx.__exit__(exc_type, exc_val, exc_tb)
for segment in get_nccl_mem_pool().snapshot():
if segment["address"] not in _registered_base_addrs:
# Check symmetric is maintained across all ranks
# TODO
self.group_coordinator.pynccl_comm.register_comm_window_raw(
segment["address"], segment["total_size"]
)
_registered_base_addrs.add(segment["address"])

if self.is_graph_capture:
assert (
_graph_pool_id is not None
), "graph_pool_id is not set under graph capture"
torch._C._cuda_beginAllocateToPool(self.device, _graph_pool_id)
Loading