Skip to content

Commit 6636a97

Browse files
committed
Update with Hopper results, move consts to a single place
Signed-off-by: ilmarkov <markovilya197@gmail.com>
1 parent 916a77b commit 6636a97

File tree

6 files changed

+65
-26
lines changed

6 files changed

+65
-26
lines changed

docs/design/v1/multiprocessing.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ The `multiproc_xpu_executor` forces the use of `spawn`.
7777

7878
There are other miscellaneous places hard-coding the use of `spawn`:
7979

80-
- <https://github.com/vllm-project/vllm/blob/d05f88679bedd73939251a17c3d785a354b2946c/vllm/distributed/device_communicators/custom_all_reduce_utils.py#L135>
80+
- <https://github.com/vllm-project/vllm/blob/d05f88679bedd73939251a17c3d785a354b2946c/vllm/distributed/device_communicators/all_reduce_utils.py#L135>
8181
- <https://github.com/vllm-project/vllm/blob/d05f88679bedd73939251a17c3d785a354b2946c/vllm/entrypoints/openai/api_server.py#L184>
8282

8383
Related PRs:

tools/check_pickle_imports.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
'vllm/distributed/utils.py',
3939
'vllm/distributed/parallel_state.py',
4040
'vllm/engine/multiprocessing/client.py',
41-
'vllm/distributed/device_communicators/custom_all_reduce_utils.py',
41+
'vllm/distributed/device_communicators/all_reduce_utils.py',
4242
'vllm/distributed/device_communicators/shm_broadcast.py',
4343
'vllm/engine/multiprocessing/engine.py',
4444
'benchmarks/kernels/graph_machete_bench.py',

vllm/distributed/device_communicators/custom_all_reduce_utils.py renamed to vllm/distributed/device_communicators/all_reduce_utils.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,48 @@
1818
import vllm.envs as envs
1919
from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
2020
from vllm.logger import init_logger
21+
from vllm.platforms import DeviceCapability
2122
from vllm.utils import (cuda_device_count_stateless,
2223
update_environment_variables)
2324

2425
logger = init_logger(__name__)
2526

27+
MiB = 1024 * 1024
28+
29+
# Max size for each world size in case symmetric memory is available
30+
# For different SM architectures
31+
32+
# TODO(ilia): update max sizes for 6, 8 for sm90
33+
CUSTOM_ALL_REDUCE_MAX_SIZES = {
34+
DeviceCapability(9, 0): {
35+
2: 64 * MiB, # 64 MB
36+
4: MiB, # 1 MB
37+
6: MiB, # 1 MB
38+
8: MiB // 2, # 512 KB
39+
},
40+
DeviceCapability(10, 0): {
41+
2: 2 * MiB, # 2 MB
42+
4: 2 * MiB, # 2 MB
43+
6: 8 * MiB, # 8 MB
44+
8: 8 * MiB, # 8 MB
45+
}
46+
}
47+
48+
SYMM_MEM_ALL_REDUCE_MAX_SIZES = {
49+
DeviceCapability(9, 0): {
50+
2: 64 * MiB, # 64 MB
51+
4: 32 * MiB, # 32 MB
52+
6: 128 * MiB, # 128 MB
53+
8: 128 * MiB, # 128 MB
54+
},
55+
DeviceCapability(10, 0): {
56+
2: 8 * MiB, # 8 MB
57+
4: 32 * MiB, # 32 MB
58+
6: 128 * MiB, # 128 MB
59+
8: 128 * MiB, # 128 MB
60+
}
61+
}
62+
2663

