Skip to content

[Model][0.7.3] Add support for Qwen3 model #903

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: v0.7.3-dev
Choose a base branch
from

Conversation

shen-shanshan
Copy link
Collaborator

@shen-shanshan shen-shanshan commented May 20, 2025

What this PR does / why we need it?

Add support for Qwen3 model for v0.7.3.

NOTE:

The qwen2.py is added for custom Qwen2Model's __init__() method to recieve param of decoder_layer_type. In main branch of vllm, Qwen3DecoderLayer param will pass to this __init__() and so that Qwen3 can reuse the methods in Qwen2 without rewrite the whole Qwen3Model class.

Part of code in qwen2.py of v0.7.3 (can only init with Qwen2DecoderLayer):

class Qwen2Model(nn.Module):

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
            lambda prefix: Qwen2DecoderLayer(config=config,
                                             cache_config=cache_config,
                                             quant_config=quant_config,
                                             prefix=prefix),
            prefix=f"{prefix}.layers",
        )

Part of code in qwen2.py of main (can init with Qwen3DecoderLayer, so that Qwen3Model can extend this class):

class Qwen2Model(nn.Module):

    def __init__(self,
                 *,
                 vllm_config: VllmConfig,
                 prefix: str = "",
                 decoder_layer_type: type[nn.Module] = Qwen2DecoderLayer):
        super().__init__()

        # Use the provided decoder layer type or default to Qwen2DecoderLayer
        decoder_layer_type = decoder_layer_type or Qwen2DecoderLayer
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
            lambda prefix: decoder_layer_type(config=config,
                                              cache_config=cache_config,
                                              quant_config=quant_config,
                                              prefix=prefix),
            prefix=f"{prefix}.layers",
        )

Does this PR introduce any user-facing change?

no.

How was this patch tested?

All the scenarios shown below have been tested and passed:

  • vllm-ascend offline inference ✅
  • vllm-ascend online inference ✅
  • vllm-ascend + mindie-turbo offline inference ✅
  • vllm-ascend + mindie-turbo online inference ✅

TODO:

Need performance test @shen-shanshan and accuracy test @hfadzxy .

Signed-off-by: shen-shanshan <467638484@qq.com>
Signed-off-by: shen-shanshan <467638484@qq.com>
@shen-shanshan
Copy link
Collaborator Author

Comparing with qwen3.py in main of vllm:

diff --git a/vllm_ascend/models/qwen3.py b/vllm_ascend/models/qwen3.py
index dbe2be8..c1f681b 100644
--- a/vllm_ascend/models/qwen3.py
+++ b/vllm_ascend/models/qwen3.py
@@ -1,13 +1,9 @@
-# SPDX-License-Identifier: Apache-2.0
-
-# Copyright 2024 The Qwen team.
+#
+# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
 # Copyright 2023 The vLLM team.
-# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+# Copyright 2024 The Qwen team.
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team.
 #
-# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
-# and OPT implementations in this library. It has been modified from its
-# original forms to accommodate minor architectural differences compared
-# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -20,15 +16,17 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-"""Inference-only Qwen3 model compatible with HuggingFace weights."""
+# Adapted from vllm/model_executor/models/qwen3.py
+# This file is a part of the vllm-ascend project.
+
 from collections.abc import Iterable
-from typing import Optional, Union
+from typing import Iterable, List, Optional, Union
 
 import torch
 from torch import nn
 from transformers import Qwen3Config
 
-from vllm.attention import Attention, AttentionType
+from vllm.attention import Attention, AttentionMetadata, AttentionType
 from vllm.compilation.decorators import support_torch_compile
 from vllm.config import CacheConfig, VllmConfig
 from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
@@ -40,13 +38,14 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
 from vllm.model_executor.layers.quantization import QuantizationConfig
 from vllm.model_executor.layers.rotary_embedding import get_rope
 from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
+from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
 from vllm.model_executor.sampling_metadata import SamplingMetadata
 from vllm.sequence import IntermediateTensors
 
-from .interfaces import SupportsLoRA, SupportsPP
-from .qwen2 import Qwen2MLP as Qwen3MLP
-from .qwen2 import Qwen2Model
-from .utils import AutoWeightsLoader, PPMissingLayer, maybe_prefix
+from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
+from vllm.model_executor.models.qwen2 import Qwen2MLP as Qwen3MLP
+from vllm.model_executor.models.utils import AutoWeightsLoader, PPMissingLayer, maybe_prefix
+from vllm_ascend.models.qwen2 import CustomQwen2Model
 
 logger = init_logger(__name__)
 
@@ -128,6 +127,8 @@ class Qwen3Attention(nn.Module):
         self,
         positions: torch.Tensor,
         hidden_states: torch.Tensor,
