1
+ from collections .abc import Iterable
2
+ from typing import Optional , Union
3
+
4
+ import torch
5
+ from torch import nn
6
+ from transformers import Qwen3Config
7
+
8
+ from vllm .compilation .decorators import support_torch_compile
9
+ from vllm .config import CacheConfig , VllmConfig
10
+ from vllm .distributed import get_pp_group
11
+ from vllm .model_executor .layers .logits_processor import LogitsProcessor
12
+ from vllm .model_executor .layers .quantization import QuantizationConfig
13
+ from vllm .model_executor .layers .vocab_parallel_embedding import ParallelLMHead
14
+ from vllm .model_executor .sampling_metadata import SamplingMetadata
15
+ from vllm .sequence import IntermediateTensors
16
+
17
+ from vllm .model_executor .models .interfaces import SupportsLoRA , SupportsPP
18
+ from vllm .model_executor .models .qwen2 import Qwen2Model
19
+ from vllm .model_executor .models .qwen3 import Qwen3DecoderLayer
20
+ from vllm .model_executor .models .utils import AutoWeightsLoader , PPMissingLayer , maybe_prefix
21
+
22
+ from vllm_ascend .ops .layernorm import AddRMSNormQuant
23
+
24
+ class CustomQwen3DecoderLayer (Qwen3DecoderLayer ):
25
+
26
+ def __init__ (
27
+ self ,
28
+ config : Qwen3Config ,
29
+ cache_config : Optional [CacheConfig ] = None ,
30
+ quant_config : Optional [QuantizationConfig ] = None ,
31
+ prefix : str = "" ,
32
+ ) -> None :
33
+ super ().__init__ (config = config ,
34
+ cache_config = cache_config ,
35
+ quant_config = quant_config ,
36
+ prefix = prefix )
37
+ if quant_config is not None :
38
+ from vllm_ascend .quantization .quant_config import AscendQuantConfig
39
+ assert isinstance (quant_config , AscendQuantConfig )
40
+ self .input_layernorm = AddRMSNormQuant (config .hidden_size ,
41
+ layer = self .self_attn .qkv_proj ,
42
+ eps = config .rms_norm_eps )
43
+ self .post_attention_layernorm = AddRMSNormQuant (config .hidden_size ,
44
+ layer = self .mlp .gate_up_proj ,
45
+ eps = config .rms_norm_eps )
46
+
47
+
48
+ ALL_DECODER_LAYER_TYPES = {
49
+ "attention" : CustomQwen3DecoderLayer ,
50
+ }
51
+
52
+
53
+ @support_torch_compile (
54
+ dynamic_arg_dims = {
55
+ "input_ids" : 0 ,
56
+ # positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
57
+ # otherwise (seq_len, ).
58
+ "positions" : - 1 ,
59
+ "intermediate_tensors" : 0 ,
60
+ "inputs_embeds" : 0 ,
61
+ })
62
+ class CustomQwen3Model (Qwen2Model ):
63
+
64
+ def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
65
+ super ().__init__ (vllm_config = vllm_config ,
66
+ prefix = prefix ,
67
+ decoder_layer_type = CustomQwen3DecoderLayer )
68
+
69
+
70
+ class CustomQwen3ForCausalLM (nn .Module , SupportsLoRA , SupportsPP ):
71
+ # add `CustomQwen3Model` to init self.model
72
+ packed_modules_mapping = {
73
+ "qkv_proj" : [
74
+ "q_proj" ,
75
+ "k_proj" ,
76
+ "v_proj" ,
77
+ ],
78
+ "gate_up_proj" : [
79
+ "gate_proj" ,
80
+ "up_proj" ,
81
+ ],
82
+ }
83
+
84
+ def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
85
+ super ().__init__ ()
86
+ config = vllm_config .model_config .hf_config
87
+ quant_config = vllm_config .quant_config
88
+ lora_config = vllm_config .lora_config
89
+
90
+ self .config = config
91
+ self .lora_config = lora_config
92
+
93
+ self .quant_config = quant_config
94
+ self .model = CustomQwen3Model (vllm_config = vllm_config ,
95
+ prefix = maybe_prefix (prefix , "model" ))
96
+
97
+ if get_pp_group ().is_last_rank :
98
+ if config .tie_word_embeddings :
99
+ self .lm_head = self .model .embed_tokens
100
+ else :
101
+ self .lm_head = ParallelLMHead (config .vocab_size ,
102
+ config .hidden_size ,
103
+ quant_config = quant_config ,
104
+ prefix = maybe_prefix (
105
+ prefix , "lm_head" ))
106
+ else :
107
+ self .lm_head = PPMissingLayer ()
108
+
109
+ self .logits_processor = LogitsProcessor (config .vocab_size )
110
+
111
+ self .make_empty_intermediate_tensors = (
112
+ self .model .make_empty_intermediate_tensors )
113
+
114
+ def get_input_embeddings (self , input_ids : torch .Tensor ) -> torch .Tensor :
115
+ return self .model .get_input_embeddings (input_ids )
116
+
117
+ def forward (
118
+ self ,
119
+ input_ids : torch .Tensor ,
120
+ positions : torch .Tensor ,
121
+ intermediate_tensors : Optional [IntermediateTensors ] = None ,
122
+ inputs_embeds : Optional [torch .Tensor ] = None ,
123
+ ) -> Union [torch .Tensor , IntermediateTensors ]:
124
+ hidden_states = self .model (input_ids , positions , intermediate_tensors ,
125
+ inputs_embeds )
126
+ return hidden_states
127
+
128
+ def compute_logits (
129
+ self ,
130
+ hidden_states : torch .Tensor ,
131
+ sampling_metadata : SamplingMetadata ,
132
+ ) -> Optional [torch .Tensor ]:
133
+ logits = self .logits_processor (self .lm_head , hidden_states ,
134
+ sampling_metadata )
135
+ return logits
136
+
137
+ def load_weights (self , weights : Iterable [tuple [str ,
138
+ torch .Tensor ]]) -> set [str ]:
139
+ loader = AutoWeightsLoader (
140
+ self ,
141
+ skip_prefixes = (["lm_head." ]
142
+ if self .config .tie_word_embeddings else None ),
143
+ )
144
+ return loader .load_weights (weights )
0 commit comments