Skip to content

Commit a859323

Browse files
trevor-mmgoin
andauthored
Add pynccl all-gatherv and reducescatterv (#20154)
Signed-off-by: Trevor Morris <tmorris@nvidia.com> Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: mgoin <mgoin64@gmail.com>
1 parent fc0f41d commit a859323

File tree

6 files changed

+284
-2
lines changed

6 files changed

+284
-2
lines changed

tests/distributed/test_pynccl.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import multiprocessing
55
import os
66

7+
import numpy as np
78
import pytest
89
import torch
910
import torch.distributed
@@ -177,6 +178,38 @@ def test_pynccl_all_gather():
177178
distributed_run(all_gather_worker_fn, 2)
178179

179180

181+
@worker_fn_wrapper
182+
def all_gatherv_worker_fn():
183+
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
184+
device=get_world_group().device)
185+
186+
rank = pynccl_comm.rank
187+
world_size = pynccl_comm.world_size
188+
device = f'cuda:{pynccl_comm.rank}'
189+
190+
assert world_size <= 8
191+
sizes = [81, 20, 57, 52, 81, 5, 49, 49][:world_size]
192+
num_elems = sizes[rank]
193+
tensor = torch.arange(num_elems, dtype=torch.float32,
194+
device=device) + rank * 100
195+
result = torch.zeros(sum(sizes), dtype=torch.float32, device=device)
196+
197+
expected = torch.cat([
198+
torch.arange(sizes[r], dtype=torch.float32) + r * 100
199+
for r in range(world_size)
200+
]).to(device)
201+
202+
pynccl_comm.all_gatherv(result, tensor, sizes=sizes)
203+
torch.cuda.synchronize()
204+
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
205+
206+
207+
@pytest.mark.skipif(torch.cuda.device_count() < 2,
208+
reason="Need at least 2 GPUs to run the test.")
209+
def test_pynccl_all_gatherv():
210+
distributed_run(all_gatherv_worker_fn, 2)
211+
212+
180213
@worker_fn_wrapper
181214
def reduce_scatter_worker_fn():
182215
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
@@ -214,6 +247,43 @@ def test_pynccl_reduce_scatter():
214247
distributed_run(reduce_scatter_worker_fn, 2)
215248

216249

250+
@worker_fn_wrapper
251+
def reduce_scatterv_worker_fn():
252+
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
253+
device=get_world_group().device)
254+
255+
rank = pynccl_comm.rank
256+
world_size = pynccl_comm.world_size
257+
device = f'cuda:{pynccl_comm.rank}'
258+
259+
assert world_size <= 8
260+
sizes = [81, 20, 57, 52, 81, 5, 49, 49][:world_size]
261+
num_elems = sum(sizes)
262+
tensor = torch.arange(num_elems, dtype=torch.float32,
263+
device=device) + rank * 100
264+
result = torch.zeros(sizes[rank], dtype=torch.float32, device=device)
265+
266+
# Calculate expected result for this rank's chunk
267+
all_tensors = [
268+
torch.arange(num_elems, dtype=torch.float32) + r * 100
269+
for r in range(world_size)
270+
]
271+
sizes_cumsum = np.cumsum(sizes)
272+
start = 0 if rank == 0 else sizes_cumsum[rank - 1]
273+
end = sizes_cumsum[rank]
274+
expected = sum(tensor[start:end] for tensor in all_tensors).to(device)
275+
276+
pynccl_comm.reduce_scatterv(result, tensor, sizes=sizes)
277+
torch.cuda.synchronize()
278+
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
279+
280+
281+
@pytest.mark.skipif(torch.cuda.device_count() < 2,
282+
reason="Need at least 2 GPUs to run the test.")
283+
def test_pynccl_reduce_scatterv():
284+
distributed_run(reduce_scatterv_worker_fn, 2)
285+
286+
217287
@pytest.mark.skipif(torch.cuda.device_count() < 2,
218288
reason="Need at least 2 GPUs to run the test.")
219289
def test_pynccl_with_cudagraph():

