Skip to content

Commit 3384ed1

Browse files
committed
[V0.9.1] Use AddRmsNormQuant ops in the custom model to optimize Qwen3's
performance Signed-off-by: rjg-lyh <1318825571@qq.com>
1 parent 9acc082 commit 3384ed1

File tree

4 files changed

+194
-5
lines changed

4 files changed

+194
-5
lines changed

vllm_ascend/models/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ def register_model():
1111
from .qwen2_5_vl import \
1212
AscendQwen2_5_VLForConditionalGeneration # noqa: F401
1313
from .qwen2_vl import AscendQwen2VLForConditionalGeneration # noqa: F401
14+
from .qwen3 import CustomQwen3ForCausalLM # noqa: F401
1415

1516
ModelRegistry.register_model(
1617
"DeepSeekMTPModel",
@@ -52,3 +53,7 @@ def register_model():
5253
ModelRegistry.register_model(
5354
"Qwen3MoeForCausalLM",
5455
"vllm_ascend.models.qwen3_moe:CustomQwen3MoeForCausalLM")
56+
57+
ModelRegistry.register_model(
58+
"Qwen3ForCausalLM",
59+
"vllm_ascend.models.qwen3:CustomQwen3ForCausalLM")

vllm_ascend/models/qwen3.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
from collections.abc import Iterable
2+
from typing import Optional, Union
3+
4+
import torch
5+
from torch import nn
6+
from transformers import Qwen3Config
7+
from vllm.compilation.decorators import support_torch_compile
8+
from vllm.config import CacheConfig, VllmConfig
9+
from vllm.distributed import get_pp_group
10+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
11+
from vllm.model_executor.layers.quantization import QuantizationConfig
12+
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
13+
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
14+
from vllm.model_executor.models.qwen2 import Qwen2Model
15+
from vllm.model_executor.models.qwen3 import Qwen3DecoderLayer
16+
from vllm.model_executor.models.utils import (AutoWeightsLoader,
17+
PPMissingLayer, maybe_prefix)
18+
from vllm.model_executor.sampling_metadata import SamplingMetadata
19+
from vllm.sequence import IntermediateTensors
20+
from vllm_ascend.ops.layernorm import AddRMSNormQuant
21+
22+
23+
class CustomQwen3DecoderLayer(Qwen3DecoderLayer):
24+
25+
def __init__(
26+
self,
27+
config: Qwen3Config,
28+
cache_config: Optional[CacheConfig] = None,
29+
quant_config: Optional[QuantizationConfig] = None,
30+
prefix: str = "",
31+
) -> None:
32+
super().__init__(config=config,
33+
cache_config=cache_config,
34+
quant_config=quant_config,
35+
prefix=prefix)
36+
if quant_config is not None:
37+
from vllm_ascend.quantization.quant_config import AscendQuantConfig
38+
assert isinstance(quant_config, AscendQuantConfig)
39+
self.input_layernorm = AddRMSNormQuant(config.hidden_size,
40+
layer=self.self_attn.qkv_proj,
41+
eps=config.rms_norm_eps)
42+
self.post_attention_layernorm = AddRMSNormQuant(config.hidden_size,
43+
layer=self.mlp.gate_up_proj,
44+
eps=config.rms_norm_eps)
45+
46+
47+
ALL_DECODER_LAYER_TYPES = {
48+
"attention": CustomQwen3DecoderLayer,
49+
}
50+
51+
52+
@support_torch_compile(
53+
dynamic_arg_dims={
54+
"input_ids": 0,
55+
# positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
56+
# otherwise (seq_len, ).
57+
"positions": -1,
58+
"intermediate_tensors": 0,
59+
"inputs_embeds": 0,
60+
})
61+
class CustomQwen3Model(Qwen2Model):
62+
63+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
64+
super().__init__(vllm_config=vllm_config,
65+
prefix=prefix,
66+
decoder_layer_type=CustomQwen3DecoderLayer)
67+
68+
69+
class CustomQwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
70+
# add `CustomQwen3Model` to init self.model
71+
packed_modules_mapping = {
72+
"qkv_proj": [
73+
"q_proj",
74+
"k_proj",
75+
"v_proj",
76+
],
77+
"gate_up_proj": [
78+
"gate_proj",
79+
"up_proj",
80+
],
81+
}
82+
83+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
84+
super().__init__()
85+
config = vllm_config.model_config.hf_config
86+
quant_config = vllm_config.quant_config
87+
lora_config = vllm_config.lora_config
88+
89+
self.config = config
90+
self.lora_config = lora_config
91+
92+
self.quant_config = quant_config
93+
self.model = CustomQwen3Model(vllm_config=vllm_config,
94+
prefix=maybe_prefix(prefix, "model"))
95+
96+
if get_pp_group().is_last_rank:
97+
if config.tie_word_embeddings:
98+
self.lm_head = self.model.embed_tokens
99+
else:
100+
self.lm_head = ParallelLMHead(config.vocab_size,
101+
config.hidden_size,
102+
quant_config=quant_config,
103+
prefix=maybe_prefix(
104+
prefix, "lm_head"))
105+
else:
106+
self.lm_head = PPMissingLayer()
107+
108+
self.logits_processor = LogitsProcessor(config.vocab_size)
109+
110+
self.make_empty_intermediate_tensors = (
111+
self.model.make_empty_intermediate_tensors)
112+
113+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
114+
return self.model.get_input_embeddings(input_ids)
115+
116+
def forward(
117+
self,
118+
input_ids: torch.Tensor,
119+
positions: torch.Tensor,
120+
intermediate_tensors: Optional[IntermediateTensors] = None,
121+
inputs_embeds: Optional[torch.Tensor] = None,
122+
) -> Union[torch.Tensor, IntermediateTensors]:
123+
hidden_states = self.model(input_ids, positions, intermediate_tensors,
124+
inputs_embeds)
125+
return hidden_states
126+
127+
def compute_logits(
128+
self,
129+
hidden_states: torch.Tensor,
130+
sampling_metadata: SamplingMetadata,
131+
) -> Optional[torch.Tensor]:
132+
logits = self.logits_processor(self.lm_head, hidden_states,
133+
sampling_metadata)
134+
return logits
135+
136+
def load_weights(self, weights: Iterable[tuple[str,
137+
torch.Tensor]]) -> set[str]:
138+
loader = AutoWeightsLoader(
139+
self,
140+
skip_prefixes=(["lm_head."]
141+
if self.config.tie_word_embeddings else None),
142+
)
143+
return loader.load_weights(weights)

vllm_ascend/ops/layernorm.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,45 @@
2121
from vllm.model_executor.layers.layernorm import RMSNorm
2222

2323

24+
class AddRMSNormQuant(RMSNorm):
25+
"""Root mean square normalization.
26+
27+
Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
28+
Refer to https://arxiv.org/abs/1910.07467
29+
"""
30+
def __init__(
31+
self,
32+
hidden_size: int,
33+
layer: torch.nn.Module,
34+
eps: float = 1e-6,
35+
var_hidden_size: Optional[int] = None,
36+
has_weight: bool = True,
37+
dtype: Optional[torch.dtype] = None,
38+
) -> None:
39+
super().__init__(hidden_size,
40+
eps,
41+
var_hidden_size,
42+
has_weight,
43+
dtype)
44+
self.layer = layer
45+
46+
def forward(
47+
self,
48+
x: torch.Tensor,
49+
residual: Optional[torch.Tensor] = None,
50+
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
51+
import torch_npu
52+
53+
if residual is not None:
54+
x, _, residual = torch_npu.npu_add_rms_norm_quant(x, residual, self.weight,
55+
self.layer.aclnn_input_scale,
56+
self.layer.aclnn_input_offset,
57+
epsilon=self.variance_epsilon)
58+
return x, residual
59+
60+
x, residual = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon)
61+
return x
62+
2463
def forward_oot(
2564
self,
2665
x: torch.Tensor,

vllm_ascend/quantization/w8a8.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,11 +93,10 @@ def apply(
9393
bias: Optional[torch.Tensor] = None,
9494
tp_rank: Optional[int] = 0,
9595
) -> torch.Tensor:
96-
original_dtype = x.dtype
97-
if original_dtype != torch.int8:
96+
if x.dtype != torch.int8:
9897
x = quant_per_tensor(
9998
x,
100-
layer.aclnn_input_scale,
99+
layer.aclnn_input_scale_reciprocal,
101100
layer.aclnn_input_offset,
102101
)
103102
quant_bias = layer.quant_bias if tp_rank == 0 else None
@@ -106,12 +105,15 @@ def apply(
106105
layer.weight,
107106
layer.deq_scale,
108107
bias=quant_bias,
109-
output_dtype=original_dtype,
108+
output_dtype=layer.params_dtype,
110109
)
111110

112111
def process_weights_after_loading(self, layer):
113112
expanding_factor = layer.weight.data.shape[1]
114-
layer.aclnn_input_scale = 1 / torch.nn.Parameter(
113+
layer.aclnn_input_scale = torch.nn.Parameter(
114+
layer.input_scale.data.repeat(expanding_factor),
115+
requires_grad=False)
116+
layer.aclnn_input_scale_reciprocal = 1 / torch.nn.Parameter(
115117
layer.input_scale.data.repeat(expanding_factor),
116118
requires_grad=False)
117119
layer.aclnn_input_offset = torch.nn.Parameter(

0 commit comments

Comments
 (0)