2764
def producer(batch_src: Sequence[int],
2865
producer_queue,

vllm/distributed/device_communicators/custom_all_reduce.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010

1111
import vllm.envs as envs
1212
from vllm import _custom_ops as ops
13-
from vllm.distributed.device_communicators.custom_all_reduce_utils import (
14-
gpu_p2p_access_check)
13+
from vllm.distributed.device_communicators.all_reduce_utils import (
14+
CUSTOM_ALL_REDUCE_MAX_SIZES, gpu_p2p_access_check)
1515
from vllm.distributed.parallel_state import in_the_same_node_as
1616
from vllm.logger import init_logger
1717
from vllm.platforms import current_platform
@@ -49,14 +49,6 @@ def is_weak_contiguous(inp: torch.Tensor):
4949
class CustomAllreduce:
5050

5151
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
52-
MiB = 1024 * 1024
53-
# Max sizes for each world size in case symmetric memory is available
54-
_MAX_SIZES = {
55-
2: 2 * MiB, # 1 MB
56-
4: 2 * MiB, # 1 MB
57-
6: MiB, # 512 KB
58-
8: MiB // 2, # 512 KB
59-
}
6052

6153
# max_size: max supported allreduce size
6254
def __init__(self,
@@ -117,8 +109,12 @@ def __init__(self,
117109
# now `device` is a `torch.device` object
118110
assert isinstance(device, torch.device)
119111
self.device = device
120-
if current_platform.is_cuda() and envs.VLLM_ALLREDUCE_USE_SYMM_MEM:
121-
max_size = CustomAllreduce._MAX_SIZES[world_size]
112+
device_capability = current_platform.get_device_capability()
113+
if (current_platform.is_cuda() and envs.VLLM_ALLREDUCE_USE_SYMM_MEM
114+
and device_capability in CUSTOM_ALL_REDUCE_MAX_SIZES):
115+
max_size = min(
116+
CUSTOM_ALL_REDUCE_MAX_SIZES[device_capability][world_size],
117+
max_size)
122118

123119
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
124120
if cuda_visible_devices:

vllm/distributed/device_communicators/symm_mem.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import torch.distributed as dist
77
from torch.distributed import ProcessGroup
88

9+
from vllm.distributed.device_communicators.all_reduce_utils import (
10+
SYMM_MEM_ALL_REDUCE_MAX_SIZES)
911
from vllm.logger import init_logger
1012
from vllm.platforms import current_platform
1113

@@ -20,14 +22,6 @@
2022

2123

2224
class SymmMemCommunicator:
23-
MiB = 1024 * 1024
24-
# Max sizes for each world size
25-
_MAX_SIZES = {
26-
2: 8 * MiB,
27-
4: 32 * MiB,
28-
6: 128 * MiB,
29-
8: 128 * MiB,
30-
}
3125

3226
def __init__(self, group: ProcessGroup, device: Union[int, str,
3327
torch.device]):
@@ -49,15 +43,27 @@ def __init__(self, group: ProcessGroup, device: Union[int, str,
4943
self.device = device
5044
self.group = group
5145
self.world_size = dist.get_world_size(self.group)
52-
if self.world_size not in self._MAX_SIZES:
46+
device_capability = current_platform.get_device_capability()
47+
48+
if device_capability not in SYMM_MEM_ALL_REDUCE_MAX_SIZES:
49+
logger.warning(
50+
"SymmMemCommunicator: Device capability %s not supported, "
51+
"communicator is not available.",
52+
device_capability,
53+
)
54+
return
55+
if self.world_size not in SYMM_MEM_ALL_REDUCE_MAX_SIZES[
56+
device_capability]:
5357
logger.warning(
5458
"SymmMemCommunicator: World size %d not supported, "
5559
"communicator is not available.",
5660
self.world_size,
5761
)
5862
return
63+
self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[device_capability][
64+
self.world_size]
5965
self.buffer = torch_symm_mem.empty(
60-
self._MAX_SIZES[self.world_size] // self.dtype.itemsize,
66+
self.max_size // self.dtype.itemsize,
6167
device=self.device,
6268
dtype=self.dtype,
6369
)
@@ -76,7 +82,7 @@ def should_use_symm_mem(self, inp: torch.Tensor):
7682
inp_size = inp.numel() * inp.element_size()
7783
if inp_size % 4 != 0:
7884
return False
79-
return inp_size <= self._MAX_SIZES[self.world_size]
85+
return inp_size <= self.max_size
8086

8187
def all_reduce(
8288
self,

vllm/envs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -621,7 +621,7 @@ def get_vllm_port() -> Optional[int]:
621621
("1", "true")),
622622

623623
# By default, vLLM will check the peer-to-peer capability itself,
624-
# in case of broken drivers. See https://github.com/vllm-project/vllm/blob/a9b15c606fea67a072416ea0ea115261a2756058/vllm/distributed/device_communicators/custom_all_reduce_utils.py#L101-L108 for details. # noqa
624+
# in case of broken drivers. See https://github.com/vllm-project/vllm/blob/a9b15c606fea67a072416ea0ea115261a2756058/vllm/distributed/device_communicators/all_reduce_utils.py#L101-L108 for details. # noqa
625625
# If this env var is set to 1, vLLM will skip the peer-to-peer check,
626626
# and trust the driver's peer-to-peer capability report.
627627
"VLLM_SKIP_P2P_CHECK":

0 commit comments

Comments
 (0)