Skip to content

Commit 89c1a0f

Browse files
authored
[Bugfix] Fix memory-leak caused by dist._functional_collectives.reduce_scatter_tensor (#1380)
### What this PR does / why we need it? In some cases, `dist._functional_collectives.reduce_scatter_tensor` can cause its input tensor not to be released immediately after the current layer ends. Instead, it will only be released when the GPU memory usage of the current process reaches a certain threshold (approximately every 15 layers each time). **Before Fix** <img width="1441" alt="截屏2025-06-24 01 26 13" src="https://github.com/user-attachments/assets/72d5dbb3-c8c8-4778-bf64-8db7bab8aff0" /> **After Fix** <img width="1475" alt="截屏2025-06-24 01 23 43" src="https://github.com/user-attachments/assets/6c69cfcd-a469-4ee5-b8c6-210aeb3a5bdf" /> ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? - vLLM version: v0.9.1 - vLLM main: vllm-project/vllm@9ff2af6 --------- Signed-off-by: ApsarasX <apsarax@outlook.com>
1 parent b1c66b2 commit 89c1a0f

File tree

2 files changed

+29
-5
lines changed

2 files changed

+29
-5
lines changed
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# This file is a part of the vllm-ascend project.
16+
#
17+
18+
import torch
19+
from vllm.distributed.parallel_state import get_dp_group
20+
21+
22+
def data_parallel_reduce_scatter(input_: torch.Tensor,
23+
dim: int = -1) -> torch.Tensor:
24+
"""Reduce-Scatter the input tensor across data parallel group."""
25+
return get_dp_group().reduce_scatter(input_, dim)

vllm_ascend/ops/fused_moe.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939

4040
import vllm_ascend.envs as envs_ascend
4141
from vllm_ascend.ascend_config import get_ascend_config
42+
from vllm_ascend.distributed.communication_op import \
43+
data_parallel_reduce_scatter
4244
from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group
4345
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
4446
from vllm_ascend.utils import (FusedMoEState, dispose_tensor,
@@ -1342,11 +1344,8 @@ def forward(self,
13421344
final_hidden_states = final_hidden_states[start:end, :]
13431345
dispose_tensor(e_hidden_states)
13441346
elif fused_moe_state == FusedMoEState.AllGather:
1345-
final_hidden_states = dist._functional_collectives.reduce_scatter_tensor(
1346-
e_hidden_states,
1347-
"sum",
1348-
scatter_dim=0,
1349-
group=get_dp_group().device_group)
1347+
final_hidden_states = data_parallel_reduce_scatter(
1348+
e_hidden_states, dim=0)
13501349
final_hidden_states = final_hidden_states[:num_tokens]
13511350
dispose_tensor(e_hidden_states)
13521351
else:

0 commit comments

Comments
 (0)