vllm/distributed/device_communicators/base_device_communicator.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
import threading
4-
from typing import Optional
4+
from typing import Optional, Union
55
from weakref import WeakValueDictionary
66

77
import torch
@@ -138,6 +138,14 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
138138
input_size[dim + 1:])
139139
return output_tensor
140140

141+
def all_gatherv(
142+
self,
143+
input_: Union[torch.Tensor, list[torch.Tensor]],
144+
dim: int = 0,
145+
sizes: Optional[list[int]] = None
146+
) -> Union[torch.Tensor, list[torch.Tensor]]:
147+
raise NotImplementedError
148+
141149
def reduce_scatter(self,
142150
input_: torch.Tensor,
143151
dim: int = -1) -> torch.Tensor:
@@ -172,6 +180,12 @@ def reduce_scatter(self,
172180
# Reshape before returning
173181
return output_tensor.movedim(0, dim).contiguous()
174182

183+
def reduce_scatterv(self,
184+
input_: torch.Tensor,
185+
dim: int = -1,
186+
sizes: Optional[list[int]] = None) -> torch.Tensor:
187+
raise NotImplementedError
188+
175189
def gather(self,
176190
input_: torch.Tensor,
177191
dst: int = 0,

vllm/distributed/device_communicators/cuda_communicator.py

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4-
from typing import Optional
4+
from typing import Optional, Union
55

66
import torch
77
from torch.distributed import ProcessGroup
@@ -142,6 +142,42 @@ def reduce_scatter(self, input_: torch.Tensor, dim: int = -1):
142142
# Reshape before returning
143143
return output.movedim(0, dim).contiguous()
144144

145+
def reduce_scatterv(self,
146+
input_: torch.Tensor,
147+
dim: int = -1,
148+
sizes: Optional[list[int]] = None):
149+
world_size = self.world_size
150+
pynccl_comm = self.pynccl_comm
151+
assert pynccl_comm is not None
152+
if dim < 0:
153+
# Convert negative dim to positive.
154+
dim += input_.dim()
155+
156+
# Note: This will produce an incorrect answer if we don't make
157+
# the input_tensor contiguous. Possible bug in reduce_scatter_tensor?
158+
input_tensor = input_.movedim(0, dim).contiguous()
159+
160+
if sizes is not None:
161+
assert len(sizes) == world_size
162+
assert input_tensor.shape[0] == sum(sizes)
163+
chunk_size = sizes[self.rank_in_group]
164+
else:
165+
assert input_tensor.shape[0] % world_size == 0
166+
chunk_size = input_tensor.shape[0] // world_size
167+
output_shape = (chunk_size, ) + input_tensor.shape[1:]
168+
169+
output = torch.empty(output_shape,
170+
dtype=input_tensor.dtype,
171+
device=input_tensor.device)
172+
173+
if sizes is not None:
174+
pynccl_comm.reduce_scatterv(output, input_, sizes=sizes)
175+
else:
176+
pynccl_comm.reduce_scatter(output, input_)
177+
178+
# Reshape before returning
179+
return output.movedim(0, dim).contiguous()
180+
145181
def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
146182
"""Sends a tensor to the destination rank in a non-blocking way"""
147183
"""NOTE: `dst` is the local rank of the destination rank."""
@@ -180,6 +216,51 @@ def destroy(self):
180216
self.all2all_manager.destroy()
181217
self.all2all_manager = None
182218

219+
def all_gatherv(self,
220+
input_: Union[torch.Tensor, list[torch.Tensor]],
221+
dim: int = 0,
222+
sizes: Optional[list[int]] = None):
223+
if dim != 0:
224+
raise NotImplementedError("only dim 0 all-gatherv is supported")
225+
world_size = self.world_size
226+
pynccl_comm = self.pynccl_comm
227+
assert pynccl_comm is not None and not pynccl_comm.disabled
228+
229+
# 'sizes' is not needed if all inputs in the same group have the same
230+
# shape
231+
if sizes is not None and all(s == sizes[0] for s in sizes):
232+
sizes = None
233+
234+
def _all_gather_single(input_: torch.Tensor,
235+
sizes: Optional[list[int]] = None):
236+
input_size = input_.size()
237+
if sizes is not None:
238+
assert len(sizes) == world_size
239+
assert input_.shape[dim] == sizes[self.rank_in_group]
240+
output_size = (sum(sizes), ) + input_size[1:]
241+
else:
242+
output_size = (input_size[0] * world_size, ) + input_size[1:]
243+
# Allocate output tensor.
244+
output_tensor = torch.empty(output_size,
245+
dtype=input_.dtype,
246+
device=input_.device)
247+
if sizes is not None:
248+
pynccl_comm.all_gatherv(output_tensor, input_, sizes=sizes)
249+
else:
250+
pynccl_comm.all_gather(output_tensor, input_)
251+
return output_tensor
252+
253+
if isinstance(input_, torch.Tensor):
254+
return _all_gather_single(input_, sizes)
255+
256+
output_list = []
257+
pynccl_comm.group_start()
258+
for inp in input_:
259+
output_list.append(_all_gather_single(inp, sizes=sizes))
260+
pynccl_comm.group_end()
261+
262+
return output_list
263+
183264
def dispatch(
184265
self, hidden_states: torch.Tensor,
185266
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:

vllm/distributed/device_communicators/pynccl.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,40 @@ def all_gather(self,
152152
ncclDataTypeEnum.from_torch(input_tensor.dtype), self.comm,
153153
cudaStream_t(stream.cuda_stream))
154154

155+
def all_gatherv(
156+
self,
157+
output_tensor: torch.Tensor,
158+
input_tensor: torch.Tensor,
159+
sizes: list[int],
160+
stream=None,
161+
):
162+
if self.disabled:
163+
return
164+
# nccl communicator created on a specific device
165+
# will only work on tensors on the same device
166+
# otherwise it will cause "illegal memory access"
167+
assert input_tensor.device == self.device, (
168+
f"this nccl communicator is created to work on {self.device}, "
169+
f"but the input tensor is on {input_tensor.device}")
170+
if stream is None:
171+
stream = current_stream()
172+
assert output_tensor.shape[0] == sum(sizes)
173+
split_offset = 0
174+
self.nccl.ncclGroupStart()
175+
for root, split_size in enumerate(sizes):
176+
dst_slice = output_tensor[split_offset:split_offset + split_size]
177+
self.nccl.ncclBroadcast(
178+
buffer_type(input_tensor.data_ptr()),
179+
buffer_type(dst_slice.data_ptr()),
180+
dst_slice.numel(),
181+
ncclDataTypeEnum.from_torch(input_tensor.dtype),
182+
root,
183+
self.comm,
184+
cudaStream_t(stream.cuda_stream),
185+
)
186+
split_offset += split_size
187+
self.nccl.ncclGroupEnd()
188+
155189
def reduce_scatter(self,
156190
output_tensor: torch.Tensor,
157191
input_tensor: torch.Tensor,
@@ -174,6 +208,38 @@ def reduce_scatter(self,
174208
ncclRedOpTypeEnum.from_torch(op), self.comm,
175209
cudaStream_t(stream.cuda_stream))
176210

211+
def reduce_scatterv(
212+
self,
213+
output_tensor: torch.Tensor,
214+
input_tensor: torch.Tensor,
215+
sizes: list[int],
216+
op: ReduceOp = ReduceOp.SUM,
217+
stream=None,
218+
):
219+
if self.disabled:
220+
return
221+
# nccl communicator created on a specific device
222+
# will only work on tensors on the same device
223+
# otherwise it will cause "illegal memory access"
224+
assert input_tensor.device == self.device, (
225+
f"this nccl communicator is created to work on {self.device}, "
226+
f"but the input tensor is on {input_tensor.device}")
227+
if stream is None:
228+
stream = current_stream()
229+
230+
split_offset = 0
231+
self.nccl.ncclGroupStart()
232+
for root, split_size in enumerate(sizes):
233+
chunk = input_tensor[split_offset:split_offset + split_size, ...]
234+
self.nccl.ncclReduce(
235+
buffer_type(chunk.data_ptr()),
236+
buffer_type(output_tensor.data_ptr()), chunk.numel(),
237+
ncclDataTypeEnum.from_torch(input_tensor.dtype),
238+
ncclRedOpTypeEnum.from_torch(op), root, self.comm,
239+
cudaStream_t(stream.cuda_stream))
240+
split_offset += split_size
241+
self.nccl.ncclGroupEnd()
242+
177243
def send(self, tensor: torch.Tensor, dst: int, stream=None):
178244
if self.disabled:
179245
return
@@ -216,3 +282,9 @@ def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
216282
self.nccl.ncclBroadcast(sendbuff, recvbuff, tensor.numel(),
217283
ncclDataTypeEnum.from_torch(tensor.dtype), src,
218284
self.comm, cudaStream_t(stream.cuda_stream))
285+
286+
def group_start(self):
287+
self.nccl.ncclGroupStart()
288+
289+
def group_end(self):
290+
self.nccl.ncclGroupEnd()

vllm/distributed/device_communicators/pynccl_wrapper.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,17 @@ class NCCLLibrary:
154154
ncclRedOp_t, ncclComm_t, cudaStream_t
155155
]),
156156

157+
# ncclResult_t ncclReduce(
158+
# const void* sendbuff, void* recvbuff, size_t count,
159+
# ncclDataType_t datatype, ncclRedOp_t op, int root,
160+
# ncclComm_t comm, cudaStream_t stream);
161+
# note that cudaStream_t is a pointer type, so the last argument
162+
# is a pointer
163+
Function("ncclReduce", ncclResult_t, [
164+
buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,
165+
ncclRedOp_t, ctypes.c_int, ncclComm_t, cudaStream_t
166+
]),
167+
157168
# ncclResult_t ncclAllGather(
158169
# const void* sendbuff, void* recvbuff, size_t count,
159170
# ncclDataType_t datatype, ncclComm_t comm,
@@ -207,6 +218,10 @@ class NCCLLibrary:
207218
# it is better not to call it at all.
208219
# ncclResult_t ncclCommDestroy(ncclComm_t comm);
209220
Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]),
221+
# ncclResult_t ncclGroupStart();
222+
Function("ncclGroupStart", ncclResult_t, []),
223+
# ncclResult_t ncclGroupEnd();
224+
Function("ncclGroupEnd", ncclResult_t, []),
210225
]
211226

