Skip to content

Commit bcbc024

Browse files
committed
[main] Use AddRmsNormQuant ops in the custom model to optimize Qwen3's performance
Signed-off-by: rjg-lyh <1318825571@qq.com>
1 parent bf25498 commit bcbc024

File tree

5 files changed

+227
-8
lines changed

5 files changed

+227
-8
lines changed

tests/e2e/multicard/test_offline_inference_distributed.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,3 +167,20 @@ def test_models_distributed_topk() -> None:
167167
distributed_executor_backend="mp",
168168
) as vllm_model:
169169
vllm_model.generate(example_prompts, sampling_params)
170+
171+
172+
def test_models_distributed_Qwen3_W8A8():
173+
example_prompts = [
174+
"Hello, my name is",
175+
]
176+
max_tokens = 5
177+
178+
with VllmRunner(
179+
snapshot_download("vllm-ascend/Qwen3-8B-W8A8"),
180+
max_model_len=8192,
181+
enforce_eager=True,
182+
dtype="auto",
183+
tensor_parallel_size=4,
184+
quantization="ascend",
185+
) as vllm_model:
186+
vllm_model.generate_greedy(example_prompts, max_tokens)

vllm_ascend/models/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ def register_model():
1111
from .qwen2_5_vl import \
1212
AscendQwen2_5_VLForConditionalGeneration # noqa: F401
1313
from .qwen2_vl import AscendQwen2VLForConditionalGeneration # noqa: F401
14+
from .qwen3 import CustomQwen3ForCausalLM # noqa: F401
1415

1516
ModelRegistry.register_model(
1617
"DeepSeekMTPModel",
@@ -53,6 +54,9 @@ def register_model():
5354
"Qwen3MoeForCausalLM",
5455
"vllm_ascend.models.qwen3_moe:CustomQwen3MoeForCausalLM")
5556

57+
ModelRegistry.register_model(
58+
"Qwen3ForCausalLM", "vllm_ascend.models.qwen3:CustomQwen3ForCausalLM")
59+
5660
ModelRegistry.register_model(
5761
"PanguProMoEForCausalLM",
58-
"vllm_ascend.models.pangu_moe:PanguProMoEForCausalLM")
62+
"vllm_ascend.models.pangu_moe:PanguProMoEForCausalLM")

vllm_ascend/models/qwen3.py

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
from collections.abc import Iterable
2+
from typing import Optional, Union
3+
4+
import torch
5+
from torch import nn
6+
from transformers import Qwen3Config
7+
from vllm.compilation.decorators import support_torch_compile
8+
from vllm.config import CacheConfig, VllmConfig
9+
from vllm.distributed import get_pp_group
10+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
11+
from vllm.model_executor.layers.quantization import QuantizationConfig
12+
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
13+
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
14+
from vllm.model_executor.models.qwen2 import Qwen2Model
15+
from vllm.model_executor.models.qwen3 import Qwen3DecoderLayer
16+
from vllm.model_executor.models.utils import (AutoWeightsLoader,
17+
PPMissingLayer, maybe_prefix)
18+
from vllm.model_executor.sampling_metadata import SamplingMetadata
19+
from vllm.sequence import IntermediateTensors
20+
21+
from vllm_ascend.ops.layernorm import AddRMSNormW8A8Quant
22+
23+
24+
class CustomQwen3DecoderLayer(Qwen3DecoderLayer):
25+
26+
def __init__(
27+
self,
28+
config: Qwen3Config,
29+
cache_config: Optional[CacheConfig] = None,
30+
quant_config: Optional[QuantizationConfig] = None,
31+
prefix: str = "",
32+
) -> None:
33+
super().__init__(config=config,
34+
cache_config=cache_config,
35+
quant_config=quant_config,
36+
prefix=prefix)
37+
if quant_config is None:
38+
return
39+
40+
from vllm_ascend.quantization.quant_config import AscendQuantConfig
41+
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
42+
43+
assert isinstance(quant_config, AscendQuantConfig), \
44+
"Expected quant_config to be an instance of AscendQuantConfig"
45+
46+
if isinstance(self.self_attn.qkv_proj.quant_method.quant_method,
47+
AscendW8A8LinearMethod):
48+
self.input_layernorm = AddRMSNormW8A8Quant(
49+
config.hidden_size,
50+
layer=self.self_attn.qkv_proj,
51+
eps=config.rms_norm_eps)
52+
if isinstance(self.mlp.gate_up_proj.quant_method.quant_method,
53+
AscendW8A8LinearMethod):
54+
self.post_attention_layernorm = AddRMSNormW8A8Quant(
55+
config.hidden_size,
56+
layer=self.mlp.gate_up_proj,
57+
eps=config.rms_norm_eps)
58+
59+
60+
ALL_DECODER_LAYER_TYPES = {
61+
"attention": CustomQwen3DecoderLayer,
62+
}
63+
64+
65+
@support_torch_compile(
66+
dynamic_arg_dims={
67+
"input_ids": 0,
68+
# positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
69+
# otherwise (seq_len, ).
70+
"positions": -1,
71+
"intermediate_tensors": 0,
72+
"inputs_embeds": 0,
73+
})
74+
class CustomQwen3Model(Qwen2Model):
75+
76+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
77+
super().__init__(vllm_config=vllm_config,
78+
prefix=prefix,
79+
decoder_layer_type=CustomQwen3DecoderLayer)
80+
81+
82+
class CustomQwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
83+
# add `CustomQwen3Model` to init self.model
84+
packed_modules_mapping = {
85+
"qkv_proj": [
86+
"q_proj",
87+
"k_proj",
88+
"v_proj",
89+
],
90+
"gate_up_proj": [
91+
"gate_proj",
92+
"up_proj",
93+
],
94+
}
95+
96+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
97+
super().__init__()
98+
config = vllm_config.model_config.hf_config
99+
quant_config = vllm_config.quant_config
100+
lora_config = vllm_config.lora_config
101+
102+
self.config = config
103+
self.lora_config = lora_config
104+
105+
self.quant_config = quant_config
106+
self.model = CustomQwen3Model(vllm_config=vllm_config,
107+
prefix=maybe_prefix(prefix, "model"))
108+
109+
if get_pp_group().is_last_rank:
110+
if config.tie_word_embeddings:
111+
self.lm_head = self.model.embed_tokens
112+
else:
113+
self.lm_head = ParallelLMHead(config.vocab_size,
114+
config.hidden_size,
115+
quant_config=quant_config,
116+
prefix=maybe_prefix(
117+
prefix, "lm_head"))
118+
else:
119+
self.lm_head = PPMissingLayer()
120+
121+
self.logits_processor = LogitsProcessor(config.vocab_size)
122+
123+
self.make_empty_intermediate_tensors = (
124+
self.model.make_empty_intermediate_tensors)
125+
126+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
127+
return self.model.get_input_embeddings(input_ids)
128+
129+
def forward(
130+
self,
131+
input_ids: torch.Tensor,
132+
positions: torch.Tensor,
133+
intermediate_tensors: Optional[IntermediateTensors] = None,
134+
inputs_embeds: Optional[torch.Tensor] = None,
135+
) -> Union[torch.Tensor, IntermediateTensors]:
136+
hidden_states = self.model(input_ids, positions, intermediate_tensors,
137+
inputs_embeds)
138+
return hidden_states
139+
140+
def compute_logits(
141+
self,
142+
hidden_states: torch.Tensor,
143+
sampling_metadata: SamplingMetadata,
144+
) -> Optional[torch.Tensor]:
145+
logits = self.logits_processor(self.lm_head, hidden_states,
146+
sampling_metadata)
147+
return logits
148+
149+
def load_weights(self, weights: Iterable[tuple[str,
150+
torch.Tensor]]) -> set[str]:
151+
loader = AutoWeightsLoader(
152+
self,
153+
skip_prefixes=(["lm_head."]
154+
if self.config.tie_word_embeddings else None),
155+
)
156+
return loader.load_weights(weights)

