Skip to content

Commit 07736ef

Browse files
committed
[V0.9.1] Use AddRmsNormQuant ops in the custom model to optimize Qwen3's
performance Signed-off-by: rjg-lyh <1318825571@qq.com>
1 parent 9acc082 commit 07736ef

File tree

4 files changed

+198
-5
lines changed

4 files changed

+198
-5
lines changed

vllm_ascend/models/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ def register_model():
1111
from .qwen2_5_vl import \
1212
AscendQwen2_5_VLForConditionalGeneration # noqa: F401
1313
from .qwen2_vl import AscendQwen2VLForConditionalGeneration # noqa: F401
14+
from .qwen3 import CustomQwen3ForCausalLM # noqa: F401
1415

1516
ModelRegistry.register_model(
1617
"DeepSeekMTPModel",
@@ -52,3 +53,6 @@ def register_model():
5253
ModelRegistry.register_model(
5354
"Qwen3MoeForCausalLM",
5455
"vllm_ascend.models.qwen3_moe:CustomQwen3MoeForCausalLM")
56+
57+
ModelRegistry.register_model(
58+
"Qwen3ForCausalLM", "vllm_ascend.models.qwen3:CustomQwen3ForCausalLM")

vllm_ascend/models/qwen3.py

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
from collections.abc import Iterable
2+
from typing import Optional, Union
3+
4+
import torch
5+
from torch import nn
6+
from transformers import Qwen3Config
7+
from vllm.compilation.decorators import support_torch_compile
8+
from vllm.config import CacheConfig, VllmConfig
9+
from vllm.distributed import get_pp_group
10+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
11+
from vllm.model_executor.layers.quantization import QuantizationConfig
12+
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
13+
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
14+
from vllm.model_executor.models.qwen2 import Qwen2Model
15+
from vllm.model_executor.models.qwen3 import Qwen3DecoderLayer
16+
from vllm.model_executor.models.utils import (AutoWeightsLoader,
17+
PPMissingLayer, maybe_prefix)
18+
from vllm.model_executor.sampling_metadata import SamplingMetadata
19+
from vllm.sequence import IntermediateTensors
20+
21+
from vllm_ascend.ops.layernorm import AddRMSNormQuant
22+
23+
24+
class CustomQwen3DecoderLayer(Qwen3DecoderLayer):
25+
26+
def __init__(
27+
self,
28+
config: Qwen3Config,
29+
cache_config: Optional[CacheConfig] = None,
30+
quant_config: Optional[QuantizationConfig] = None,
31+
prefix: str = "",
32+
) -> None:
33+
super().__init__(config=config,
34+
cache_config=cache_config,
35+
quant_config=quant_config,
36+
prefix=prefix)
37+
if quant_config is not None:
38+
from vllm_ascend.quantization.quant_config import AscendQuantConfig
39+
assert isinstance(quant_config, AscendQuantConfig)
40+
self.input_layernorm = AddRMSNormQuant(
41+
config.hidden_size,
42+
layer=self.self_attn.qkv_proj,
43+
eps=config.rms_norm_eps)
44+
self.post_attention_layernorm = AddRMSNormQuant(
45+
config.hidden_size,
46+
layer=self.mlp.gate_up_proj,
47+
eps=config.rms_norm_eps)
48+
49+
50+
ALL_DECODER_LAYER_TYPES = {
51+
"attention": CustomQwen3DecoderLayer,
52+
}
53+
54+
55+
@support_torch_compile(
56+
dynamic_arg_dims={
57+
"input_ids": 0,
58+
# positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
59+
# otherwise (seq_len, ).
60+
"positions": -1,
61+
"intermediate_tensors": 0,
62+
"inputs_embeds": 0,
63+
})
64+
class CustomQwen3Model(Qwen2Model):
65+
66+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
67+
super().__init__(vllm_config=vllm_config,
68+
prefix=prefix,
69+
decoder_layer_type=CustomQwen3DecoderLayer)
70+
71+
72+
class CustomQwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
73+
# add `CustomQwen3Model` to init self.model
74+
packed_modules_mapping = {
75+
"qkv_proj": [
76+
"q_proj",
77+
"k_proj",
78+
"v_proj",
79+
],
80+
"gate_up_proj": [
81+
"gate_proj",
82+
"up_proj",
83+
],
84+
}
85+
86+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
87+
super().__init__()
88+
config = vllm_config.model_config.hf_config
89+
quant_config = vllm_config.quant_config
90+
lora_config = vllm_config.lora_config
91+
92+
self.config = config
93+
self.lora_config = lora_config
94+
95+
self.quant_config = quant_config
96+
self.model = CustomQwen3Model(vllm_config=vllm_config,
97+
prefix=maybe_prefix(prefix, "model"))
98+
99+
if get_pp_group().is_last_rank:
100+
if config.tie_word_embeddings:
101+
self.lm_head = self.model.embed_tokens
102+
else:
103+
self.lm_head = ParallelLMHead(config.vocab_size,
104+
config.hidden_size,
105+
quant_config=quant_config,
106+
prefix=maybe_prefix(
107+
prefix, "lm_head"))
108+
else:
109+
self.lm_head = PPMissingLayer()
110+
111+
self.logits_processor = LogitsProcessor(config.vocab_size)
112+
113+
self.make_empty_intermediate_tensors = (
114+
self.model.make_empty_intermediate_tensors)
115+
116+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
117+
return self.model.get_input_embeddings(input_ids)
118+
119+
def forward(
120+
self,
121+
input_ids: torch.Tensor,
122+
positions: torch.Tensor,
123+
intermediate_tensors: Optional[IntermediateTensors] = None,
124+
inputs_embeds: Optional[torch.Tensor] = None,
125+
) -> Union[torch.Tensor, IntermediateTensors]:
126+
hidden_states = self.model(input_ids, positions, intermediate_tensors,
127+
inputs_embeds)
128+
return hidden_states
129+
130+
def compute_logits(
131+
self,
132+
hidden_states: torch.Tensor,
133+
sampling_metadata: SamplingMetadata,
134+
) -> Optional[torch.Tensor]:
135+
logits = self.logits_processor(self.lm_head, hidden_states,
136+
sampling_metadata)
137+
return logits
138+
139+
def load_weights(self, weights: Iterable[tuple[str,
140+
torch.Tensor]]) -> set[str]:
141+
loader = AutoWeightsLoader(
142+
self,
143+
skip_prefixes=(["lm_head."]
144+
if self.config.tie_word_embeddings else None),
145+
)
146+
return loader.load_weights(weights)

vllm_ascend/ops/layernorm.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,47 @@
2121
from vllm.model_executor.layers.layernorm import RMSNorm
2222

2323

24+
class AddRMSNormQuant(RMSNorm):
25+
"""Root mean square normalization.
26+
27+
Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
28+
Refer to https://arxiv.org/abs/1910.07467
29+
"""
30+
31+
def __init__(
32+
self,
33+
hidden_size: int,
34+
layer: torch.nn.Module,
35+
eps: float = 1e-6,
36+
var_hidden_size: Optional[int] = None,
37+
has_weight: bool = True,
38+
dtype: Optional[torch.dtype] = None,
39+
) -> None:
40+
super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype)
41+
self.layer = layer
42+
43+
def forward(
44+
self,
45+
x: torch.Tensor,
46+
residual: Optional[torch.Tensor] = None,
47+
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
48+
import torch_npu
49+
50+
if residual is not None:
51+
x, _, residual = torch_npu.npu_add_rms_norm_quant(
52+
x,
53+
residual,
54+
self.weight,
55+
self.layer.aclnn_input_scale,
56+
self.layer.aclnn_input_offset,
57+
epsilon=self.variance_epsilon)
58+
return x, residual
59+
60+
x, residual = torch_npu.npu_rms_norm(x, self.weight,
61+
self.variance_epsilon)
62+
return x
63+
64+
2465
def forward_oot(
2566
self,
2667
x: torch.Tensor,

vllm_ascend/quantization/w8a8.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,11 +93,10 @@ def apply(
9393
bias: Optional[torch.Tensor] = None,
9494
tp_rank: Optional[int] = 0,
9595
) -> torch.Tensor:
96-
original_dtype = x.dtype
97-
if original_dtype != torch.int8:
96+
if x.dtype != torch.int8:
9897
x = quant_per_tensor(
9998
x,
100-
layer.aclnn_input_scale,
99+
layer.aclnn_input_scale_reciprocal,
101100
layer.aclnn_input_offset,
102101
)
103102
quant_bias = layer.quant_bias if tp_rank == 0 else None
@@ -106,12 +105,15 @@ def apply(
106105
layer.weight,
107106
layer.deq_scale,
108107
bias=quant_bias,
109-
output_dtype=original_dtype,
108+
output_dtype=layer.params_dtype,
110109
)
111110

112111
def process_weights_after_loading(self, layer):
113112
expanding_factor = layer.weight.data.shape[1]
114-
layer.aclnn_input_scale = 1 / torch.nn.Parameter(
113+
layer.aclnn_input_scale = torch.nn.Parameter(
114+
layer.input_scale.data.repeat(expanding_factor),
115+
requires_grad=False)
116+
layer.aclnn_input_scale_reciprocal = 1 / torch.nn.Parameter(
115117
layer.input_scale.data.repeat(expanding_factor),
116118
requires_grad=False)
117119
layer.aclnn_input_offset = torch.nn.Parameter(

0 commit comments

Comments
 (0)