Skip to content

Commit a878a6d

Browse files
authored
liscens & fix dsk dbo.
* ut test * liscense & fix dsk dbo.
1 parent d470373 commit a878a6d

File tree

4 files changed

+25
-68
lines changed

4 files changed

+25
-68
lines changed

vllm_ascend/distributed/tensor_parallel.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# Copyright (c) 2024; NVIDIA CORPORATION. All rights reserved.
12
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
23
# Copyright 2023 The vLLM team.
34
#

vllm_ascend/models/deepseek_dbo.py

Lines changed: 22 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@
5555
from vllm.sequence import IntermediateTensors
5656

5757
import vllm_ascend.envs as envs_ascend
58-
from vllm_ascend.ascend_config import get_ascend_config
5958
from vllm_ascend.distributed.tensor_parallel import gather_from_sequence_parallel_region
6059
from vllm_ascend.ascend_forward_context import FusedMoEState
6160
from vllm_ascend.models.deepseek_v2 import (CustomDeepseekV2DecoderLayer,
@@ -72,8 +71,7 @@
7271
make_multistream_metadata_ds)
7372
from vllm_ascend.quantization.w8a8_dynamic import (
7473
AscendW8A8DynamicLinearMethod, apply_mlp)
75-
from vllm_ascend.ops.fused_moe import AscendFusedMoE, apply_mlp, select_experts
76-
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
74+
from vllm_ascend.ops.fused_moe import apply_mlp, select_experts
7775
from vllm_ascend.utils import dispose_tensor
7876

