From c528147a4fb6ef421e4af9b1307f2f640c55558a Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Tue, 3 Jun 2025 22:02:49 +0000 Subject: [PATCH 1/2] [Bugfix][EP+DP] Fix internode check Signed-off-by: Tyler Michael Smith --- vllm/distributed/device_communicators/all2all.py | 2 -- .../device_communicators/base_device_communicator.py | 3 +-- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index 2ab3779ece05..e38b59d18e63 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -178,7 +178,6 @@ def _make_all2all_kwargs(self) -> dict[Any, Any]: num_rdma_bytes = 1024 * 1024 * 1024 num_qps_per_rank = self.num_sms // 2 else: - assert self.intranode num_rdma_bytes = 0 num_qps_per_rank = 1 @@ -243,7 +242,6 @@ def _make_all2all_kwargs( if self.internode: num_rdma_bytes = 1024 * 1024 * 1024 else: - assert self.intranode num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint( num_max_dispatch_tokens_per_rank=max_num_tokens_per_dp_rank, hidden=token_hidden_size, diff --git a/vllm/distributed/device_communicators/base_device_communicator.py b/vllm/distributed/device_communicators/base_device_communicator.py index 38370d4dc2b5..1bc2d8e0281c 100644 --- a/vllm/distributed/device_communicators/base_device_communicator.py +++ b/vllm/distributed/device_communicators/base_device_communicator.py @@ -49,8 +49,7 @@ def __init__(self, cpu_group): # all2all communication often has separate implementations for # intra-node and inter-node communication - self.intranode = in_the_same_node_as(cpu_group, source_rank=0) - self.internode = not self.intranode + self.internode = not all(in_the_same_node_as(cpu_group, source_rank=0)) def get_handle(self, kwargs): # get a handle for the all2all communication, From 38a7be44bdfa39ab9e78772674ef32b6660d3c53 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Tue, 3 Jun 2025 22:05:37 +0000 Subject: [PATCH 2/2] rm hack from #19034 Signed-off-by: Tyler Michael Smith --- vllm/distributed/device_communicators/all2all.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index e38b59d18e63..cab2496bfba7 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -84,10 +84,6 @@ def __init__(self, cpu_group): assert has_pplx, "pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install pplx_kernels." # noqa super().__init__(cpu_group) - # TODO(tms): Disable pplx-a2a intranode as it fails with the error: - # failed: cuda error /app/pplx/csrc/all_to_all/intranode.cpp:84 'invalid resource handle' # noqa - self.internode = True - if self.internode: # inter-node communication needs nvshmem, # intra-node communication uses p2p mapping directly