Skip to content

[V0.9.1] Use AddRmsNormQuant ops in the custom model to optimize Qwen3's performance #1545

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions vllm_ascend/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def register_model():
from .qwen2_5_vl import \
AscendQwen2_5_VLForConditionalGeneration # noqa: F401
from .qwen2_vl import AscendQwen2VLForConditionalGeneration # noqa: F401
from .qwen3 import CustomQwen3ForCausalLM # noqa: F401

ModelRegistry.register_model(
"DeepSeekMTPModel",
Expand Down Expand Up @@ -52,3 +53,6 @@ def register_model():
ModelRegistry.register_model(
"Qwen3MoeForCausalLM",
"vllm_ascend.models.qwen3_moe:CustomQwen3MoeForCausalLM")

ModelRegistry.register_model(
"Qwen3ForCausalLM", "vllm_ascend.models.qwen3:CustomQwen3ForCausalLM")
156 changes: 156 additions & 0 deletions vllm_ascend/models/qwen3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
from collections.abc import Iterable
from typing import Optional, Union

import torch
from torch import nn
from transformers import Qwen3Config
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
from vllm.model_executor.models.qwen2 import Qwen2Model
from vllm.model_executor.models.qwen3 import Qwen3DecoderLayer
from vllm.model_executor.models.utils import (AutoWeightsLoader,
PPMissingLayer, maybe_prefix)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors

from vllm_ascend.ops.layernorm import AddRMSNormQuant


class CustomQwen3DecoderLayer(Qwen3DecoderLayer):

def __init__(
self,
config: Qwen3Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__(config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix)
if quant_config is None:
return

from vllm_ascend.quantization.quant_config import AscendQuantConfig
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod

assert isinstance(quant_config, AscendQuantConfig), \
"Expected quant_config to be an instance of AscendQuantConfig"

if isinstance(self.self_attn.qkv_proj.quant_method,
AscendW8A8LinearMethod):
self.input_layernorm = AddRMSNormQuant(
config.hidden_size,
layer=self.self_attn.qkv_proj,
eps=config.rms_norm_eps)
if isinstance(self.mlp.gate_up_proj.quant_method,
AscendW8A8LinearMethod):
self.post_attention_layernorm = AddRMSNormQuant(
config.hidden_size,
layer=self.mlp.gate_up_proj,
eps=config.rms_norm_eps)


ALL_DECODER_LAYER_TYPES = {
"attention": CustomQwen3DecoderLayer,
}


@support_torch_compile(
dynamic_arg_dims={
"input_ids": 0,
# positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
# otherwise (seq_len, ).
"positions": -1,
"intermediate_tensors": 0,
"inputs_embeds": 0,
})
class CustomQwen3Model(Qwen2Model):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config,
prefix=prefix,
decoder_layer_type=CustomQwen3DecoderLayer)


class CustomQwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
# add `CustomQwen3Model` to init self.model
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config

self.config = config
self.lora_config = lora_config

self.quant_config = quant_config
self.model = CustomQwen3Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))

if get_pp_group().is_last_rank:
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(
prefix, "lm_head"))
else:
self.lm_head = PPMissingLayer()

self.logits_processor = LogitsProcessor(config.vocab_size)

self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return hidden_states

def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits

def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights)
41 changes: 41 additions & 0 deletions vllm_ascend/ops/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,47 @@
from vllm.model_executor.layers.layernorm import RMSNorm


class AddRMSNormQuant(RMSNorm):
"""Root mean square normalization.
Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
Refer to https://arxiv.org/abs/1910.07467
"""

def __init__(
self,
hidden_size: int,
layer: torch.nn.Module,
eps: float = 1e-6,
var_hidden_size: Optional[int] = None,
has_weight: bool = True,
dtype: Optional[torch.dtype] = None,
) -> None:
super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype)
self.layer = layer

def forward(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
import torch_npu

if residual is not None:
x, _, residual = torch_npu.npu_add_rms_norm_quant(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does torch_npu.npu_add_rms_norm_quant require a newer version of torch_npu?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now version of PTA has supported it.

x,
residual,
self.weight,
self.layer.aclnn_input_scale,
self.layer.aclnn_input_offset,
epsilon=self.variance_epsilon)
return x, residual

x, residual = torch_npu.npu_rms_norm(x, self.weight,
self.variance_epsilon)
return x


def forward_oot(
self,
x: torch.Tensor,
Expand Down
12 changes: 7 additions & 5 deletions vllm_ascend/quantization/w8a8.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,10 @@ def apply(
bias: Optional[torch.Tensor] = None,
tp_rank: Optional[int] = 0,
) -> torch.Tensor:
original_dtype = x.dtype
if original_dtype != torch.int8:
if x.dtype != torch.int8:
x = quant_per_tensor(
x,
layer.aclnn_input_scale,
layer.aclnn_input_scale_reciprocal,
layer.aclnn_input_offset,
)
quant_bias = layer.quant_bias if tp_rank == 0 else None
Expand All @@ -106,12 +105,15 @@ def apply(
layer.weight,
layer.deq_scale,
bias=quant_bias,
output_dtype=original_dtype,
output_dtype=layer.params_dtype,
)

def process_weights_after_loading(self, layer):
expanding_factor = layer.weight.data.shape[1]
layer.aclnn_input_scale = 1 / torch.nn.Parameter(
layer.aclnn_input_scale = torch.nn.Parameter(
layer.input_scale.data.repeat(expanding_factor),
requires_grad=False)
layer.aclnn_input_scale_reciprocal = 1 / torch.nn.Parameter(
layer.input_scale.data.repeat(expanding_factor),
requires_grad=False)
layer.aclnn_input_offset = torch.nn.Parameter(
Expand Down