7977
VLLM_ASCEND_ENABLE_DBO: bool = envs_ascend.VLLM_ASCEND_ENABLE_DBO
@@ -94,7 +92,8 @@ def __init__(
9492
intermediate_size=intermediate_size,
9593
hidden_act=hidden_act,
9694
quant_config=quant_config,
97-
prefix=prefix)
95+
prefix=prefix,
96+
reduce_results=reduce_results)
9897
self.is_dynamic_quant = not isinstance(
9998
self.gate_up_proj.quant_method,
10099
UnquantizedLinearMethod) and isinstance(
@@ -152,19 +151,6 @@ def __init__(
152151
prefix=f"{prefix}.shared_experts",
153152
)
154153
CustomDeepseekDBOMoE.top_k = config.num_experts_per_tok
155-
156-
self.dp_size = get_dp_group().world_size
157-
158-
self.tp_group = get_tp_group().device_group
159-
self.tp_rank = get_tp_group().rank_in_group
160-
self.kv_consumer = None
161-
transfer_config = get_current_vllm_config().kv_transfer_config
162-
if transfer_config is not None:
163-
self.kv_consumer = transfer_config.kv_role = "kv_consumer"
164-
self.params_dtype = torch.get_default_dtype()
165-
166-
ascend_config = get_ascend_config()
167-
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
168154
self.config = config
169155

170156
def forward(
@@ -196,9 +182,13 @@ def forward(
196182
enable_force_load_balance=enable_force_load_balance,
197183
shared_experts=self.shared_experts)
198184

185+
shared_experts_hidden = experts_hidden_states[1]
186+
if not (self.shared_experts.down_proj.reduce_results and self.shared_experts.down_proj.tp_size > 1):
187+
shared_experts_hidden = tensor_model_parallel_all_reduce(shared_experts_hidden)
188+
199189
hidden_states = (
200190
experts_hidden_states[0] * self.routed_scaling_factor +
201-
experts_hidden_states[1])
191+
shared_experts_hidden)
202192

203193
return hidden_states
204194

@@ -225,18 +215,10 @@ def _forward_op_gating(
225215
) -> torch.Tensor:
226216
if attn_metadata is None:
227217
attn_metadata = get_forward_context().attn_metadata
228-
# when profile runs, force experts to load balanced tokens
229-
# to avoid high memory consumption on a single rank.
230-
# TODO: need a better flag to indicate whether in profile run or not.
231-
if attn_metadata is None:
232-
# for profile run
233-
self.is_prefill = True
234-
self.enable_force_load_balance = True
235-
else:
236-
is_prefill = attn_metadata.num_prefills > 0
237-
self.enable_force_load_balance = False
238-
if hasattr(attn_metadata, 'with_prefill_across_dp'):
239-
self.is_prefill = is_prefill or attn_metadata.with_prefill_across_dp
218+
# when profile runs, force experts to load balanced tokens
219+
# to avoid high memory consumption on a single rank.
220+
# TODO: need a better flag to indicate whether in profile run or not.
221+
enable_force_load_balance = get_forward_context().in_profile_run
240222

241223
num_tokens, hidden_dim = hidden_states.shape
242224

@@ -291,17 +273,11 @@ def _forward_op_gating(
291273
# this is a naive implementation for experts load balance so as
292274
# to avoid accumulating too much tokens on a single rank.
293275
# currently it is only activated when doing profile runs.
294-
if self.enable_force_load_balance:
276+
if enable_force_load_balance:
295277
topk_ids = torch.randint_like(topk_ids, 0, self.config.n_routed_experts)
296278

297279
return topk_weights, topk_ids, local_hidden_states, chunked_hidden_states_sizes
298280

299-
def _forward_dispatch_comm(
300-
self, hidden_states, topk_weights, topk_ids, microbatch_id
301-
):
302-
token_dispatcher = self.experts.token_dispatchers[microbatch_id]
303-
_, hidden_states, tokens_per_expert = token_dispatcher.token_permutation(hidden_states, topk_weights, topk_ids)
304-
return hidden_states, tokens_per_expert
305281

306282
def _forward_op_shared_experts(
307283
self, hidden_states
@@ -315,7 +291,7 @@ def _forward_op_grouped_mlp(
315291
self, dispatched_input, tokens_per_expert
316292
):
317293
return apply_mlp(
318-
[dispatched_input],
294+
dispatched_input,
319295
self.experts.w13_weight,
320296
self.experts.w2_weight,
321297
tokens_per_expert
@@ -325,8 +301,9 @@ def _forward_combine_comm(
325301
self, hidden_states, microbatch_id, num_tokens, chunked_hidden_states_sizes
326302
):
327303
token_dispatcher = self.experts.token_dispatchers[microbatch_id]
328-
token_dispatcher.combine_alltoall()
329-
final_hidden_states = token_dispatcher.unpermute2() * self.routed_scaling_factor
304+
final_hidden_states, _ = token_dispatcher.token_unpermutation(hidden_states)
305+
if hasattr(self, 'routed_scaling_factor'):
306+
final_hidden_states = final_hidden_states * self.routed_scaling_factor
330307

331308
if self.tp_size > 1:
332309
final_hidden_states = gather_from_sequence_parallel_region(final_hidden_states, self.tp_group,
@@ -794,17 +771,12 @@ def _forward_ms_layer_alltoallv_finegrained(
794771
chunked_hidden_states_sizes = [None] * num_micro_batchs
795772
token_dispatchers = self.mlp.experts.token_dispatchers
796773

797-
def print_with_sync(*args, **kwargs):
798-
torch.npu.synchronize()
799-
print(*args, **kwargs)
800-
801774
def discard_tensor(tensor):
802775
if isinstance(tensor, torch.Tensor):
803776
tensor = [tensor]
804777
for t in tensor:
805778
t.untyped_storage().resize_(0)
806779

807-
# print_with_sync('begin layer...', torch.distributed.get_rank())
808780

809781
# block 1 : attention
810782
# block 2 : Router Gating
@@ -814,12 +786,11 @@ def discard_tensor(tensor):
814786
# can be overlapped with the attn communication of microbatch 1
815787
for i in range(num_micro_batchs):
816788
# wait last layer moe finishing communication
817-
ms_metadata.try_wait_event(layer_index - 1, i,
818-
MSEventKey.MOE_AFTER_COMM)
819789

820790
forward_context = get_forward_context()
821791
layer_index, ms_metadata, attn_metadata = get_multistream_layer_context(
822792
)
793+
ms_metadata.try_wait_event(layer_index - 1, i, MSEventKey.FFN_AR_FINISH)
823794
forward_context.attn_metadata = attn_metadata[i]
824795

825796
# input layernorm
@@ -856,9 +827,10 @@ def discard_tensor(tensor):
856827
with torch.npu.stream(dispatch_context.comm_stream):
857828
dispatch_context.comm_stream.wait_event(dispatch_context.before_comm_event)
858829
token_dispatchers[i].dispatch_alltoall()
830+
dispatched_input[i], tokens_per_expert[i] = token_dispatchers[i].permute2()
859831
dispatch_context.after_comm_event.record()
860832

861-
if self.mlp.n_shared_experts:
833+
if self.mlp.n_shared_experts and self.tp_size > 1:
862834
token_dispatchers[i].cached_shared_expert_output = tensor_model_parallel_all_reduce(
863835
token_dispatchers[i].cached_shared_expert_output
864836
)
@@ -872,20 +844,16 @@ def discard_tensor(tensor):
872844
ms_metadata.try_wait_event(layer_index, i, MSEventKey.MOE_AFTER_COMM)
873845
discard_tensor(hidden_states[i])
874846

875-
dispatched_input[i], tokens_per_expert[i] = token_dispatchers[i].permute2()
876847
router_expert_output[i] = self.mlp._forward_op_grouped_mlp(dispatched_input[i], tokens_per_expert[i])
877848
discard_tensor(dispatched_input[i])
878-
token_dispatchers[i].unpermute1(router_expert_output[i])
879-
if router_expert_output[i].shape[0] > 0 and token_dispatchers[i].num_local_experts > 1:
880-
discard_tensor(router_expert_output[i])
881849

882850
# Launch Combine Comm in a New Stream.
883851
combine_context = MultiStreamStepMetadata(
884852
comm_stream=ms_metadata.communicate_stream,
885853
before_comm_event=ms_metadata.ms_events[layer_index][i][
886-
MSEventKey.MOE_BEFORE_COMM],
854+
MSEventKey.FFN_COM_FINISH],
887855
after_comm_event=ms_metadata.ms_events[layer_index][i][
888-
MSEventKey.MOE_AFTER_COMM],
856+
MSEventKey.FFN_AR_FINISH],
889857
)
890858
combine_context.before_comm_event.record()
891859
ms_metadata.try_wait_event(layer_index, i, MSEventKey.MOE_SE_COMM_FINISH)
@@ -1032,7 +1000,6 @@ def forward(
10321000
if VLLM_ASCEND_ENABLE_DBO and not graph_enable
10331001
and self.can_run_ms() else self.end_layer -
10341002
self.start_layer)
1035-
10361003
moe_start_layer = self.start_layer + num_normal_layers
10371004
for i in range(self.start_layer, min(moe_start_layer, self.end_layer)):
10381005
layer = self.layers[i]
@@ -1068,16 +1035,6 @@ def can_run_ms(self):
10681035
return False
10691036
return True
10701037

1071-
def all_can_run_ms(self):
1072-
can_run_ms_local = self.can_run_ms()
1073-
ep_group = get_ep_group().cpu_group
1074-
flag = torch.ones(1, dtype=torch.int) if can_run_ms_local else torch.zeros(1, dtype=torch.int)
1075-
torch.distributed.all_reduce(flag, group=ep_group)
1076-
if flag.item() == torch.distributed.get_world_size(ep_group):
1077-
return True
1078-
else:
1079-
return False
1080-
10811038
def _forward_ms_layers(self,
10821039
positions: torch.Tensor,
10831040
hidden_states: torch.Tensor,
@@ -1098,9 +1055,7 @@ def _forward_ms_layers(self,
10981055
layer = self.layers[i]
10991056
ms_layer_forward_func = layer._forward_ms_layer
11001057
if fused_moe_state == FusedMoEState.All2AllSeq:
1101-
# ms_layer_forward_func = layer._forward_ms_layer_alltoallv
11021058
ms_layer_forward_func = layer._forward_ms_layer_alltoallv_finegrained
1103-
# print("get_called......")
11041059
hidden_states, residual = ms_layer_forward_func(
11051060
positions=positions,
11061061
hidden_states=hidden_states,

vllm_ascend/ops/moe_dispatcher/moe_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#
1+
# Copyright (c) 2024; NVIDIA CORPORATION. All rights reserved.
22
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");

vllm_ascend/ops/moe_dispatcher/token_dispatcher.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# SPDX-License-Identifier: Apache-2.0
2+
# Copyright (c) 2024; NVIDIA CORPORATION. All rights reserved.
23
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
34
# Copyright 2023 The vLLM team.
45
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.

0 commit comments

Comments
 (0)