1
+ from collections .abc import Iterable
2
+ from typing import Optional , Union
3
+
4
+ import torch
5
+ from torch import nn
6
+ from transformers import Qwen3Config
7
+ from vllm .compilation .decorators import support_torch_compile
8
+ from vllm .config import CacheConfig , VllmConfig
9
+ from vllm .distributed import get_pp_group
10
+ from vllm .model_executor .layers .logits_processor import LogitsProcessor
11
+ from vllm .model_executor .layers .quantization import QuantizationConfig
12
+ from vllm .model_executor .layers .vocab_parallel_embedding import ParallelLMHead
13
+ from vllm .model_executor .models .interfaces import SupportsLoRA , SupportsPP
14
+ from vllm .model_executor .models .qwen2 import Qwen2Model
15
+ from vllm .model_executor .models .qwen3 import Qwen3DecoderLayer
16
+ from vllm .model_executor .models .utils import (AutoWeightsLoader ,
17
+ PPMissingLayer , maybe_prefix )
18
+ from vllm .model_executor .sampling_metadata import SamplingMetadata
19
+ from vllm .sequence import IntermediateTensors
20
+ from vllm_ascend .ops .layernorm import AddRMSNormQuant
21
+
22
+
23
+ class CustomQwen3DecoderLayer (Qwen3DecoderLayer ):
24
+
25
+ def __init__ (
26
+ self ,
27
+ config : Qwen3Config ,
28
+ cache_config : Optional [CacheConfig ] = None ,
29
+ quant_config : Optional [QuantizationConfig ] = None ,
30
+ prefix : str = "" ,
31
+ ) -> None :
32
+ super ().__init__ (config = config ,
33
+ cache_config = cache_config ,
34
+ quant_config = quant_config ,
35
+ prefix = prefix )
36
+ if quant_config is not None :
37
+ from vllm_ascend .quantization .quant_config import AscendQuantConfig
38
+ assert isinstance (quant_config , AscendQuantConfig )
39
+ self .input_layernorm = AddRMSNormQuant (config .hidden_size ,
40
+ layer = self .self_attn .qkv_proj ,
41
+ eps = config .rms_norm_eps )
42
+ self .post_attention_layernorm = AddRMSNormQuant (config .hidden_size ,
43
+ layer = self .mlp .gate_up_proj ,
44
+ eps = config .rms_norm_eps )
45
+
46
+
47
+ ALL_DECODER_LAYER_TYPES = {
48
+ "attention" : CustomQwen3DecoderLayer ,
49
+ }
50
+
51
+
52
+ @support_torch_compile (
53
+ dynamic_arg_dims = {
54
+ "input_ids" : 0 ,
55
+ # positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
56
+ # otherwise (seq_len, ).
57
+ "positions" : - 1 ,
58
+ "intermediate_tensors" : 0 ,
59
+ "inputs_embeds" : 0 ,
60
+ })
61
+ class CustomQwen3Model (Qwen2Model ):
62
+
63
+ def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
64
+ super ().__init__ (vllm_config = vllm_config ,
65
+ prefix = prefix ,
66
+ decoder_layer_type = CustomQwen3DecoderLayer )
67
+
68
+
69
+ class CustomQwen3ForCausalLM (nn .Module , SupportsLoRA , SupportsPP ):
70
+ # add `CustomQwen3Model` to init self.model
71
+ packed_modules_mapping = {
72
+ "qkv_proj" : [
73
+ "q_proj" ,
74
+ "k_proj" ,
75
+ "v_proj" ,
76
+ ],
77
+ "gate_up_proj" : [
78
+ "gate_proj" ,
79
+ "up_proj" ,
80
+ ],
81
+ }
82
+
83
+ def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
84
+ super ().__init__ ()
85
+ config = vllm_config .model_config .hf_config
86
+ quant_config = vllm_config .quant_config
87
+ lora_config = vllm_config .lora_config
88
+
89
+ self .config = config
90
+ self .lora_config = lora_config
91
+
92
+ self .quant_config = quant_config
93
+ self .model = CustomQwen3Model (vllm_config = vllm_config ,
94
+ prefix = maybe_prefix (prefix , "model" ))
95
+
96
+ if get_pp_group ().is_last_rank :
97
+ if config .tie_word_embeddings :
98
+ self .lm_head = self .model .embed_tokens
99
+ else :
100
+ self .lm_head = ParallelLMHead (config .vocab_size ,
101
+ config .hidden_size ,
102
+ quant_config = quant_config ,
103
+ prefix = maybe_prefix (
104
+ prefix , "lm_head" ))
105
+ else :
106
+ self .lm_head = PPMissingLayer ()
107
+
108
+ self .logits_processor = LogitsProcessor (config .vocab_size )
109
+
110
+ self .make_empty_intermediate_tensors = (
111
+ self .model .make_empty_intermediate_tensors )
112
+
113
+ def get_input_embeddings (self , input_ids : torch .Tensor ) -> torch .Tensor :
114
+ return self .model .get_input_embeddings (input_ids )
115
+
116
+ def forward (
117
+ self ,
118
+ input_ids : torch .Tensor ,
119
+ positions : torch .Tensor ,
120
+ intermediate_tensors : Optional [IntermediateTensors ] = None ,
121
+ inputs_embeds : Optional [torch .Tensor ] = None ,
122
+ ) -> Union [torch .Tensor , IntermediateTensors ]:
123
+ hidden_states = self .model (input_ids , positions , intermediate_tensors ,
124
+ inputs_embeds )
125
+ return hidden_states
126
+
127
+ def compute_logits (
128
+ self ,
129
+ hidden_states : torch .Tensor ,
130
+ sampling_metadata : SamplingMetadata ,
131
+ ) -> Optional [torch .Tensor ]:
132
+ logits = self .logits_processor (self .lm_head , hidden_states ,
133
+ sampling_metadata )
134
+ return logits
135
+
136
+ def load_weights (self , weights : Iterable [tuple [str ,
137
+ torch .Tensor ]]) -> set [str ]:
138
+ loader = AutoWeightsLoader (
139
+ self ,
140
+ skip_prefixes = (["lm_head." ]
141
+ if self .config .tie_word_embeddings else None ),
142
+ )
143
+ return loader .load_weights (weights )
0 commit comments