16
16
# Adapted from vllm/model_executor/models/qwen3_moe.py
17
17
# This file is a part of the vllm-ascend project.
18
18
19
- from typing import Any , Callable , Dict , List , Optional , Tuple , Union
19
+ from typing import Optional
20
20
21
21
import torch
22
- import torch .distributed as dist
23
- import torch_npu
24
- import vllm
25
- import vllm .envs as envs
26
22
from torch import nn
27
23
from transformers import PretrainedConfig
24
+ from vllm_ascend .ascend_config import get_ascend_config
25
+ from vllm_ascend .distributed .parallel_state import get_ep_group
26
+ from vllm_ascend .ops .fused_moe import AscendFusedMoE
27
+
28
+ import vllm
28
29
from vllm .attention import AttentionMetadata
29
- from vllm .distributed import (get_tensor_model_parallel_world_size ,
30
- get_tp_group )
30
+ from vllm .distributed import get_tensor_model_parallel_world_size , get_tp_group
31
31
from vllm .distributed .parallel_state import get_dp_group
32
32
from vllm .forward_context import get_forward_context
33
33
from vllm .model_executor .layers .linear import ReplicatedLinear
34
-
35
34
from vllm .model_executor .layers .quantization import QuantizationConfig
36
-
37
- from vllm_ascend .ascend_config import get_ascend_config
38
- from vllm_ascend .distributed .parallel_state import get_ep_group
39
- from vllm_ascend .ops .fused_moe import AscendFusedMoE
40
-
41
35
from vllm .model_executor .models .qwen3_moe import Qwen3MoeForCausalLM
42
- from transformers import PretrainedConfig
43
- from vllm .model_executor .layers .quantization import QuantizationConfig
44
36
45
37
46
38
class CustomQwen3MoeForCausalLM (Qwen3MoeForCausalLM ):
@@ -55,19 +47,18 @@ class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
55
47
"up_proj" ,
56
48
],
57
49
"experts" :
58
- ["experts.0.gate_proj" , "experts.0.up_proj" , "experts.0.down_proj" ],
50
+ ["experts.0.gate_proj" , "experts.0.up_proj" , "experts.0.down_proj" ],
59
51
}
60
52
61
53
62
54
class AscendQwen3MoeSparseMoeBlock (nn .Module ):
63
-
64
55
top_k : int
65
56
66
57
def __init__ (
67
- self ,
68
- config : PretrainedConfig ,
69
- quant_config : Optional [QuantizationConfig ] = None ,
70
- prefix : str = "" ,
58
+ self ,
59
+ config : PretrainedConfig ,
60
+ quant_config : Optional [QuantizationConfig ] = None ,
61
+ prefix : str = "" ,
71
62
):
72
63
super ().__init__ ()
73
64
self .tp_size = get_tensor_model_parallel_world_size ()
@@ -97,7 +88,6 @@ def __init__(
97
88
quant_config = quant_config ,
98
89
prefix = f"{ prefix } .experts" )
99
90
100
-
101
91
self .top_k = config .num_experts_per_tok
102
92
103
93
self .dp_size = get_dp_group ().world_size
@@ -122,7 +112,7 @@ def forward(
122
112
is_prefill = True
123
113
enable_force_load_balance = True
124
114
else :
125
- # is_prefill = attn_metadata.num_prefills > 0 is_prefill or
115
+ # is_prefill = attn_metadata.num_prefills > 0
126
116
enable_force_load_balance = False
127
117
if hasattr (attn_metadata , 'with_prefill_across_dp' ):
128
118
is_prefill = attn_metadata .with_prefill_across_dp
0 commit comments