Skip to content

Commit 8e42f71

Browse files
authored
[0.9.1][BugFix] Fix the failure to recognize the actual type of quantization (#1721)
### What this PR does / why we need it? Fix the failure to recognize the actual type of quantization in layernorm, which causes the expected branch not to be executed. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? CI passed with new added/existing test. Signed-off-by: rjg-lyh <1318825571@qq.com>
1 parent 2ab9b16 commit 8e42f71

File tree

2 files changed

+7
-11
lines changed

2 files changed

+7
-11
lines changed

vllm_ascend/models/qwen3.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from vllm.model_executor.sampling_metadata import SamplingMetadata
1919
from vllm.sequence import IntermediateTensors
2020

21-
from vllm_ascend.ops.layernorm import AddRMSNormQuant
21+
from vllm_ascend.ops.layernorm import AddRMSNormW8A8Quant
2222

2323

2424
class CustomQwen3DecoderLayer(Qwen3DecoderLayer):
@@ -43,15 +43,15 @@ def __init__(
4343
assert isinstance(quant_config, AscendQuantConfig), \
4444
"Expected quant_config to be an instance of AscendQuantConfig"
4545

46-
if isinstance(self.self_attn.qkv_proj.quant_method,
46+
if isinstance(self.self_attn.qkv_proj.quant_method.quant_method,
4747
AscendW8A8LinearMethod):
48-
self.input_layernorm = AddRMSNormQuant(
48+
self.input_layernorm = AddRMSNormW8A8Quant(
4949
config.hidden_size,
5050
layer=self.self_attn.qkv_proj,
5151
eps=config.rms_norm_eps)
52-
if isinstance(self.mlp.gate_up_proj.quant_method,
52+
if isinstance(self.mlp.gate_up_proj.quant_method.quant_method,
5353
AscendW8A8LinearMethod):
54-
self.post_attention_layernorm = AddRMSNormQuant(
54+
self.post_attention_layernorm = AddRMSNormW8A8Quant(
5555
config.hidden_size,
5656
layer=self.mlp.gate_up_proj,
5757
eps=config.rms_norm_eps)

vllm_ascend/ops/layernorm.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,8 @@
2121
from vllm.model_executor.layers.layernorm import RMSNorm
2222

2323

24-
class AddRMSNormQuant(RMSNorm):
25-
"""Root mean square normalization.
26-
27-
Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
28-
Refer to https://arxiv.org/abs/1910.07467
29-
"""
24+
class AddRMSNormW8A8Quant(RMSNorm):
25+
# Fuse AddRmsNorm and W8A8 quantization ops together
3026

3127
def __init__(
3228
self,

0 commit comments

Comments
 (0)