vllm_ascend/ops/layernorm.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,43 @@
2323
from vllm_ascend.utils import is_310p
2424

2525

26+
class AddRMSNormW8A8Quant(RMSNorm):
27+
# Fuse AddRmsNorm and W8A8 quantization ops together
28+
29+
def __init__(
30+
self,
31+
hidden_size: int,
32+
layer: torch.nn.Module,
33+
eps: float = 1e-6,
34+
var_hidden_size: Optional[int] = None,
35+
has_weight: bool = True,
36+
dtype: Optional[torch.dtype] = None,
37+
) -> None:
38+
super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype)
39+
self.layer = layer
40+
41+
def forward(
42+
self,
43+
x: torch.Tensor,
44+
residual: Optional[torch.Tensor] = None,
45+
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
46+
import torch_npu
47+
48+
if residual is not None:
49+
x, _, residual = torch_npu.npu_add_rms_norm_quant(
50+
x,
51+
residual,
52+
self.weight,
53+
self.layer.aclnn_input_scale,
54+
self.layer.aclnn_input_offset,
55+
epsilon=self.variance_epsilon)
56+
return x, residual
57+
58+
x, residual = torch_npu.npu_rms_norm(x, self.weight,
59+
self.variance_epsilon)
60+
return x
61+
62+
2663
def forward_oot(
2764
self,
2865
x: torch.Tensor,

vllm_ascend/quantization/w8a8.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -91,10 +91,12 @@ def apply(
9191
bias: Optional[torch.Tensor] = None,
9292
tp_rank: Optional[int] = 0,
9393
) -> torch.Tensor:
94-
original_dtype = x.dtype
95-
if original_dtype != torch.int8:
96-
x = quant_per_tensor(x, layer.aclnn_input_scale,
97-
layer.aclnn_input_offset)
94+
if x.dtype != torch.int8:
95+
x = quant_per_tensor(
96+
x,
97+
layer.aclnn_input_scale_reciprocal,
98+
layer.aclnn_input_offset,
99+
)
98100
quant_bias = layer.quant_bias if tp_rank == 0 else None
99101
if is_310p():
100102
# On 300I Duo platform, we need transpose again if
@@ -104,21 +106,24 @@ def apply(
104106
layer.weight.data.transpose(1, 0),
105107
layer.deq_scale,
106108
bias=quant_bias,
107-
output_dtype=original_dtype,
109+
output_dtype=layer.params_dtype,
108110
)
109111
else:
110112
output = torch_npu.npu_quant_matmul(
111113
x,
112114
layer.weight,
113115
layer.deq_scale,
114116
bias=quant_bias,
115-
output_dtype=original_dtype,
117+
output_dtype=layer.params_dtype,
116118
)
117119
return output
118120

119121
def process_weights_after_loading(self, layer):
120122
expanding_factor = layer.weight.data.shape[1]
121-
layer.aclnn_input_scale = 1 / torch.nn.Parameter(
123+
layer.aclnn_input_scale = torch.nn.Parameter(
124+
layer.input_scale.data.repeat(expanding_factor),
125+
requires_grad=False)
126+
layer.aclnn_input_scale_reciprocal = 1 / torch.nn.Parameter(
122127
layer.input_scale.data.repeat(expanding_factor),
123128
requires_grad=False)
124129
layer.aclnn_input_offset = torch.nn.Parameter(

0 commit comments

Comments
 (0)