-
Notifications
You must be signed in to change notification settings - Fork 154
[Model][0.7.3] Add support for Qwen3 model #903
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 2 commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,207 @@ | ||
# | ||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. | ||
# Copyright 2023 The vLLM team. | ||
# Copyright 2024 The Qwen team. | ||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. | ||
# | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# Adapted from vllm/model_executor/models/qwen2.py | ||
# This file is a part of the vllm-ascend project. | ||
|
||
from typing import Iterable, List, Optional, Set, Tuple, Union | ||
|
||
import torch | ||
from torch import nn | ||
|
||
from vllm.attention import AttentionMetadata | ||
from vllm.compilation.decorators import support_torch_compile | ||
from vllm.config import VllmConfig | ||
from vllm.distributed import get_pp_group | ||
from vllm.logger import init_logger | ||
from vllm.model_executor.layers.layernorm import RMSNorm | ||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding | ||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader, maybe_remap_kv_scale_name | ||
from vllm.sequence import IntermediateTensors | ||
|
||
from vllm.model_executor.models.utils import (PPMissingLayer, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers,) | ||
from vllm.model_executor.models.qwen2 import Qwen2DecoderLayer | ||
|
||
|
||
logger = init_logger(__name__) | ||
|
||
|
||
@support_torch_compile( | ||
dynamic_arg_dims={ | ||
"input_ids": 0, | ||
# positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl, | ||
# otherwise (seq_len, ). | ||
"positions": -1, | ||
"intermediate_tensors": 0, | ||
"inputs_embeds": 0, | ||
}) | ||
class CustomQwen2Model(nn.Module): | ||
# NOTE: Patch this to recieve `Qwen3DecoderLayer` param, | ||
# so that Qwen3Model can reuse the methods of Qwen2Model. | ||
|
||
def __init__(self, | ||
*, | ||
vllm_config: VllmConfig, | ||
prefix: str = "", | ||
decoder_layer_type: type[nn.Module] = Qwen2DecoderLayer): | ||
super().__init__() | ||
|
||
config = vllm_config.model_config.hf_config | ||
cache_config = vllm_config.cache_config | ||
quant_config = vllm_config.quant_config | ||
|
||
# TODO (@robertgshaw2): see if this can be moved out | ||
if (cache_config.sliding_window is not None | ||
and hasattr(config, "max_window_layers")): | ||
assert config.max_window_layers == config.num_hidden_layers, ( | ||
"Sliding window for some but all layers is not supported. " | ||
"This model uses sliding window but `max_window_layers` = {} " | ||
"is less than `num_hidden_layers` = {}. Please open an issue " | ||
"to discuss this feature.".format( | ||
config.max_window_layers, | ||
config.num_hidden_layers, | ||
)) | ||
|
||
self.config = config | ||
self.quant_config = quant_config | ||
self.vocab_size = config.vocab_size | ||
|
||
if get_pp_group().is_first_rank or (config.tie_word_embeddings | ||
and get_pp_group().is_last_rank): | ||
self.embed_tokens = VocabParallelEmbedding( | ||
config.vocab_size, | ||
config.hidden_size, | ||
quant_config=quant_config, | ||
prefix=f"{prefix}.embed_tokens", | ||
) | ||
else: | ||
self.embed_tokens = PPMissingLayer() | ||
|
||
# Use the provided decoder layer type or default to Qwen2DecoderLayer | ||
decoder_layer_type = decoder_layer_type or Qwen2DecoderLayer | ||
self.start_layer, self.end_layer, self.layers = make_layers( | ||
config.num_hidden_layers, | ||
lambda prefix: decoder_layer_type(config=config, | ||
cache_config=cache_config, | ||
quant_config=quant_config, | ||
prefix=prefix), | ||
prefix=f"{prefix}.layers", | ||
) | ||
|
||
self.make_empty_intermediate_tensors = ( | ||
make_empty_intermediate_tensors_factory( | ||
["hidden_states", "residual"], config.hidden_size)) | ||
if get_pp_group().is_last_rank: | ||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) | ||
else: | ||
self.norm = PPMissingLayer() | ||
|
||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: | ||
return self.embed_tokens(input_ids) | ||
|
||
def forward( | ||
self, | ||
input_ids: torch.Tensor, | ||
positions: torch.Tensor, | ||
kv_caches: List[torch.Tensor], | ||
attn_metadata: AttentionMetadata, | ||
intermediate_tensors: Optional[IntermediateTensors] = None, | ||
inputs_embeds: Optional[torch.Tensor] = None, | ||
) -> Union[torch.Tensor, IntermediateTensors]: | ||
if get_pp_group().is_first_rank: | ||
if inputs_embeds is not None: | ||
hidden_states = inputs_embeds | ||
else: | ||
hidden_states = self.get_input_embeddings(input_ids) | ||
residual = None | ||
else: | ||
assert intermediate_tensors is not None | ||
hidden_states = intermediate_tensors["hidden_states"] | ||
residual = intermediate_tensors["residual"] | ||
for i in range(self.start_layer, self.end_layer): | ||
layer = self.layers[i] | ||
hidden_states, residual = layer( | ||
positions, | ||
hidden_states, | ||
kv_caches[i - self.start_layer], | ||
attn_metadata, | ||
residual, | ||
) | ||
if not get_pp_group().is_last_rank: | ||
return IntermediateTensors({ | ||
"hidden_states": hidden_states, | ||
"residual": residual | ||
}) | ||
hidden_states, _ = self.norm(hidden_states, residual) | ||
return hidden_states | ||
|
||
def load_weights(self, weights: Iterable[Tuple[str, | ||
torch.Tensor]]) -> Set[str]: | ||
stacked_params_mapping = [ | ||
# (param_name, shard_name, shard_id) | ||
("qkv_proj", "q_proj", "q"), | ||
("qkv_proj", "k_proj", "k"), | ||
("qkv_proj", "v_proj", "v"), | ||
("gate_up_proj", "gate_proj", 0), | ||
("gate_up_proj", "up_proj", 1), | ||
] | ||
params_dict = dict(self.named_parameters(remove_duplicate=False)) | ||
loaded_params: Set[str] = set() | ||
for name, loaded_weight in weights: | ||
if "rotary_emb.inv_freq" in name: | ||
continue | ||
if (self.quant_config is not None and | ||
(scale_name := self.quant_config.get_cache_scale(name))): | ||
# Loading kv cache quantization scales | ||
param = params_dict[scale_name] | ||
weight_loader = getattr(param, "weight_loader", | ||
default_weight_loader) | ||
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else | ||
loaded_weight[0]) | ||
weight_loader(param, loaded_weight) | ||
loaded_params.add(scale_name) | ||
continue | ||
for (param_name, weight_name, shard_id) in stacked_params_mapping: | ||
if weight_name not in name: | ||
continue | ||
name = name.replace(weight_name, param_name) | ||
# Skip loading extra bias for GPTQ models. | ||
if name.endswith(".bias") and name not in params_dict: | ||
continue | ||
if is_pp_missing_parameter(name, self): | ||
continue | ||
param = params_dict[name] | ||
weight_loader = param.weight_loader | ||
weight_loader(param, loaded_weight, shard_id) | ||
break | ||
else: | ||
# Skip loading extra bias for GPTQ models. | ||
if name.endswith(".bias") and name not in params_dict: | ||
continue | ||
# Remapping the name of FP8 kv-scale. | ||
name = maybe_remap_kv_scale_name(name, params_dict) | ||
if name is None: | ||
continue | ||
if is_pp_missing_parameter(name, self): | ||
continue | ||
param = params_dict[name] | ||
weight_loader = getattr(param, "weight_loader", | ||
default_weight_loader) | ||
weight_loader(param, loaded_weight) | ||
loaded_params.add(name) | ||
return loaded_params |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So the
qwen2.py
here is only used for inherited byqwen3.py
? why not merge the content into qwen3.py directlly?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"So the qwen2.py here is only used for inherited by qwen3.py?" -- Yes
"why not merge the content into qwen3.py directlly?" -- OK, I will move these codes into
qwen3.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I rewrote the methods in
qwen2.py
intoqwen3.py
directly finally.