+        kv_cache: torch.Tensor,
+        attn_metadata: AttentionMetadata,
     ) -> torch.Tensor:
         qkv, _ = self.qkv_proj(hidden_states)
         q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
@@ -141,7 +142,7 @@ class Qwen3Attention(nn.Module):
         k_by_head = self.k_norm(k_by_head)
         k = k_by_head.view(k.shape)
         q, k = self.rotary_emb(positions, q, k)
-        attn_output = self.attn(q, k, v)
+        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
         output, _ = self.o_proj(attn_output)
         return output
 
@@ -201,6 +202,8 @@ class Qwen3DecoderLayer(nn.Module):
         self,
         positions: torch.Tensor,
         hidden_states: torch.Tensor,
+        kv_cache: torch.Tensor,
+        attn_metadata: AttentionMetadata,
         residual: Optional[torch.Tensor],
     ) -> tuple[torch.Tensor, torch.Tensor]:
         # Self Attention
@@ -213,6 +216,8 @@ class Qwen3DecoderLayer(nn.Module):
         hidden_states = self.self_attn(
             positions=positions,
             hidden_states=hidden_states,
+            kv_cache=kv_cache,
+            attn_metadata=attn_metadata,
         )
 
         # Fully Connected
@@ -222,11 +227,6 @@ class Qwen3DecoderLayer(nn.Module):
         return hidden_states, residual
 
 
-ALL_DECODER_LAYER_TYPES = {
-    "attention": Qwen3DecoderLayer,
-}
-
-
 @support_torch_compile(
     dynamic_arg_dims={
         "input_ids": 0,
@@ -236,7 +236,7 @@ ALL_DECODER_LAYER_TYPES = {
         "intermediate_tensors": 0,
         "inputs_embeds": 0,
     })
-class Qwen3Model(Qwen2Model):
+class Qwen3Model(CustomQwen2Model):
 
     def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
         super().__init__(vllm_config=vllm_config,
@@ -284,6 +284,8 @@ class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
 
         self.logits_processor = LogitsProcessor(config.vocab_size)
 
+        self.sampler = get_sampler()
+
         self.make_empty_intermediate_tensors = (
             self.model.make_empty_intermediate_tensors)
 
@@ -294,10 +296,13 @@ class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
         self,
         input_ids: torch.Tensor,
         positions: torch.Tensor,
+        kv_caches: List[torch.Tensor],
+        attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors] = None,
         inputs_embeds: Optional[torch.Tensor] = None,
     ) -> Union[torch.Tensor, IntermediateTensors]:
-        hidden_states = self.model(input_ids, positions, intermediate_tensors,
+        hidden_states = self.model(input_ids, positions, kv_caches,
+                                   attn_metadata, intermediate_tensors,
                                    inputs_embeds)
         return hidden_states
 
@@ -318,3 +323,11 @@ class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
                            if self.config.tie_word_embeddings else None),
         )
         return loader.load_weights(weights)
+
+    def sample(
+        self,
+        logits: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[SamplerOutput]:
+        next_tokens = self.sampler(logits, sampling_metadata)
+        return next_tokens

@@ -0,0 +1,207 @@
#
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the qwen2.py here is only used for inherited by qwen3.py? why not merge the content into qwen3.py directlly?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"So the qwen2.py here is only used for inherited by qwen3.py?" -- Yes
"why not merge the content into qwen3.py directlly?" -- OK, I will move these codes into qwen3.py

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I rewrote the methods in qwen2.py into qwen3.py directly finally.

Signed-off-by: shen-shanshan <467638484@qq.com>
Signed-off-by: shen-shanshan <467638484@qq.com>
Copy link
Collaborator

@wangxiyuan wangxiyuan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please make sure the model works as expect before merge.

  • Accuracy
  • Performance with Mindie-turbo

@wangxiyuan wangxiyuan changed the title [Model] Add support for Qwen3 model for v0.7.3 [Model][0.7.3] Add support for Qwen3 model May 20, 2025
@hfadzxy
Copy link
Contributor

hfadzxy commented May 20, 2025

TODO:

Need performance test @shen-shanshan and accuracy test @hfadzxy .

I use gsm8k to test the accuracy find the accuracy of Qwen3-8B is very low:

Task Filter n-shot Metric Value Stderr
gsm8k flexible-extract 5 exact_match ↑ 0.1835 ± 0.0107

but the accuracy of Qwen3-8B-Base is normal:

Task Filter n-shot Metric Value Stderr
gsm8k flexible-extract 5 exact_match ↑ 0.8324 ± 0.0103

@shen-shanshan
Copy link
Collaborator Author

Benchmark results (comparing Qwen3 inference speed without and with mindie-turbo):

Qwen3_v0 7 3_benchmark

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants