diff --git a/tests/e2e/multicard/test_offline_inference_distributed.py b/tests/e2e/multicard/test_offline_inference_distributed.py index 2b155383cc..3866172c1b 100644 --- a/tests/e2e/multicard/test_offline_inference_distributed.py +++ b/tests/e2e/multicard/test_offline_inference_distributed.py @@ -167,3 +167,20 @@ def test_models_distributed_topk() -> None: distributed_executor_backend="mp", ) as vllm_model: vllm_model.generate(example_prompts, sampling_params) + + +def test_models_distributed_Qwen3_W8A8(): + example_prompts = [ + "Hello, my name is", + ] + max_tokens = 5 + + with VllmRunner( + snapshot_download("vllm-ascend/Qwen3-8B-W8A8"), + max_model_len=8192, + enforce_eager=True, + dtype="auto", + tensor_parallel_size=4, + quantization="ascend", + ) as vllm_model: + vllm_model.generate_greedy(example_prompts, max_tokens) diff --git a/vllm_ascend/models/__init__.py b/vllm_ascend/models/__init__.py index cae779ce2f..ec0bccbc16 100644 --- a/vllm_ascend/models/__init__.py +++ b/vllm_ascend/models/__init__.py @@ -11,6 +11,7 @@ def register_model(): from .qwen2_5_vl import \ AscendQwen2_5_VLForConditionalGeneration # noqa: F401 from .qwen2_vl import AscendQwen2VLForConditionalGeneration # noqa: F401 + from .qwen3 import CustomQwen3ForCausalLM # noqa: F401 ModelRegistry.register_model( "DeepSeekMTPModel", @@ -53,6 +54,9 @@ def register_model(): "Qwen3MoeForCausalLM", "vllm_ascend.models.qwen3_moe:CustomQwen3MoeForCausalLM") + ModelRegistry.register_model( + "Qwen3ForCausalLM", "vllm_ascend.models.qwen3:CustomQwen3ForCausalLM") + ModelRegistry.register_model( "PanguProMoEForCausalLM", - "vllm_ascend.models.pangu_moe:PanguProMoEForCausalLM") \ No newline at end of file + "vllm_ascend.models.pangu_moe:PanguProMoEForCausalLM") diff --git a/vllm_ascend/models/qwen3.py b/vllm_ascend/models/qwen3.py new file mode 100644 index 0000000000..a05106f228 --- /dev/null +++ b/vllm_ascend/models/qwen3.py @@ -0,0 +1,156 @@ +from collections.abc import Iterable +from typing import Optional, Union + +import torch +from torch import nn +from transformers import Qwen3Config +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP +from vllm.model_executor.models.qwen2 import Qwen2Model +from vllm.model_executor.models.qwen3 import Qwen3DecoderLayer +from vllm.model_executor.models.utils import (AutoWeightsLoader, + PPMissingLayer, maybe_prefix) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from vllm_ascend.ops.layernorm import AddRMSNormW8A8Quant + + +class CustomQwen3DecoderLayer(Qwen3DecoderLayer): + + def __init__( + self, + config: Qwen3Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__(config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix) + if quant_config is None: + return + + from vllm_ascend.quantization.quant_config import AscendQuantConfig + from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod + + assert isinstance(quant_config, AscendQuantConfig), \ + "Expected quant_config to be an instance of AscendQuantConfig" + + if isinstance(self.self_attn.qkv_proj.quant_method.quant_method, + AscendW8A8LinearMethod): + self.input_layernorm = AddRMSNormW8A8Quant( + config.hidden_size, + layer=self.self_attn.qkv_proj, + eps=config.rms_norm_eps) + if isinstance(self.mlp.gate_up_proj.quant_method.quant_method, + AscendW8A8LinearMethod): + self.post_attention_layernorm = AddRMSNormW8A8Quant( + config.hidden_size, + layer=self.mlp.gate_up_proj, + eps=config.rms_norm_eps) + + +ALL_DECODER_LAYER_TYPES = { + "attention": CustomQwen3DecoderLayer, +} + + +@support_torch_compile( + dynamic_arg_dims={ + "input_ids": 0, + # positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl, + # otherwise (seq_len, ). + "positions": -1, + "intermediate_tensors": 0, + "inputs_embeds": 0, + }) +class CustomQwen3Model(Qwen2Model): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, + prefix=prefix, + decoder_layer_type=CustomQwen3DecoderLayer) + + +class CustomQwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + # add `CustomQwen3Model` to init self.model + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.lora_config = lora_config + + self.quant_config = quant_config + self.model = CustomQwen3Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + if get_pp_group().is_last_rank: + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix( + prefix, "lm_head")) + else: + self.lm_head = PPMissingLayer() + + self.logits_processor = LogitsProcessor(config.vocab_size) + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) diff --git a/vllm_ascend/ops/layernorm.py b/vllm_ascend/ops/layernorm.py index 7b839fe3d0..7506f87d88 100644 --- a/vllm_ascend/ops/layernorm.py +++ b/vllm_ascend/ops/layernorm.py @@ -23,6 +23,43 @@ from vllm_ascend.utils import is_310p +class AddRMSNormW8A8Quant(RMSNorm): + # Fuse AddRmsNorm and W8A8 quantization ops together + + def __init__( + self, + hidden_size: int, + layer: torch.nn.Module, + eps: float = 1e-6, + var_hidden_size: Optional[int] = None, + has_weight: bool = True, + dtype: Optional[torch.dtype] = None, + ) -> None: + super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype) + self.layer = layer + + def forward( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + import torch_npu + + if residual is not None: + x, _, residual = torch_npu.npu_add_rms_norm_quant( + x, + residual, + self.weight, + self.layer.aclnn_input_scale, + self.layer.aclnn_input_offset, + epsilon=self.variance_epsilon) + return x, residual + + x, residual = torch_npu.npu_rms_norm(x, self.weight, + self.variance_epsilon) + return x + + def forward_oot( self, x: torch.Tensor, diff --git a/vllm_ascend/quantization/w8a8.py b/vllm_ascend/quantization/w8a8.py index edd42e53bf..ea4e8910c6 100644 --- a/vllm_ascend/quantization/w8a8.py +++ b/vllm_ascend/quantization/w8a8.py @@ -91,10 +91,12 @@ def apply( bias: Optional[torch.Tensor] = None, tp_rank: Optional[int] = 0, ) -> torch.Tensor: - original_dtype = x.dtype - if original_dtype != torch.int8: - x = quant_per_tensor(x, layer.aclnn_input_scale, - layer.aclnn_input_offset) + if x.dtype != torch.int8: + x = quant_per_tensor( + x, + layer.aclnn_input_scale_reciprocal, + layer.aclnn_input_offset, + ) quant_bias = layer.quant_bias if tp_rank == 0 else None if is_310p(): # On 300I Duo platform, we need transpose again if @@ -104,7 +106,7 @@ def apply( layer.weight.data.transpose(1, 0), layer.deq_scale, bias=quant_bias, - output_dtype=original_dtype, + output_dtype=layer.params_dtype, ) else: output = torch_npu.npu_quant_matmul( @@ -112,13 +114,16 @@ def apply( layer.weight, layer.deq_scale, bias=quant_bias, - output_dtype=original_dtype, + output_dtype=layer.params_dtype, ) return output def process_weights_after_loading(self, layer): expanding_factor = layer.weight.data.shape[1] - layer.aclnn_input_scale = 1 / torch.nn.Parameter( + layer.aclnn_input_scale = torch.nn.Parameter( + layer.input_scale.data.repeat(expanding_factor), + requires_grad=False) + layer.aclnn_input_scale_reciprocal = 1 / torch.nn.Parameter( layer.input_scale.data.repeat(expanding_factor), requires_grad=False) layer.aclnn_input_offset = torch.nn.Parameter(