Skip to content

Commit 854c149

Browse files
harygo22weijinqian_v1
authored andcommitted
fix bug
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
1 parent 3f88769 commit 854c149

File tree

5 files changed

+125
-44
lines changed

5 files changed

+125
-44
lines changed

vllm_ascend/ascend_forward_context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class FusedMoEState(Enum):
2121
def get_fused_moe_state(ep_size: int, with_prefill: bool):
2222
if ep_size == 1:
2323
return FusedMoEState.AllGather
24-
elif with_prefill and envs_ascend.VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ:
24+
elif envs_ascend.VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ:
2525
return FusedMoEState.All2AllSeq
2626
# NOTE: mc2 need ep_size >= 16 & all2all can't use in torchair graph.
2727
elif ep_size < 16 or with_prefill:

vllm_ascend/models/qwen3_dbo.py

Lines changed: 22 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import torch_npu
77
from torch import nn
88
from transformers import PretrainedConfig
9-
from vllm.compilation.decorators import support_torch_compile
109

1110
from vllm.model_executor.models.qwen3_moe import Qwen3MoeDecoderLayer, Qwen3MoeModel
1211
from vllm.config import CacheConfig, VllmConfig
@@ -22,6 +21,7 @@
2221
from vllm.model_executor.models.qwen3_moe import Qwen3MoeForCausalLM
2322
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
2423
from vllm.model_executor.layers.logits_processor import LogitsProcessor
24+
from vllm.compilation.decorators import support_torch_compile
2525

