Skip to content

Commit 5c53cba

Browse files
Angazennangazenn
andauthored
[BugFix]Fix bugs when initializing communication groups with dp on 300I Duo (#1478)
### What this PR does / why we need it? This PR fixes a bug that use broadcast with cpu_group when running dp. The `broadcast310p` patch will take effects for both cpu_group and device group, but we only need it for device group. Hence a wrapper is added to allow cpu_group use native torch broadcast and it solves the bug. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? With this PR, DP on 310p runs normally and generates reasonable answers. Signed-off-by: angazenn <zengyanjia@huawei.com> Co-authored-by: angazenn <zengyanjia@huawei.com>
1 parent 2cf9c4c commit 5c53cba

File tree

1 file changed

+22
-14
lines changed

1 file changed

+22
-14
lines changed

vllm_ascend/patch/platform/patch_common/patch_distributed.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -77,20 +77,28 @@ def wait(self):
7777

7878
def communication_adaptation_310p():
7979

80-
def broadcast310p(tensor, src, group=None, async_op=False):
81-
rank = torch.distributed.get_rank(group)
82-
world_size = torch.distributed.get_world_size(group)
83-
tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
84-
tensor_list[rank] = tensor
85-
torch.distributed.all_gather(tensor_list, tensor, group=group)
86-
tensor[...] = tensor_list[src]
87-
if async_op:
88-
return NullHandle()
89-
else:
90-
return None
91-
92-
torch.distributed.broadcast = broadcast310p
93-
torch.distributed.distributed_c10d.broadcast = broadcast310p
80+
def broadcast310p_wrapper(fn):
81+
82+
def broadcast310p(tensor, src, group=None, async_op=False):
83+
if tensor.device == torch.device('cpu'):
84+
return fn(tensor, src, group, async_op)
85+
rank = torch.distributed.get_rank(group)
86+
world_size = torch.distributed.get_world_size(group)
87+
tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
88+
tensor_list[rank] = tensor
89+
torch.distributed.all_gather(tensor_list, tensor, group=group)
90+
tensor[...] = tensor_list[src]
91+
if async_op:
92+
return NullHandle()
93+
else:
94+
return None
95+
96+
return broadcast310p
97+
98+
torch.distributed.broadcast = broadcast310p_wrapper(
99+
torch.distributed.broadcast)
100+
torch.distributed.distributed_c10d.broadcast = broadcast310p_wrapper(
101+
torch.distributed.distributed_c10d.broadcast)
94102

95103
def all_reduce_wrapper_310p(fn):
96104

0 commit comments

Comments
 (0)