From 3e07ea107dade1a7aa7dc8e319bf59867148a23c Mon Sep 17 00:00:00 2001 From: lizexu Date: Mon, 7 Jul 2025 12:56:59 +0000 Subject: [PATCH 1/4] fix qwen3-8b --- .../model_executor/layers/embeddings.py | 2 ++ fastdeploy/model_executor/models/qwen3.py | 30 ++++++++++++------- fastdeploy/worker/gpu_model_runner.py | 2 +- fastdeploy/worker/worker_process.py | 1 + 4 files changed, 23 insertions(+), 12 deletions(-) diff --git a/fastdeploy/model_executor/layers/embeddings.py b/fastdeploy/model_executor/layers/embeddings.py index bc67cb1333..3a750eb5c4 100644 --- a/fastdeploy/model_executor/layers/embeddings.py +++ b/fastdeploy/model_executor/layers/embeddings.py @@ -122,6 +122,7 @@ def load_state_dict(self, state_dict: Dict[str, Args: state_dict (dict): A dictionary containing the checkpoint weights and biases. """ + if self.tie_word_embeddings: self.word_embeddings.weight.set_value( get_tensor(state_dict[self.prefix + ".weight"]).astype( @@ -131,6 +132,7 @@ def load_state_dict(self, state_dict: Dict[str, get_tensor(state_dict.pop(self.prefix + ".weight")).astype( paddle.get_default_dtype())) + def forward(self, ids_remove_padding=None) -> paddle.Tensor: """ Defines the forward computation of the layer. diff --git a/fastdeploy/model_executor/models/qwen3.py b/fastdeploy/model_executor/models/qwen3.py index c1654f4144..a44e67c94c 100644 --- a/fastdeploy/model_executor/models/qwen3.py +++ b/fastdeploy/model_executor/models/qwen3.py @@ -164,7 +164,6 @@ def __init__( self.num_layers = fd_config.model_config.num_layers fd_config.model_config.prefix_name = "model" - fd_config.model_config.tie_word_embeddings = True self.embeddings = VocabParallelEmbedding( fd_config=fd_config, @@ -240,14 +239,22 @@ def __init__(self, fd_config: FDConfig): self.model = Qwen3Model(fd_config=fd_config) self.ori_vocab_size = fd_config.model_config.ori_vocab_size - - self.lm_head = ParallelLMHead( - fd_config=fd_config, - embedding_dim=fd_config.model_config.hidden_size, - num_embeddings=fd_config.model_config.vocab_size, - prefix=(f"{fd_config.model_config.prefix_name}.embed_tokens"), - ) self.tie_word_embeddings = fd_config.model_config.tie_word_embeddings + if self.tie_word_embeddings: + self.lm_head = ParallelLMHead( + fd_config=fd_config, + embedding_dim=fd_config.model_config.hidden_size, + num_embeddings=fd_config.model_config.vocab_size, + prefix=(f"{fd_config.model_config.prefix_name}.embed_tokens"), + ) + else: + self.lm_head = ParallelLMHead( + fd_config=fd_config, + embedding_dim=fd_config.model_config.hidden_size, + num_embeddings=fd_config.model_config.vocab_size, + prefix="lm_head", + ) + # self.tie_word_embeddings = fd_config.model_config.tie_word_embeddings @classmethod def name(self): @@ -266,9 +273,9 @@ def set_state_dict(self, state_dict): and values are NumPy arrays or PaddlePaddle tensors. """ self.model.load_state_dict(state_dict) - if self.tie_word_embeddings: - self.lm_head.out_linear.weight.set_value( - self.model.embeddings.word_embeddings.weight.transpose([1, 0])) + # if self.tie_word_embeddings: + self.lm_head.out_linear.weight.set_value( + self.model.embeddings.word_embeddings.weight.transpose([1, 0])) self.lm_head.load_state_dict(state_dict) def compute_logits(self, hidden_states: paddle.Tensor): @@ -324,6 +331,7 @@ def get_tensor_parallel_split_mappings(num_layers): base_actions = { # Row Linear + "lm_head.weight": partial(fn, is_column=True), "embed_tokens.weight": partial(fn, is_column=False), "layers.0.self_attn.o_proj.weight": partial(fn, is_column=False), diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 8d6ca79a1b..c8942f3ba2 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -207,7 +207,6 @@ def insert_prefill_inputs(self, req_dicts: List[Request]): self.share_inputs["input_ids"][idx:idx + 1, :length] = np.array( request.prompt_token_ids) - # Use chunked prefill if self.parallel_config.enable_chunked_prefill: request.set("chunk_idx", 1) @@ -714,6 +713,7 @@ def initialize_attn_backend(self) -> None: # Get the attention backend attn_cls = get_attention_backend() + attn_backend = attn_cls(self.fd_config, kv_num_heads=self.model_config.kv_num_heads, num_heads=num_heads, diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index ba7a5541a5..dede0bb774 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -573,6 +573,7 @@ def initialize_fd_config(args: argparse.Namespace) -> FDConfig: model_config = ModelConfig.from_dict(config) # TODO Set `head_dim` again. Because `ModelConfig` class doesn't support feeding head_dim at all! model_config.head_dim = config["head_dim"] + model_config.tie_word_embeddings = config["tie_word_embeddings"] paddle.set_default_dtype(args.dtype) device_config = DeviceConfig() From bf9c960dedb0e336da4c91fd486dfd9116e7dbed Mon Sep 17 00:00:00 2001 From: lizexu Date: Mon, 7 Jul 2025 13:10:31 +0000 Subject: [PATCH 2/4] fix --- fastdeploy/model_executor/models/qwen3.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/fastdeploy/model_executor/models/qwen3.py b/fastdeploy/model_executor/models/qwen3.py index a44e67c94c..17b425185c 100644 --- a/fastdeploy/model_executor/models/qwen3.py +++ b/fastdeploy/model_executor/models/qwen3.py @@ -254,7 +254,6 @@ def __init__(self, fd_config: FDConfig): num_embeddings=fd_config.model_config.vocab_size, prefix="lm_head", ) - # self.tie_word_embeddings = fd_config.model_config.tie_word_embeddings @classmethod def name(self): @@ -273,7 +272,6 @@ def set_state_dict(self, state_dict): and values are NumPy arrays or PaddlePaddle tensors. """ self.model.load_state_dict(state_dict) - # if self.tie_word_embeddings: self.lm_head.out_linear.weight.set_value( self.model.embeddings.word_embeddings.weight.transpose([1, 0])) self.lm_head.load_state_dict(state_dict) From bc8d296e2e1e6a8be38245a0a0856e5faff6ce89 Mon Sep 17 00:00:00 2001 From: lizexu Date: Tue, 8 Jul 2025 12:31:48 +0000 Subject: [PATCH 3/4] merge develop --- fastdeploy/demo/offline_demo.py | 19 ++++++----- fastdeploy/envs.py | 2 +- .../layers/attention/attention_selecter.py | 9 +++-- fastdeploy/model_executor/models/qwen3.py | 34 ++++++++++--------- fastdeploy/worker/worker_process.py | 7 ++-- .../Qwen3-MoE/test_Qwen3-MoE_serving.py | 2 +- test/layers/test_append_attention.py | 3 +- 7 files changed, 43 insertions(+), 33 deletions(-) diff --git a/fastdeploy/demo/offline_demo.py b/fastdeploy/demo/offline_demo.py index 856757aa00..7e56799ee6 100644 --- a/fastdeploy/demo/offline_demo.py +++ b/fastdeploy/demo/offline_demo.py @@ -17,13 +17,14 @@ from fastdeploy.engine.sampling_params import SamplingParams from fastdeploy.entrypoints.llm import LLM -model_name_or_path = "./models/llama-7b" +model_name_or_path = "/root/.paddlenlp/models/Qwen/Qwen3-8B" +# model_name_or_path = "/home/zexuli/Models/Qwen3-0.6B" -# 超参设置 -sampling_params = SamplingParams(temperature=0.1, max_tokens=30) -llm = LLM(model=model_name_or_path, tensor_parallel_size=1) -output = llm.generate(prompts="who are you?", - use_tqdm=True, - sampling_params=sampling_params) - -print(output) +sampling_params = SamplingParams(temperature=0.1) +llm = LLM(model=model_name_or_path, tensor_parallel_size=2,reasoning_parser="qwen3") +prompt = "北京天安门在哪里?" +messages = [{"role": "user", "content": prompt}] +output = llm.chat([messages], + sampling_params) + +print(output) \ No newline at end of file diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 8ef8a5149c..2410f5083c 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -72,7 +72,7 @@ # Set attention backend. "NATIVE_ATTN", "APPEND_ATTN" # and "MLA_ATTN" can be set currently. "FD_ATTENTION_BACKEND": - lambda: os.getenv("FD_ATTENTION_BACKEND", "APPEND_ATTN"), + lambda: os.getenv("FD_ATTENTION_BACKEND", "APPEND_ATTN,NATIVE_ATTN").split(","), # Set sampling class. "base", "air" and "rejection" can be set currently. "FD_SAMPLING_CLASS": diff --git a/fastdeploy/model_executor/layers/attention/attention_selecter.py b/fastdeploy/model_executor/layers/attention/attention_selecter.py index 3db03b188e..29c183a155 100644 --- a/fastdeploy/model_executor/layers/attention/attention_selecter.py +++ b/fastdeploy/model_executor/layers/attention/attention_selecter.py @@ -34,7 +34,7 @@ def _get_attn_backend(selected_backend: str) -> object: selected_backend = backend_name_to_enum(selected_backend) attention_cls = current_platform.get_attention_backend_cls( selected_backend) - + print("attention_cls",attention_cls) if not attention_cls: raise ValueError( f"Invalid attention backend for {current_platform.device_name}") @@ -43,5 +43,8 @@ def _get_attn_backend(selected_backend: str) -> object: def get_attention_backend() -> object: """Selects which attention backend.""" - attention_backend = envs.FD_ATTENTION_BACKEND - return _get_attn_backend(attention_backend) + attention_backend,native_attention_backend = envs.FD_ATTENTION_BACKEND + if current_platform.is_cuda(): + return _get_attn_backend(attention_backend) + elif current_platform.is_cpu(): + return _get_attn_backend(native_attention_backend) diff --git a/fastdeploy/model_executor/models/qwen3.py b/fastdeploy/model_executor/models/qwen3.py index 17b425185c..cfce6f50ab 100644 --- a/fastdeploy/model_executor/models/qwen3.py +++ b/fastdeploy/model_executor/models/qwen3.py @@ -240,19 +240,19 @@ def __init__(self, fd_config: FDConfig): self.ori_vocab_size = fd_config.model_config.ori_vocab_size self.tie_word_embeddings = fd_config.model_config.tie_word_embeddings - if self.tie_word_embeddings: - self.lm_head = ParallelLMHead( - fd_config=fd_config, - embedding_dim=fd_config.model_config.hidden_size, - num_embeddings=fd_config.model_config.vocab_size, - prefix=(f"{fd_config.model_config.prefix_name}.embed_tokens"), - ) - else: - self.lm_head = ParallelLMHead( - fd_config=fd_config, - embedding_dim=fd_config.model_config.hidden_size, - num_embeddings=fd_config.model_config.vocab_size, - prefix="lm_head", + # if self.tie_word_embeddings: + # self.lm_head = ParallelLMHead( + # fd_config=fd_config, + # embedding_dim=fd_config.model_config.hidden_size, + # num_embeddings=fd_config.model_config.vocab_size, + # prefix=(f"{fd_config.model_config.prefix_name}.embed_tokens"), + # ) + # else: + self.lm_head = ParallelLMHead( + fd_config=fd_config, + embedding_dim=fd_config.model_config.hidden_size, + num_embeddings=fd_config.model_config.vocab_size, + prefix="lm_head", ) @classmethod @@ -272,9 +272,11 @@ def set_state_dict(self, state_dict): and values are NumPy arrays or PaddlePaddle tensors. """ self.model.load_state_dict(state_dict) - self.lm_head.out_linear.weight.set_value( - self.model.embeddings.word_embeddings.weight.transpose([1, 0])) - self.lm_head.load_state_dict(state_dict) + if self.tie_word_embeddings: + self.lm_head.out_linear.weight.set_value( + self.model.embeddings.word_embeddings.weight.transpose([1, 0])) + else: + self.lm_head.load_state_dict(state_dict) def compute_logits(self, hidden_states: paddle.Tensor): """ diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index e0c41f4aab..a2e81ff1f3 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -574,23 +574,26 @@ def initialize_fd_config(config_or_args) -> FDConfig: # Get model config from model directory model_config_dict, _ = ModelConfig.get_config_dict(config_or_args.model_name_or_path) + + # Handle MoE related configs if 'num_experts' in model_config_dict: model_config_dict['moe_num_experts'] = model_config_dict.pop('num_experts') if 'num_experts_per_tok' in model_config_dict: model_config_dict['moe_topk'] = model_config_dict.pop('num_experts_per_tok') + # Set default values for model config model_config_dict["head_dim"] = model_config_dict.get( "head_dim", model_config_dict["hidden_size"] // model_config_dict["num_attention_heads"]) model_config_dict["rope_theta"] = model_config_dict.get("rope_theta", 10000.0) - if 'tie_word_embeddings' in model_config_dict: - model_config_dict['tie_word_embeddings'] = model_config_dict.pop('tie_word_embeddings') # Create model config object model_config = ModelConfig.from_dict(model_config_dict) model_config.head_dim = model_config_dict["head_dim"] paddle.set_default_dtype(config_or_args.dtype) + if 'tie_word_embeddings' in model_config_dict: + model_config_dict['tie_word_embeddings'] = model_config_dict.pop('tie_word_embeddings') # Initialize all config components device_config = DeviceConfig() diff --git a/test/ci_use/Qwen3-MoE/test_Qwen3-MoE_serving.py b/test/ci_use/Qwen3-MoE/test_Qwen3-MoE_serving.py index 092b1282f3..704884ee6c 100644 --- a/test/ci_use/Qwen3-MoE/test_Qwen3-MoE_serving.py +++ b/test/ci_use/Qwen3-MoE/test_Qwen3-MoE_serving.py @@ -23,7 +23,7 @@ # Read ports from environment variables; use default values if not set -FD_API_PORT = int(os.getenv("FD_API_PORT", 8188)) +FD_API_PORT = int(os.getenv("FD_API_PORT", 8781)) FD_ENGINE_QUEUE_PORT = int(os.getenv("FD_ENGINE_QUEUE_PORT", 8133)) FD_METRICS_PORT = int(os.getenv("FD_METRICS_PORT", 8233)) diff --git a/test/layers/test_append_attention.py b/test/layers/test_append_attention.py index 2b23566efb..abc05dbaeb 100644 --- a/test/layers/test_append_attention.py +++ b/test/layers/test_append_attention.py @@ -80,6 +80,7 @@ def _apply_rope(self, rotary_emb, q, k, v=None, causal=False): # sin [θ0,θ1,θ2......θd/2-1] -> sin_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1] if self.use_neox_rotary_style: + print("use_neox_rotary_style也有?") sin_pos = sin cos_pos = cos # NeoX Stype:前后半部分分块旋转 @@ -92,7 +93,7 @@ def _apply_rope(self, rotary_emb, q, k, v=None, causal=False): paddle.shape(k), ) else: - # import pdb;pdb.set_trace() + print("跑的这里嘛") sin_pos = paddle.reshape(paddle.stack( [sin, sin], axis=-1), [1, 1, seq, head_dim]) # cos [θ0,θ1,θ2......θd/2-1] -> cos_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1] From 79d2fbc2d235fc9de9b6c0007c8555fc7c11265a Mon Sep 17 00:00:00 2001 From: lizexu Date: Wed, 9 Jul 2025 13:12:09 +0000 Subject: [PATCH 4/4] fix --- fastdeploy/worker/worker_process.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index fef56089ba..c4cf13cab7 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -596,7 +596,7 @@ def initialize_fd_config(config_or_args) -> FDConfig: model_config.head_dim = model_config_dict["head_dim"] paddle.set_default_dtype(config_or_args.dtype) if 'tie_word_embeddings' in model_config_dict: - model_config_dict['tie_word_embeddings'] = model_config_dict.pop('tie_word_embeddings') + model_config.tie_word_embeddings = model_config_dict['tie_word_embeddings'] # Initialize all config components device_config = DeviceConfig()