2626
from vllm_ascend.multistream.context import (
2727
advance_step_multistream_layer_context, get_multistream_comm_context,
@@ -35,6 +35,7 @@
3535
from vllm_ascend.ops.fused_moe import AscendFusedMoE, select_experts, apply_mlp
3636
from vllm_ascend.distributed.tensor_parallel import gather_from_sequence_parallel_region
3737
import vllm_ascend.envs as envs_ascend
38+
from vllm_ascend.models.qwen3_moe import CustomQwen3MoeForCausalLM
3839

3940
VLLM_ASCEND_ENABLE_DBO: bool = envs_ascend.VLLM_ASCEND_ENABLE_DBO
4041

@@ -197,7 +198,7 @@ def _forward_op_grouped_mlp(
197198
self, dispatched_input, tokens_per_expert
198199
):
199200
return apply_mlp(
200-
[dispatched_input],
201+
dispatched_input,
201202
self.mlp.experts.w13_weight,
202203
self.mlp.experts.w2_weight,
203204
tokens_per_expert
@@ -207,8 +208,7 @@ def _forward_combine_comm(
207208
self, hidden_states, microbatch_id, num_tokens, chunked_hidden_states_sizes
208209
):
209210
token_dispatcher = self.mlp.experts.token_dispatchers[microbatch_id]
210-
token_dispatcher.combine_alltoall()
211-
final_hidden_states = token_dispatcher.unpermute2()
211+
final_hidden_states, _ = token_dispatcher.token_unpermutation(hidden_states)
212212
if hasattr(self.mlp, 'routed_scaling_factor'):
213213
final_hidden_states = final_hidden_states * self.mlp.routed_scaling_factor
214214

@@ -267,13 +267,10 @@ def discard_tensor(tensor):
267267
# communication in the previous layer, and the attn computation of microbatch 2
268268
# can be overlapped with the attn communication of microbatch 1
269269
for i in range(num_micro_batchs):
270-
# wait last layer moe finishing communication
271-
ms_metadata.try_wait_event(layer_index - 1, i,
272-
MSEventKey.MOE_AFTER_COMM)
273-
274270
forward_context = get_forward_context()
275271
layer_index, ms_metadata, attn_metadata = get_multistream_layer_context(
276272
)
273+
ms_metadata.try_wait_event(layer_index - 1, i, MSEventKey.FFN_AR_FINISH)
277274
forward_context.attn_metadata = attn_metadata[i]
278275

279276
# input layernorm
@@ -309,36 +306,25 @@ def discard_tensor(tensor):
309306
with torch.npu.stream(dispatch_context.comm_stream):
310307
dispatch_context.comm_stream.wait_event(dispatch_context.before_comm_event)
311308
token_dispatchers[i].dispatch_alltoall()
309+
dispatched_input[i], tokens_per_expert[i] = token_dispatchers[i].permute2()
312310
dispatch_context.after_comm_event.record()
313311

314-
if has_shared_expert:
315-
token_dispatchers[i].cached_shared_expert_output = tensor_model_parallel_all_reduce(
316-
token_dispatchers[i].cached_shared_expert_output
317-
)
318-
ms_metadata.ms_events[layer_index][i][MSEventKey.MOE_SE_COMM_FINISH].record()
319-
320312
# print_with_sync('begin experts...', torch.distributed.get_rank())
321313
# block 4 : Router Experts Computation
322314
# block 5 : Token Combine Communication
323315
for i in range(num_micro_batchs):
324-
325316
ms_metadata.try_wait_event(layer_index, i, MSEventKey.MOE_AFTER_COMM)
326317
discard_tensor(hidden_states[i])
327-
328-
dispatched_input[i], tokens_per_expert[i] = token_dispatchers[i].permute2()
329318
router_expert_output[i] = self._forward_op_grouped_mlp(dispatched_input[i], tokens_per_expert[i])
330319
discard_tensor(dispatched_input[i])
331-
token_dispatchers[i].unpermute1(router_expert_output[i])
332-
if router_expert_output[i].shape[0] > 0 and token_dispatchers[i].num_local_experts > 1:
333-
discard_tensor(router_expert_output[i])
334320

335321
# Launch Combine Comm in a New Stream.
336322
combine_context = MultiStreamStepMetadata(
337323
comm_stream=ms_metadata.communicate_stream,
338324
before_comm_event=ms_metadata.ms_events[layer_index][i][
339-
MSEventKey.MOE_BEFORE_COMM],
325+
MSEventKey.FFN_COM_FINISH],
340326
after_comm_event=ms_metadata.ms_events[layer_index][i][
341-
MSEventKey.MOE_AFTER_COMM],
327+
MSEventKey.FFN_AR_FINISH],
342328
)
343329
combine_context.before_comm_event.record()
344330
ms_metadata.try_wait_event(layer_index, i, MSEventKey.MOE_SE_COMM_FINISH)
@@ -347,7 +333,7 @@ def discard_tensor(tensor):
347333
hidden_states[i] = self._forward_combine_comm(
348334
router_expert_output[i], i, num_tokens[i], chunked_hidden_states_sizes[i]
349335
)
350-
combine_context.after_comm_event.record()
336+
ms_metadata.ms_events[layer_index][i][MSEventKey.FFN_AR_FINISH] = combine_context.comm_stream.record_event()
351337

352338
return hidden_states, residual
353339

