16
16
# Adapted from vllm/model_executor/models/qwen3_moe.py
17
17
# This file is a part of the vllm-ascend project.
18
18
19
- from typing import Any , Callable , Dict , List , Optional , Tuple , Union
19
+ from typing import Optional
20
20
21
21
import torch
22
- import torch .distributed as dist
23
- import torch_npu
24
22
import vllm
25
- import vllm .envs as envs
26
23
from torch import nn
27
24
from transformers import PretrainedConfig
28
25
from vllm .attention import AttentionMetadata
29
- from vllm .distributed import (get_tensor_model_parallel_world_size ,
30
- get_tp_group )
26
+ from vllm .distributed import get_tensor_model_parallel_world_size , get_tp_group
31
27
from vllm .distributed .parallel_state import get_dp_group
32
28
from vllm .forward_context import get_forward_context
33
29
from vllm .model_executor .layers .linear import ReplicatedLinear
34
-
35
30
from vllm .model_executor .layers .quantization import QuantizationConfig
36
-
31
+ from vllm . model_executor . models . qwen3_moe import Qwen3MoeForCausalLM
37
32
from vllm_ascend .ascend_config import get_ascend_config
38
33
from vllm_ascend .distributed .parallel_state import get_ep_group
39
34
from vllm_ascend .ops .fused_moe import AscendFusedMoE
40
35
41
- from vllm .model_executor .models .qwen3_moe import Qwen3MoeForCausalLM
42
- from transformers import PretrainedConfig
43
- from vllm .model_executor .layers .quantization import QuantizationConfig
44
-
45
36
46
37
class CustomQwen3MoeForCausalLM (Qwen3MoeForCausalLM ):
47
38
packed_modules_mapping = {
@@ -55,19 +46,18 @@ class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
55
46
"up_proj" ,
56
47
],
57
48
"experts" :
58
- ["experts.0.gate_proj" , "experts.0.up_proj" , "experts.0.down_proj" ],
49
+ ["experts.0.gate_proj" , "experts.0.up_proj" , "experts.0.down_proj" ],
59
50
}
60
51
61
52
62
53
class AscendQwen3MoeSparseMoeBlock (nn .Module ):
63
-
64
54
top_k : int
65
55
66
56
def __init__ (
67
- self ,
68
- config : PretrainedConfig ,
69
- quant_config : Optional [QuantizationConfig ] = None ,
70
- prefix : str = "" ,
57
+ self ,
58
+ config : PretrainedConfig ,
59
+ quant_config : Optional [QuantizationConfig ] = None ,
60
+ prefix : str = "" ,
71
61
):
72
62
super ().__init__ ()
73
63
self .tp_size = get_tensor_model_parallel_world_size ()
@@ -97,7 +87,6 @@ def __init__(
97
87
quant_config = quant_config ,
98
88
prefix = f"{ prefix } .experts" )
99
89
100
-
101
90
self .top_k = config .num_experts_per_tok
102
91
103
92
self .dp_size = get_dp_group ().world_size
@@ -122,7 +111,7 @@ def forward(
122
111
is_prefill = True
123
112
enable_force_load_balance = True
124
113
else :
125
- # is_prefill = attn_metadata.num_prefills > 0 is_prefill or
114
+ # is_prefill = attn_metadata.num_prefills > 0
126
115
enable_force_load_balance = False
127
116
if hasattr (attn_metadata , 'with_prefill_across_dp' ):
128
117
is_prefill = attn_metadata .with_prefill_across_dp
0 commit comments