Skip to content

Commit fbe0c59

Browse files
authored
Merge pull request #4 from harygo22/dbo_091_v2
fix bug after merge with 091
2 parents d556d49 + 2e9dd59 commit fbe0c59

File tree

5 files changed

+29
-48
lines changed

5 files changed

+29
-48
lines changed

vllm_ascend/ascend_forward_context.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@ class FusedMoEState(Enum):
2121
def get_fused_moe_state(ep_size: int, with_prefill: bool):
2222
if ep_size == 1:
2323
return FusedMoEState.AllGather
24-
elif with_prefill and envs_ascend.VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ:
25-
return FusedMoEState.All2AllSeq
24+
elif envs_ascend.VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ:
25+
# MC2 Dispatch/Combine performs better than alltoall_seq in decoding stage.
26+
return FusedMoEState.All2AllSeq if (ep_size < 16 or with_prefill) else FusedMoEState.MC2
2627
# NOTE: mc2 need ep_size >= 16 & all2all can't use in torchair graph.
2728
elif ep_size < 16 or with_prefill:
2829
return FusedMoEState.All2All

vllm_ascend/models/moe_block.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,10 +98,8 @@ def forward(
9898
is_prefill = True
9999
enable_force_load_balance = True
100100
else:
101-
is_prefill = False
101+
is_prefill = get_forward_context().with_prefill
102102
enable_force_load_balance = False
103-
if hasattr(attn_metadata, 'with_prefill_across_dp'):
104-
is_prefill = attn_metadata.with_prefill_across_dp
105103

106104
# router_logits: (num_tokens, n_experts)
107105
router_logits, _ = self.gate(hidden_states)

vllm_ascend/models/qwen3_dbo.py

Lines changed: 22 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import torch_npu
77
from torch import nn
88
from transformers import PretrainedConfig
9-
from vllm.compilation.decorators import support_torch_compile
109

1110
from vllm.model_executor.models.qwen3_moe import Qwen3MoeDecoderLayer, Qwen3MoeModel
1211
from vllm.config import CacheConfig, VllmConfig
@@ -22,6 +21,7 @@
2221
from vllm.model_executor.models.qwen3_moe import Qwen3MoeForCausalLM
2322
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
2423
from vllm.model_executor.layers.logits_processor import LogitsProcessor
24+
from vllm.compilation.decorators import support_torch_compile
2525

2626
from vllm_ascend.multistream.context import (
2727
advance_step_multistream_layer_context, get_multistream_comm_context,
@@ -35,6 +35,7 @@
3535
from vllm_ascend.ops.fused_moe import AscendFusedMoE, select_experts, apply_mlp
3636
from vllm_ascend.distributed.tensor_parallel import gather_from_sequence_parallel_region
3737
import vllm_ascend.envs as envs_ascend
38+
from vllm_ascend.models.qwen3_moe import CustomQwen3MoeForCausalLM
3839

3940
VLLM_ASCEND_ENABLE_DBO: bool = envs_ascend.VLLM_ASCEND_ENABLE_DBO
4041

@@ -197,7 +198,7 @@ def _forward_op_grouped_mlp(
197198
self, dispatched_input, tokens_per_expert
198199
):
199200
return apply_mlp(
200-
[dispatched_input],
201+
dispatched_input,
201202
self.mlp.experts.w13_weight,
202203
self.mlp.experts.w2_weight,
203204
tokens_per_expert
@@ -207,8 +208,7 @@ def _forward_combine_comm(
207208
self, hidden_states, microbatch_id, num_tokens, chunked_hidden_states_sizes
208209
):
209210
token_dispatcher = self.mlp.experts.token_dispatchers[microbatch_id]
210-
token_dispatcher.combine_alltoall()
211-
final_hidden_states = token_dispatcher.unpermute2()
211+
final_hidden_states, _ = token_dispatcher.token_unpermutation(hidden_states)
212212
if hasattr(self.mlp, 'routed_scaling_factor'):
213213
final_hidden_states = final_hidden_states * self.mlp.routed_scaling_factor
214214

@@ -267,13 +267,10 @@ def discard_tensor(tensor):
267267
# communication in the previous layer, and the attn computation of microbatch 2
268268
# can be overlapped with the attn communication of microbatch 1
269269
for i in range(num_micro_batchs):
270-
# wait last layer moe finishing communication
271-
ms_metadata.try_wait_event(layer_index - 1, i,
272-
MSEventKey.MOE_AFTER_COMM)
273-
274270
forward_context = get_forward_context()
275271
layer_index, ms_metadata, attn_metadata = get_multistream_layer_context(
276272
)
273+
ms_metadata.try_wait_event(layer_index - 1, i, MSEventKey.FFN_AR_FINISH)
277274
forward_context.attn_metadata = attn_metadata[i]
278275

279276
# input layernorm
@@ -309,36 +306,25 @@ def discard_tensor(tensor):
309306
with torch.npu.stream(dispatch_context.comm_stream):
310307
dispatch_context.comm_stream.wait_event(dispatch_context.before_comm_event)
311308
token_dispatchers[i].dispatch_alltoall()
309+
dispatched_input[i], tokens_per_expert[i] = token_dispatchers[i].permute2()
312310
dispatch_context.after_comm_event.record()
313311

314-
if has_shared_expert:
315-
token_dispatchers[i].cached_shared_expert_output = tensor_model_parallel_all_reduce(
316-
token_dispatchers[i].cached_shared_expert_output
317-
)
318-
ms_metadata.ms_events[layer_index][i][MSEventKey.MOE_SE_COMM_FINISH].record()
319-
320312
# print_with_sync('begin experts...', torch.distributed.get_rank())
321313
# block 4 : Router Experts Computation
322314
# block 5 : Token Combine Communication
323315
for i in range(num_micro_batchs):
324-
325316
ms_metadata.try_wait_event(layer_index, i, MSEventKey.MOE_AFTER_COMM)
326317
discard_tensor(hidden_states[i])
327-
328-
dispatched_input[i], tokens_per_expert[i] = token_dispatchers[i].permute2()
329318
router_expert_output[i] = self._forward_op_grouped_mlp(dispatched_input[i], tokens_per_expert[i])
330319
discard_tensor(dispatched_input[i])
331-
token_dispatchers[i].unpermute1(router_expert_output[i])
332-
if router_expert_output[i].shape[0] > 0 and token_dispatchers[i].num_local_experts > 1:
333-
discard_tensor(router_expert_output[i])
334320

335321
# Launch Combine Comm in a New Stream.
336322
combine_context = MultiStreamStepMetadata(
337323
comm_stream=ms_metadata.communicate_stream,
338324
before_comm_event=ms_metadata.ms_events[layer_index][i][
339-
MSEventKey.MOE_BEFORE_COMM],
325+
MSEventKey.FFN_COM_FINISH],
340326
after_comm_event=ms_metadata.ms_events[layer_index][i][
341-
MSEventKey.MOE_AFTER_COMM],
327+
MSEventKey.FFN_AR_FINISH],
342328
)
343329
combine_context.before_comm_event.record()
344330
ms_metadata.try_wait_event(layer_index, i, MSEventKey.MOE_SE_COMM_FINISH)
@@ -347,7 +333,7 @@ def discard_tensor(tensor):
347333
hidden_states[i] = self._forward_combine_comm(
348334
router_expert_output[i], i, num_tokens[i], chunked_hidden_states_sizes[i]
349335
)
350-
combine_context.after_comm_event.record()
336+
ms_metadata.ms_events[layer_index][i][MSEventKey.FFN_AR_FINISH] = combine_context.comm_stream.record_event()
351337

352338
return hidden_states, residual
353339

@@ -443,11 +429,10 @@ def forward(
443429
def can_run_ms(self):
444430
attn_metadata = get_forward_context().attn_metadata
445431
# enable prefill overlap
446-
with_prefill = getattr(attn_metadata, "with_prefill_across_dp", False)
432+
with_prefill = get_forward_context().with_prefill
447433
if attn_metadata is None or not with_prefill or not attn_metadata.enable_dbo_across_dp:
448434
return False
449-
# if torch.distributed.get_rank() == 0:
450-
# print(attn_metadata)
435+
451436
return True
452437

453438
def _forward_ms_layers(
@@ -465,9 +450,7 @@ def _forward_ms_layers(
465450
attn_metadata, [positions, hidden_states,
466451
residual] = self.ms_pre_layer(
467452
[positions, hidden_states, residual], )
468-
# if torch.distributed.get_rank() == 0:
469-
# print(attn_metadata[0], attn_metadata[1])
470-
# exit()
453+
num_micro_batch = len(attn_metadata)
471454
# the rest layers
472455
for i in range(moe_start_layer, self.end_layer):
473456
layer = self.layers[i]
@@ -481,6 +464,11 @@ def _forward_ms_layers(
481464
)
482465
advance_step_multistream_layer_context()
483466

467+
layer_index, ms_metadata, attn_metadata = get_multistream_layer_context()
468+
for i in range(num_micro_batch):
469+
ms_metadata.try_wait_event(layer_index - 1, i, MSEventKey.FFN_AR_FINISH)
470+
471+
484472
[hidden_states,
485473
residual] = self.ms_post_layer([hidden_states, residual], )
486474
return hidden_states, residual
@@ -517,17 +505,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
517505
self.logits_processor = LogitsProcessor(config.vocab_size)
518506
self.make_empty_intermediate_tensors = (
519507
self.model.make_empty_intermediate_tensors)
508+
509+
def forward(self, *args, **kwargs):
510+
if "graph_enable" in kwargs:
511+
kwargs.pop('graph_enable')
512+
return super().forward(*args, **kwargs)
520513

521-
def forward(
522-
self,
523-
input_ids: torch.Tensor,
524-
positions: torch.Tensor,
525-
intermediate_tensors: Optional[IntermediateTensors] = None,
526-
inputs_embeds: Optional[torch.Tensor] = None,
527-
graph_enable: Optional[bool] = True
528-
) -> Union[torch.Tensor, IntermediateTensors]:
529-
hidden_states = self.model(input_ids, positions, intermediate_tensors,
530-
inputs_embeds)
531-
return hidden_states
532514

533515

vllm_ascend/multistream/ms_split.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -324,13 +324,13 @@ def model_input_split_v1_attn(
324324
query_start_loc=query_start_loc_pre,
325325
query_lens=query_lens_pre,
326326
seq_lens=seq_lens_pre,
327+
seq_lens_list=seq_lens_pre.tolist(),
327328
max_query_len=max_query_len_pre,
328329
slot_mapping=slot_mapping_pre,
329330
is_only_prefill=is_only_prefill_pre,
330331
attn_state=attn_state_pre,
331332
attn_mask=attn_mask_pre,
332333
num_input_tokens=token_index,
333-
with_prefill_across_dp=attn_metadata.with_prefill_across_dp,
334334
enable_dbo_across_dp=attn_metadata.enable_dbo_across_dp,
335335
)
336336

@@ -340,13 +340,13 @@ def model_input_split_v1_attn(
340340
query_start_loc=query_start_loc_post,
341341
query_lens=query_lens_post,
342342
seq_lens=seq_lens_post,
343+
seq_lens_list=seq_lens_post.tolist(),
343344
max_query_len=max_query_len_post,
344345
slot_mapping=slot_mapping_post,
345346
is_only_prefill=is_only_prefill_post,
346347
attn_state=attn_state_post,
347348
attn_mask=attn_mask_post,
348349
num_input_tokens=attn_metadata.num_input_tokens - token_index,
349-
with_prefill_across_dp=attn_metadata.with_prefill_across_dp,
350350
enable_dbo_across_dp=attn_metadata.enable_dbo_across_dp,
351351
)
352352

vllm_ascend/ops/fused_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1003,7 +1003,7 @@ def apply(
10031003
global_batch_size=self.global_batch_size,
10041004
expert_map=expert_map,
10051005
ep_group=get_ep_group())
1006-
elif fused_moe_state == FusedMoEState.All2AllSeq and is_prefill:
1006+
elif fused_moe_state == FusedMoEState.All2AllSeq:
10071007
token_dispatcher = kwargs.get('token_dispatcher')
10081008
return fused_experts_with_all2allv(token_dispatcher=token_dispatcher,
10091009
probs=topk_weights,

0 commit comments

Comments
 (0)