1
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
2
+ # Copyright 2023 The vLLM team.
3
+ #
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # This file is a part of the vllm-ascend project.
17
+
18
+ from typing import Optional
19
+
20
+ import torch
21
+ from torch import nn
22
+ from vllm .attention import AttentionMetadata
23
+ from vllm .distributed import (get_tensor_model_parallel_world_size ,
24
+ get_tp_group )
25
+ from vllm .distributed .parallel_state import get_dp_group
26
+ from vllm .forward_context import get_forward_context
27
+ from vllm .model_executor .layers .linear import ReplicatedLinear
28
+
29
+ from vllm_ascend .ascend_config import get_ascend_config
30
+ from vllm .distributed .parallel_state import get_ep_group
31
+ from vllm_ascend .ops .fused_moe import AscendFusedMoE
32
+
33
+ from transformers import PretrainedConfig
34
+ from vllm .model_executor .layers .quantization import QuantizationConfig
35
+
36
+
37
+ class AscendSparseMoeBlock (nn .Module ):
38
+
39
+ top_k : int
40
+
41
+ def __init__ (
42
+ self ,
43
+ config : PretrainedConfig ,
44
+ quant_config : Optional [QuantizationConfig ] = None ,
45
+ prefix : str = "" ,
46
+ ):
47
+ super ().__init__ ()
48
+ self .tp_size = get_tensor_model_parallel_world_size ()
49
+ if self .tp_size > config .num_experts :
50
+ raise ValueError (
51
+ f"Tensor parallel size { self .tp_size } is greater than "
52
+ f"the number of experts { config .num_experts } ." )
53
+
54
+ ascend_config = get_ascend_config ()
55
+ self .torchair_graph_enabled = ascend_config .torchair_graph_config .enabled
56
+ self .enable_multistream_moe = \
57
+ ascend_config .torchair_graph_config .enable_multistream_moe
58
+
59
+ self .gate = ReplicatedLinear (config .hidden_size ,
60
+ config .num_experts ,
61
+ bias = False ,
62
+ quant_config = None ,
63
+ prefix = f"{ prefix } .gate" )
64
+
65
+ self .experts = AscendFusedMoE (
66
+ num_experts = config .num_experts ,
67
+ top_k = config .num_experts_per_tok ,
68
+ hidden_size = config .hidden_size ,
69
+ intermediate_size = config .moe_intermediate_size ,
70
+ reduce_results = False ,
71
+ renormalize = config .norm_topk_prob ,
72
+ quant_config = quant_config ,
73
+ prefix = f"{ prefix } .experts" )
74
+
75
+ self .top_k = config .num_experts_per_tok
76
+
77
+ self .dp_size = get_dp_group ().world_size
78
+
79
+ self .tp_group = get_tp_group ().device_group
80
+ self .tp_rank = get_tp_group ().rank_in_group
81
+ self .ep_group = get_ep_group ()
82
+
83
+ self .params_dtype = torch .get_default_dtype ()
84
+
85
+
86
+ def forward (
87
+ self ,
88
+ hidden_states : torch .Tensor ,
89
+ attn_metadata : Optional [AttentionMetadata ] = None ) -> torch .Tensor :
90
+ if attn_metadata is None :
91
+ attn_metadata = get_forward_context ().attn_metadata
92
+ # when profile runs, force experts to load balanced tokens
93
+ # to avoid high memory consumption on a single rank.
94
+ is_prefill = True
95
+ if attn_metadata is None :
96
+ # for profile run
97
+ is_prefill = True
98
+ enable_force_load_balance = True
99
+ else :
100
+ # is_prefill = attn_metadata.num_prefills > 0 is_prefill or
101
+ enable_force_load_balance = False
102
+ if hasattr (attn_metadata , 'with_prefill_across_dp' ):
103
+ is_prefill = attn_metadata .with_prefill_across_dp
104
+
105
+ # router_logits: (num_tokens, n_experts)
106
+ router_logits , _ = self .gate (hidden_states )
107
+
108
+ hidden_states = self .experts (
109
+ hidden_states = hidden_states ,
110
+ router_logits = router_logits ,
111
+ is_prefill = is_prefill ,
112
+ top_k = self .top_k ,
113
+ enable_force_load_balance = enable_force_load_balance ,
114
+ shared_experts = None ,
115
+ )
116
+
117
+ return hidden_states
0 commit comments