@@ -443,11 +429,10 @@ def forward(
443429
def can_run_ms(self):
444430
attn_metadata = get_forward_context().attn_metadata
445431
# enable prefill overlap
446-
with_prefill = getattr(attn_metadata, "with_prefill_across_dp", False)
432+
with_prefill = get_forward_context().with_prefill
447433
if attn_metadata is None or not with_prefill or not attn_metadata.enable_dbo_across_dp:
448434
return False
449-
# if torch.distributed.get_rank() == 0:
450-
# print(attn_metadata)
435+
451436
return True
452437

453438
def _forward_ms_layers(
@@ -465,9 +450,7 @@ def _forward_ms_layers(
465450
attn_metadata, [positions, hidden_states,
466451
residual] = self.ms_pre_layer(
467452
[positions, hidden_states, residual], )
468-
# if torch.distributed.get_rank() == 0:
469-
# print(attn_metadata[0], attn_metadata[1])
470-
# exit()
453+
num_micro_batch = len(attn_metadata)
471454
# the rest layers
472455
for i in range(moe_start_layer, self.end_layer):
473456
layer = self.layers[i]
@@ -481,6 +464,11 @@ def _forward_ms_layers(
481464
)
482465
advance_step_multistream_layer_context()
483466

467+
layer_index, ms_metadata, attn_metadata = get_multistream_layer_context()
468+
for i in range(num_micro_batch):
469+
ms_metadata.try_wait_event(layer_index - 1, i, MSEventKey.FFN_AR_FINISH)
470+
471+
484472
[hidden_states,
485473
residual] = self.ms_post_layer([hidden_states, residual], )
486474
return hidden_states, residual
@@ -517,17 +505,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
517505
self.logits_processor = LogitsProcessor(config.vocab_size)
518506
self.make_empty_intermediate_tensors = (
519507
self.model.make_empty_intermediate_tensors)
508+
509+
def forward(self, *args, **kwargs):
510+
if "graph_enable" in kwargs:
511+
kwargs.pop('graph_enable')
512+
return super().forward(*args, **kwargs)
520513

521-
def forward(
522-
self,
523-
input_ids: torch.Tensor,
524-
positions: torch.Tensor,
525-
intermediate_tensors: Optional[IntermediateTensors] = None,
526-
inputs_embeds: Optional[torch.Tensor] = None,
527-
graph_enable: Optional[bool] = True
528-
) -> Union[torch.Tensor, IntermediateTensors]:
529-
hidden_states = self.model(input_ids, positions, intermediate_tensors,
530-
inputs_embeds)
531-
return hidden_states
532514

533515

vllm_ascend/models/qwen3_moe.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,26 @@
1515
# limitations under the License.
1616
# Adapted from vllm/model_executor/models/qwen3_moe.py
1717
# This file is a part of the vllm-ascend project.
18+
from typing import Optional
1819

20+
import torch
21+
import vllm
22+
from torch import nn
23+
from transformers import PretrainedConfig
24+
from vllm.attention import AttentionMetadata
25+
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
26+
from vllm.distributed.parallel_state import get_dp_group
27+
from vllm.forward_context import get_forward_context
28+
from vllm.model_executor.layers.linear import ReplicatedLinear
29+
from vllm.model_executor.layers.quantization import QuantizationConfig
1930
from vllm.model_executor.models.qwen3_moe import Qwen3MoeForCausalLM
31+
from vllm.distributed.parallel_state import get_ep_group
32+
from vllm.forward_context import get_forward_context
2033

2134

35+
from vllm_ascend.ascend_config import get_ascend_config
36+
from vllm_ascend.ops.fused_moe import AscendFusedMoE
37+
2238
class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
2339
packed_modules_mapping = {
2440
"qkv_proj": [
@@ -33,3 +49,86 @@ class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
3349
"experts":
3450
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
3551
}
52+
53+
54+
class AscendQwen3MoeSparseMoeBlock(nn.Module):
55+
top_k: int
56+
57+
def __init__(
58+
self,
59+
config: PretrainedConfig,
60+
quant_config: Optional[QuantizationConfig] = None,
61+
prefix: str = "",
62+
):
63+
super().__init__()
64+
self.tp_size = get_tensor_model_parallel_world_size()
65+
if self.tp_size > config.num_experts:
66+
raise ValueError(
67+
f"Tensor parallel size {self.tp_size} is greater than "
68+
f"the number of experts {config.num_experts}.")
69+
70+
ascend_config = get_ascend_config()
71+
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
72+
self.enable_multistream_moe = \
73+
ascend_config.torchair_graph_config.enable_multistream_moe
74+
75+
self.gate = ReplicatedLinear(config.hidden_size,
76+
config.num_experts,
77+
bias=False,
78+
quant_config=None,
79+
prefix=f"{prefix}.gate")
80+
81+
self.experts = AscendFusedMoE(
82+
num_experts=config.num_experts,
83+
top_k=config.num_experts_per_tok,
84+
hidden_size=config.hidden_size,
85+
intermediate_size=config.moe_intermediate_size,
86+
reduce_results=False,
87+
renormalize=config.norm_topk_prob,
88+
quant_config=quant_config,
89+
prefix=f"{prefix}.experts")
90+
91+
self.top_k = config.num_experts_per_tok
92+
93+
self.dp_size = get_dp_group().world_size
94+
95+
self.tp_group = get_tp_group().device_group
96+
self.tp_rank = get_tp_group().rank_in_group
97+
self.ep_group = get_ep_group()
98+
99+
self.params_dtype = torch.get_default_dtype()
100+
101+
def forward(
102+
self,
103+
hidden_states: torch.Tensor,
104+
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
105+
if attn_metadata is None:
106+
attn_metadata = get_forward_context().attn_metadata
107+
# when profile runs, force experts to load balanced tokens
108+
# to avoid high memory consumption on a single rank.
109+
# TODO: need a better flag to indicate whether in profile run or not.
110+
if attn_metadata is None:
111+
# for profile run
112+
is_prefill = True
113+
enable_force_load_balance = True
114+
else:
115+
is_prefill = get_forward_context().with_prefill
116+
enable_force_load_balance = False
117+
# if hasattr(attn_metadata, 'with_prefill_across_dp'):
118+
# is_prefill = attn_metadata.with_prefill_across_dp
119+
120+
# router_logits: (num_tokens, n_experts)
121+
router_logits, _ = self.gate(hidden_states)
122+
123+
hidden_states = self.experts(
124+
hidden_states=hidden_states,
125+
router_logits=router_logits,
126+
is_prefill=is_prefill,
127+
top_k=self.top_k,
128+
enable_force_load_balance=enable_force_load_balance,
129+
shared_experts=None)
130+
131+
return hidden_states
132+
133+
134+
vllm.model_executor.models.qwen3_moe.Qwen3MoeSparseMoeBlock = AscendQwen3MoeSparseMoeBlock

vllm_ascend/multistream/ms_split.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -324,13 +324,13 @@ def model_input_split_v1_attn(
324324
query_start_loc=query_start_loc_pre,
325325
query_lens=query_lens_pre,
326326
seq_lens=seq_lens_pre,
327+
seq_lens_list=seq_lens_pre.tolist(),
327328
max_query_len=max_query_len_pre,
328329
slot_mapping=slot_mapping_pre,
329330
is_only_prefill=is_only_prefill_pre,
330331
attn_state=attn_state_pre,
331332
attn_mask=attn_mask_pre,
332333
num_input_tokens=token_index,
333-
with_prefill_across_dp=attn_metadata.with_prefill_across_dp,
334334
enable_dbo_across_dp=attn_metadata.enable_dbo_across_dp,
335335
)
336336

@@ -340,13 +340,13 @@ def model_input_split_v1_attn(
340340
query_start_loc=query_start_loc_post,
341341
query_lens=query_lens_post,
342342
seq_lens=seq_lens_post,
343+
seq_lens_list=seq_lens_post.tolist(),
343344
max_query_len=max_query_len_post,
344345
slot_mapping=slot_mapping_post,
345346
is_only_prefill=is_only_prefill_post,
346347
attn_state=attn_state_post,
347348
attn_mask=attn_mask_post,
348349
num_input_tokens=attn_metadata.num_input_tokens - token_index,
349-
with_prefill_across_dp=attn_metadata.with_prefill_across_dp,
350350
enable_dbo_across_dp=attn_metadata.enable_dbo_across_dp,
351351
)
352352

vllm_ascend/ops/fused_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -989,7 +989,7 @@ def apply(
989989
global_batch_size=self.global_batch_size,
990990
expert_map=expert_map,
991991
ep_group=get_ep_group())
992-
elif fused_moe_state == FusedMoEState.All2AllSeq and is_prefill:
992+
elif fused_moe_state == FusedMoEState.All2AllSeq:
993993
token_dispatcher = kwargs.get('token_dispatcher')
994994
return fused_experts_with_all2allv(token_dispatcher=token_dispatcher,
995995
probs=topk_weights,

0 commit comments

Comments
 (0)