Skip to content

Commit b02ad40

Browse files
harygo22weijinqian_v1
authored andcommitted
revert
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
1 parent eb54e22 commit b02ad40

File tree

1 file changed

+0
-99
lines changed

1 file changed

+0
-99
lines changed

vllm_ascend/models/qwen3_moe.py

Lines changed: 0 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -15,26 +15,10 @@
1515
# limitations under the License.
1616
# Adapted from vllm/model_executor/models/qwen3_moe.py
1717
# This file is a part of the vllm-ascend project.
18-
from typing import Optional
1918

20-
import torch
21-
import vllm
22-
from torch import nn
23-
from transformers import PretrainedConfig
24-
from vllm.attention import AttentionMetadata
25-
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
26-
from vllm.distributed.parallel_state import get_dp_group
27-
from vllm.forward_context import get_forward_context
28-
from vllm.model_executor.layers.linear import ReplicatedLinear
29-
from vllm.model_executor.layers.quantization import QuantizationConfig
3019
from vllm.model_executor.models.qwen3_moe import Qwen3MoeForCausalLM
31-
from vllm.distributed.parallel_state import get_ep_group
32-
from vllm.forward_context import get_forward_context
3320

3421

35-
from vllm_ascend.ascend_config import get_ascend_config
36-
from vllm_ascend.ops.fused_moe import AscendFusedMoE
37-
3822
class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
3923
packed_modules_mapping = {
4024
"qkv_proj": [
@@ -49,86 +33,3 @@ class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
4933
"experts":
5034
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
5135
}
52-
53-
54-
class AscendQwen3MoeSparseMoeBlock(nn.Module):
55-
top_k: int
56-
57-
def __init__(
58-
self,
59-
config: PretrainedConfig,
60-
quant_config: Optional[QuantizationConfig] = None,
61-
prefix: str = "",
62-
):
63-
super().__init__()
64-
self.tp_size = get_tensor_model_parallel_world_size()
65-
if self.tp_size > config.num_experts:
66-
raise ValueError(
67-
f"Tensor parallel size {self.tp_size} is greater than "
68-
f"the number of experts {config.num_experts}.")
69-
70-
ascend_config = get_ascend_config()
71-
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
72-
self.enable_multistream_moe = \
73-
ascend_config.torchair_graph_config.enable_multistream_moe
74-
75-
self.gate = ReplicatedLinear(config.hidden_size,
76-
config.num_experts,
77-
bias=False,
78-
quant_config=None,
79-
prefix=f"{prefix}.gate")
80-
81-
self.experts = AscendFusedMoE(
82-
num_experts=config.num_experts,
83-
top_k=config.num_experts_per_tok,
84-
hidden_size=config.hidden_size,
85-
intermediate_size=config.moe_intermediate_size,
86-
reduce_results=False,
87-
renormalize=config.norm_topk_prob,
88-
quant_config=quant_config,
89-
prefix=f"{prefix}.experts")
90-
91-
self.top_k = config.num_experts_per_tok
92-
93-
self.dp_size = get_dp_group().world_size
94-
95-
self.tp_group = get_tp_group().device_group
96-
self.tp_rank = get_tp_group().rank_in_group
97-
self.ep_group = get_ep_group()
98-
99-
self.params_dtype = torch.get_default_dtype()
100-
101-
def forward(
102-
self,
103-
hidden_states: torch.Tensor,
104-
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
105-
if attn_metadata is None:
106-
attn_metadata = get_forward_context().attn_metadata
107-
# when profile runs, force experts to load balanced tokens
108-
# to avoid high memory consumption on a single rank.
109-
# TODO: need a better flag to indicate whether in profile run or not.
110-
if attn_metadata is None:
111-
# for profile run
112-
is_prefill = True
113-
enable_force_load_balance = True
114-
else:
115-
is_prefill = get_forward_context().with_prefill
116-
enable_force_load_balance = False
117-
# if hasattr(attn_metadata, 'with_prefill_across_dp'):
118-
# is_prefill = attn_metadata.with_prefill_across_dp
119-
120-
# router_logits: (num_tokens, n_experts)
121-
router_logits, _ = self.gate(hidden_states)
122-
123-
hidden_states = self.experts(
124-
hidden_states=hidden_states,
125-
router_logits=router_logits,
126-
is_prefill=is_prefill,
127-
top_k=self.top_k,
128-
enable_force_load_balance=enable_force_load_balance,
129-
shared_experts=None)
130-
131-
return hidden_states
132-
133-
134-
vllm.model_executor.models.qwen3_moe.Qwen3MoeSparseMoeBlock = AscendQwen3MoeSparseMoeBlock

0 commit comments

Comments
 (0)