Skip to content

Commit d6bfae8

Browse files
authored
support 32K model len on deepseek r1 W8A8 (#728)
### What this PR does / why we need it? Optimize NPU memory usage. #723 vllm v0.8.4.rc2 and DeepSeek R1 can only support a model length of 16K. When attempting to run with a model length of 32K, an "Out of Memory" (OOM) error will occur. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? CI passed Signed-off-by: sunbaosong <13793883820@163.com>
1 parent 79538b5 commit d6bfae8

File tree

1 file changed

+3
-6
lines changed

1 file changed

+3
-6
lines changed

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -276,8 +276,7 @@ def fused_experts(hidden_states: torch.Tensor,
276276
group_list_type=group_list_type)
277277

278278
if expert_map is not None:
279-
weighted_down_out = down_out_list * sorted_weights.unsqueeze(1)
280-
279+
down_out_list.mul_(sorted_weights.unsqueeze(1))
281280
final_hidden_states = torch.zeros(*original_shape,
282281
device=hidden_states.device,
283282
dtype=dtype)
@@ -286,10 +285,8 @@ def fused_experts(hidden_states: torch.Tensor,
286285
valid_token_mask = torch.arange(
287286
0, sorted_token_indices.shape[0],
288287
device=device).unsqueeze(1) < num_valid_tokens
289-
valid_output = torch.where(
290-
valid_token_mask, weighted_down_out,
291-
torch.zeros_like(weighted_down_out)).to(dtype)
292-
final_hidden_states.index_add_(0, sorted_token_indices, valid_output)
288+
down_out_list.mul_(valid_token_mask)
289+
final_hidden_states.index_add_(0, sorted_token_indices, down_out_list)
293290
else:
294291
# TODO: Reorder device memory 2 times here, replace the current
295292
# implementation here when suitable operators become available.

0 commit comments

Comments
 (0)