Skip to content

Commit dc31ffa

Browse files
ApsarasXDavid9857
authored andcommitted
[Perf] Refactor tensor disposal logic to reduce memory usage (#966)
### What this PR does / why we need it? 1. In previous PRs #580 #784, I saved GPU memory by promptly deleting unnecessary tensors. For tensors passed from upper-layer functions, I used a list container to transfer the parameter and then popped the tensor from the list within the inner function to achieve deletion. Recently, I discovered a better implementation in sglang—the `dispose_tensor` function and I recommend adopting this approach. 2. Dispose `hidden_states` and `residual` from the previous layer once they're no longer used. 3. Avoid to generate `self.inputs_embeds` in `ModelRunnerV1` in non-multimodal scenarios. With the aforementioned optimizations, using the DeepSeek-R1-W8A8 model under the conditions of `TP=16` and `max-model-len=32768`, we can save 1.3GB of npu memory. **Reference**: sgl-project/sglang#6147 ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? --------- Signed-off-by: ApsarasX <apsarax@outlook.com>
1 parent d7f8be5 commit dc31ffa

File tree

4 files changed

+30
-23
lines changed

4 files changed

+30
-23
lines changed

vllm_ascend/models/deepseek_v2.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
import vllm_ascend.envs as envs_ascend
7070
from vllm_ascend.ops.fused_moe import AscendFusedMoE
7171
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
72+
from vllm_ascend.utils import dispose_tensor
7273

7374
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
7475

@@ -558,8 +559,14 @@ def forward(
558559
residual = hidden_states
559560
hidden_states = self.input_layernorm(hidden_states)
560561
else:
562+
previous_hidden_states, previous_residual = hidden_states, residual
561563
hidden_states, residual = self.input_layernorm(
562564
hidden_states, residual)
565+
# Dispose hidden_states and residual from the previous layer
566+
# to save npu memory because they're no longer used.
567+
dispose_tensor(previous_hidden_states)
568+
dispose_tensor(previous_residual)
569+
563570
hidden_states = self.self_attn(
564571
positions=positions,
565572
hidden_states=hidden_states,

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# limitations under the License.
1616
#
1717

18-
from typing import Any, Callable, Dict, List, Optional
18+
from typing import Any, Callable, Dict, Optional
1919

2020
import torch
2121
import torch.distributed as dist
@@ -26,11 +26,12 @@
2626
import vllm_ascend.envs as envs_ascend
2727
from vllm_ascend.distributed.parallel_state import get_ep_group
2828
from vllm_ascend.ops.fused_moe import select_experts
29+
from vllm_ascend.utils import dispose_tensor
2930

3031
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
3132

3233

33-
def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
34+
def apply_mlp(hidden_states: torch.Tensor,
3435
w1: torch.Tensor,
3536
w1_scale: torch.Tensor,
3637
w2: torch.Tensor,
@@ -43,7 +44,7 @@ def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
4344
apply MLP: gate_up_proj -> swiglu -> down_proj
4445
4546
Args:
46-
hidden_states_wrapper: wrapper of input hidden states with shape (num_tokens, hidden_size).
47+
hidden_states: input hidden states with shape (num_tokens, hidden_size).
4748
w1: expert weights1 with shape
4849
(num_experts, hidden_size, intermediate_size * 2)
4950
w1_scale: weights1 scale with shape (num_experts, intermediate_size * 2)
@@ -62,11 +63,13 @@ def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
6263
hidden_states: output hidden states after MLP.
6364
"""
6465

65-
assert len(hidden_states_wrapper) == 1
66-
hidden_states = hidden_states_wrapper.pop()
6766
if dynamic_scale is None:
67+
unquantized_hidden_states = hidden_states
6868
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
6969
hidden_states)
70+
# Dispose the original unquantized hidden states
71+
# to save npu memory because they're no longer used.
72+
dispose_tensor(unquantized_hidden_states)
7073
else:
7174
pertoken_scale = dynamic_scale
7275

@@ -188,11 +191,8 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor,
188191
if quant_mode == 0:
189192
dynamic_scale = None
190193

191-
# place hidden_states in a list to transfer its ownership into the `apply_mlp` function
192-
hidden_states_wrapper = [expand_x]
193-
del expand_x
194-
195-
down_out_list = apply_mlp(hidden_states_wrapper,
194+
# `expand_x` will be disposed in the `apply_mlp` function
195+
down_out_list = apply_mlp(expand_x,
196196
w1,
197197
w1_scale,
198198
w2,
@@ -321,10 +321,8 @@ def fused_experts_with_all2all(
321321
expert_tokens = expert_tokens.to(torch.int64)
322322
group_list_type = 0
323323

324-
hidden_states_wrapper = [hidden_states]
325-
del hidden_states
326-
327-
hidden_states = apply_mlp(hidden_states_wrapper,
324+
# `hidden_states` will be disposed in the `apply_mlp` function
325+
hidden_states = apply_mlp(hidden_states,
328326
w1,
329327
w1_scale,
330328
w2,
@@ -439,11 +437,8 @@ def fused_experts(hidden_states: torch.Tensor,
439437
expert_tokens = expert_tokens.to(torch.int64)
440438
group_list_type = 0
441439

442-
# place hidden_states in a list to transfer its ownership into the `apply_mlp` function
443-
hidden_states_wrapper = [hidden_states]
444-
del hidden_states
445-
446-
hidden_states = apply_mlp(hidden_states_wrapper,
440+
# `hidden_states` will be disposed in the `apply_mlp` function
441+
hidden_states = apply_mlp(hidden_states,
447442
w1,
448443
w1_scale,
449444
w2,

vllm_ascend/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,3 +169,7 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
169169
"No adjustment needed for ACL graph batch sizes: %s model (layers: %d) with %d sizes",
170170
vllm_config.model_config.architectures[0], num_hidden_layers,
171171
len(original_sizes))
172+
173+
174+
def dispose_tensor(x: torch.Tensor):
175+
x.set_(torch.empty((0, ), device=x.device, dtype=x.dtype))

vllm_ascend/worker/model_runner_v1.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -240,10 +240,11 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
240240
device="cpu",
241241
pin_memory=True)
242242

243-
self.inputs_embeds = torch.zeros(
244-
(self.max_num_tokens, self.hidden_size),
245-
dtype=self.dtype,
246-
device=self.device)
243+
if self.is_multimodal_model:
244+
self.inputs_embeds = torch.zeros(
245+
(self.max_num_tokens, self.hidden_size),
246+
dtype=self.dtype,
247+
device=self.device)
247248

248249
# OPTIMIZATION: Cache the tensors rather than creating them every step.
249250
self.arange_np: npt.NDArray[np.int32] = np.arange(max(

0 commit comments

Comments
 (0)