212227
# class attribute to store the mapping from the path to the library
@@ -300,6 +315,18 @@ def ncclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type,
300315
datatype, op, comm,
301316
stream))
302317

318+
def ncclReduce(self, sendbuff: buffer_type, recvbuff: buffer_type,
319+
count: int, datatype: int, op: int, root: int,
320+
comm: ncclComm_t, stream: cudaStream_t) -> None:
321+
# `datatype` actually should be `ncclDataType_t`
322+
# and `op` should be `ncclRedOp_t`
323+
# both are aliases of `ctypes.c_int`
324+
# when we pass int to a function, it will be converted to `ctypes.c_int`
325+
# by ctypes automatically
326+
self.NCCL_CHECK(self._funcs["ncclReduce"](sendbuff, recvbuff, count,
327+
datatype, op, root, comm,
328+
stream))
329+
303330
def ncclReduceScatter(self, sendbuff: buffer_type, recvbuff: buffer_type,
304331
count: int, datatype: int, op: int, comm: ncclComm_t,
305332
stream: cudaStream_t) -> None:
@@ -342,6 +369,12 @@ def ncclBroadcast(self, sendbuff: buffer_type, recvbuff: buffer_type,
342369
def ncclCommDestroy(self, comm: ncclComm_t) -> None:
343370
self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm))
344371

372+
def ncclGroupStart(self) -> None:
373+
self.NCCL_CHECK(self._funcs["ncclGroupStart"]())
374+
375+
def ncclGroupEnd(self) -> None:
376+
self.NCCL_CHECK(self._funcs["ncclGroupEnd"]())
377+
345378

346379
__all__ = [
347380
"NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId",

0 commit comments

Comments
 (0)