diff --git a/vllm_ascend/models/__init__.py b/vllm_ascend/models/__init__.py index d85572b32c..324a31b212 100644 --- a/vllm_ascend/models/__init__.py +++ b/vllm_ascend/models/__init__.py @@ -11,6 +11,7 @@ def register_model(): from .qwen2_5_vl import \ AscendQwen2_5_VLForConditionalGeneration # noqa: F401 from .qwen2_vl import AscendQwen2VLForConditionalGeneration # noqa: F401 + from .qwen3 import CustomQwen3ForCausalLM # noqa: F401 ModelRegistry.register_model( "DeepSeekMTPModel", @@ -52,3 +53,6 @@ def register_model(): ModelRegistry.register_model( "Qwen3MoeForCausalLM", "vllm_ascend.models.qwen3_moe:CustomQwen3MoeForCausalLM") + + ModelRegistry.register_model( + "Qwen3ForCausalLM", "vllm_ascend.models.qwen3:CustomQwen3ForCausalLM") diff --git a/vllm_ascend/models/qwen3.py b/vllm_ascend/models/qwen3.py new file mode 100644 index 0000000000..3026f455f0 --- /dev/null +++ b/vllm_ascend/models/qwen3.py @@ -0,0 +1,156 @@ +from collections.abc import Iterable +from typing import Optional, Union + +import torch +from torch import nn +from transformers import Qwen3Config +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP +from vllm.model_executor.models.qwen2 import Qwen2Model +from vllm.model_executor.models.qwen3 import Qwen3DecoderLayer +from vllm.model_executor.models.utils import (AutoWeightsLoader, + PPMissingLayer, maybe_prefix) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from vllm_ascend.ops.layernorm import AddRMSNormQuant + + +class CustomQwen3DecoderLayer(Qwen3DecoderLayer): + + def __init__( + self, + config: Qwen3Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__(config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix) + if quant_config is None: + return + + from vllm_ascend.quantization.quant_config import AscendQuantConfig + from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod + + assert isinstance(quant_config, AscendQuantConfig), \ + "Expected quant_config to be an instance of AscendQuantConfig" + + if isinstance(self.self_attn.qkv_proj.quant_method, + AscendW8A8LinearMethod): + self.input_layernorm = AddRMSNormQuant( + config.hidden_size, + layer=self.self_attn.qkv_proj, + eps=config.rms_norm_eps) + if isinstance(self.mlp.gate_up_proj.quant_method, + AscendW8A8LinearMethod): + self.post_attention_layernorm = AddRMSNormQuant( + config.hidden_size, + layer=self.mlp.gate_up_proj, + eps=config.rms_norm_eps) + + +ALL_DECODER_LAYER_TYPES = { + "attention": CustomQwen3DecoderLayer, +} + + +@support_torch_compile( + dynamic_arg_dims={ + "input_ids": 0, + # positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl, + # otherwise (seq_len, ). + "positions": -1, + "intermediate_tensors": 0, + "inputs_embeds": 0, + }) +class CustomQwen3Model(Qwen2Model): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, + prefix=prefix, + decoder_layer_type=CustomQwen3DecoderLayer) + + +class CustomQwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + # add `CustomQwen3Model` to init self.model + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.lora_config = lora_config + + self.quant_config = quant_config + self.model = CustomQwen3Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + if get_pp_group().is_last_rank: + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix( + prefix, "lm_head")) + else: + self.lm_head = PPMissingLayer() + + self.logits_processor = LogitsProcessor(config.vocab_size) + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) diff --git a/vllm_ascend/ops/layernorm.py b/vllm_ascend/ops/layernorm.py index 8ff4c559e6..971ab42717 100644 --- a/vllm_ascend/ops/layernorm.py +++ b/vllm_ascend/ops/layernorm.py @@ -21,6 +21,47 @@ from vllm.model_executor.layers.layernorm import RMSNorm +class AddRMSNormQuant(RMSNorm): + """Root mean square normalization. + + Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight. + Refer to https://arxiv.org/abs/1910.07467 + """ + + def __init__( + self, + hidden_size: int, + layer: torch.nn.Module, + eps: float = 1e-6, + var_hidden_size: Optional[int] = None, + has_weight: bool = True, + dtype: Optional[torch.dtype] = None, + ) -> None: + super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype) + self.layer = layer + + def forward( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + import torch_npu + + if residual is not None: + x, _, residual = torch_npu.npu_add_rms_norm_quant( + x, + residual, + self.weight, + self.layer.aclnn_input_scale, + self.layer.aclnn_input_offset, + epsilon=self.variance_epsilon) + return x, residual + + x, residual = torch_npu.npu_rms_norm(x, self.weight, + self.variance_epsilon) + return x + + def forward_oot( self, x: torch.Tensor, diff --git a/vllm_ascend/quantization/w8a8.py b/vllm_ascend/quantization/w8a8.py index 28925034c1..9574f50c1f 100644 --- a/vllm_ascend/quantization/w8a8.py +++ b/vllm_ascend/quantization/w8a8.py @@ -93,11 +93,10 @@ def apply( bias: Optional[torch.Tensor] = None, tp_rank: Optional[int] = 0, ) -> torch.Tensor: - original_dtype = x.dtype - if original_dtype != torch.int8: + if x.dtype != torch.int8: x = quant_per_tensor( x, - layer.aclnn_input_scale, + layer.aclnn_input_scale_reciprocal, layer.aclnn_input_offset, ) quant_bias = layer.quant_bias if tp_rank == 0 else None @@ -106,12 +105,15 @@ def apply( layer.weight, layer.deq_scale, bias=quant_bias, - output_dtype=original_dtype, + output_dtype=layer.params_dtype, ) def process_weights_after_loading(self, layer): expanding_factor = layer.weight.data.shape[1] - layer.aclnn_input_scale = 1 / torch.nn.Parameter( + layer.aclnn_input_scale = torch.nn.Parameter( + layer.input_scale.data.repeat(expanding_factor), + requires_grad=False) + layer.aclnn_input_scale_reciprocal = 1 / torch.nn.Parameter( layer.input_scale.data.repeat(expanding_factor), requires_grad=False) layer.aclnn_input_offset = torch.nn.Parameter(