diff --git a/custom_ops/gpu_ops/get_output_msg_with_topk.cc b/custom_ops/gpu_ops/get_output_msg_with_topk.cc index c4b6b14a4c..5da88dc1d6 100644 --- a/custom_ops/gpu_ops/get_output_msg_with_topk.cc +++ b/custom_ops/gpu_ops/get_output_msg_with_topk.cc @@ -24,16 +24,18 @@ #endif #define MAX_BSZ 512 -#define K 10 +#define K 20 struct msgdata { long mtype; int mtext[MAX_BSZ * (K + 1) + 2]; // stop_flag, bsz, tokens float mtext_f[MAX_BSZ * (K + 1)]; // score + int mtext_ranks[MAX_BSZ]; // ranks }; void GetOutputTopK(const paddle::Tensor& x, const paddle::Tensor& scores, + const paddle::Tensor& ranks, int k, int64_t rank_id, bool wait_flag) { @@ -66,17 +68,18 @@ void GetOutputTopK(const paddle::Tensor& x, int64_t* out_data = const_cast(x.data()); float* scores_data = const_cast(scores.data()); + int64_t* ranks_data = const_cast(ranks.data()); int ret = -1; if (!wait_flag) { ret = msgrcv(msgid, &msg_rcv, - (MAX_BSZ * (K + 1) + 2) * 4 + MAX_BSZ * (K + 1) * 4, + (MAX_BSZ * (K + 1) + 2) * 4 + MAX_BSZ * (K + 1) * 4 + MAX_BSZ * 4, 0, IPC_NOWAIT); } else { ret = msgrcv(msgid, &msg_rcv, - (MAX_BSZ * (K + 1) + 2) * 4 + MAX_BSZ * (K + 1) * 4, + (MAX_BSZ * (K + 1) + 2) * 4 + MAX_BSZ * (K + 1) * 4 + MAX_BSZ * 4, 0, 0); } @@ -97,13 +100,14 @@ void GetOutputTopK(const paddle::Tensor& x, out_data[offset + 2] = (int64_t)msg_rcv.mtext[offset + 2]; scores_data[offset] = msg_rcv.mtext_f[offset]; } + ranks_data[i] = (int64_t)msg_rcv.mtext_ranks[i]; } return; } PD_BUILD_STATIC_OP(get_output_topk) - .Inputs({"x", "scores"}) + .Inputs({"x", "scores", "ranks"}) .Attrs({"k: int", "rank_id: int64_t", "wait_flag: bool"}) - .Outputs({"x_out", "scores_out"}) - .SetInplaceMap({{"x", "x_out"}, {"scores", "scores_out"}}) + .Outputs({"x_out", "scores_out", "ranks_out"}) + .SetInplaceMap({{"x", "x_out"}, {"scores", "scores_out"}, {"ranks", "ranks_out"}}) .SetKernelFn(PD_KERNEL(GetOutputTopK)); diff --git a/custom_ops/gpu_ops/save_output_msg_with_topk.cc b/custom_ops/gpu_ops/save_output_msg_with_topk.cc index ee2cf865d8..fcaaca4dac 100644 --- a/custom_ops/gpu_ops/save_output_msg_with_topk.cc +++ b/custom_ops/gpu_ops/save_output_msg_with_topk.cc @@ -23,34 +23,31 @@ #define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) #endif -#define MAX_BSZ 128 -#define K 10 +#define MAX_BSZ 512 +#define K 20 // #define SAVE_WITH_OUTPUT_DEBUG struct msgdata { long mtype; int mtext[MAX_BSZ * (K + 1) + 2]; // stop_flag, bsz, tokens float mtext_f[MAX_BSZ * (K + 1)]; // score + int mtext_ranks[MAX_BSZ]; // ranks }; void SaveOutMmsgTopK(const paddle::Tensor& x, const paddle::Tensor& scores, - const paddle::Tensor& topk_ids, - const paddle::Tensor& topk_scores, // [bsz, k] + const paddle::Tensor& ranks, const paddle::Tensor& not_need_stop, - int k, int64_t rank_id) { if (rank_id > 0) { return; } auto x_cpu = x.copy_to(paddle::CPUPlace(), false); auto scores_cpu = scores.copy_to(paddle::CPUPlace(), false); - auto topk_ids_cpu = topk_ids.copy_to(paddle::CPUPlace(), false); - auto topk_scores_cpu = topk_scores.copy_to(paddle::CPUPlace(), false); + auto ranks_cpu = ranks.copy_to(paddle::CPUPlace(), false); int64_t* x_data = x_cpu.data(); float* scores_data = scores_cpu.data(); - int64_t* topk_ids_data = topk_ids_cpu.data(); - float* topk_scores_data = topk_scores_cpu.data(); + int64_t* ranks_data = ranks_cpu.data(); static struct msgdata msg_sed; int msg_queue_id = 1; if (const char* inference_msg_queue_id_env_p = @@ -106,21 +103,16 @@ void SaveOutMmsgTopK(const paddle::Tensor& x, msg_sed.mtext[0] = not_need_stop_data ? inference_msg_id_from_env : -inference_msg_id_from_env; int bsz = x.shape()[0]; + int token_num = x.shape()[1]; + int k = token_num - 1; msg_sed.mtext[1] = bsz; for (int i = 0; i < bsz; i++) { - for (int j = 0; j < k + 1; j++) { + for (int j = 0; j < token_num; j++) { const int64_t offset = i * (K + 1) + j; - if (j == 0) { - msg_sed.mtext[offset + 2] = (int)x_data[i]; - msg_sed.mtext_f[offset] = scores_data[i]; - } else if (j <= k + 1) { - msg_sed.mtext[offset + 2] = (int)topk_ids_data[i * k + j - 1]; - msg_sed.mtext_f[offset] = topk_scores_data[i * k + j - 1]; - } else { - msg_sed.mtext[offset + 2] = -1; - msg_sed.mtext_f[offset] = 0.0; - } + msg_sed.mtext[offset + 2] = (int)x_data[i * token_num + j]; + msg_sed.mtext_f[offset] = scores_data[i * token_num + j]; } + msg_sed.mtext_ranks[i] = (int)ranks_data[i]; } #ifdef SAVE_WITH_OUTPUT_DEBUG std::cout << "msg data: "; @@ -131,7 +123,7 @@ void SaveOutMmsgTopK(const paddle::Tensor& x, #endif if ((msgsnd(msgid, &msg_sed, - (MAX_BSZ * (K + 1) + 2) * 4 + (MAX_BSZ * (K + 1)) * 4, + (MAX_BSZ * (K + 1) + 2) * 4 + (MAX_BSZ * (K + 1)) * 4 + MAX_BSZ * 4, 0)) == -1) { printf("full msg buffer\n"); } @@ -139,8 +131,8 @@ void SaveOutMmsgTopK(const paddle::Tensor& x, } PD_BUILD_STATIC_OP(save_output_topk) - .Inputs({"x", "scores", "topk_ids", "topk_scores", "not_need_stop"}) - .Attrs({"k: int", "rank_id: int64_t"}) + .Inputs({"x", "scores", "ranks", "not_need_stop"}) + .Attrs({"rank_id: int64_t"}) .Outputs({"x_out"}) .SetInplaceMap({{"x", "x_out"}}) .SetKernelFn(PD_KERNEL(SaveOutMmsgTopK)); diff --git a/custom_ops/setup_ops_base.py b/custom_ops/setup_ops_base.py index fb8b76b75e..d05b1d39e5 100644 --- a/custom_ops/setup_ops_base.py +++ b/custom_ops/setup_ops_base.py @@ -22,6 +22,7 @@ "gpu_ops/save_with_output_msg.cc", "gpu_ops/get_output.cc", "gpu_ops/get_output_msg_with_topk.cc", + "gpu_ops/save_output_msg_with_topk.cc", "gpu_ops/transfer_output.cc", "cpu_ops/rebuild_padding.cc", ], diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index 2611214cf2..b6f0c5d1f3 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -292,6 +292,11 @@ class EngineArgs: Example: max_capture_batch_size=64, FastDeploy will capture graphs for batches [1,64]. """ + enable_logprob: bool = False + """ + Flag to enable logprob output. Default is False (disabled). + Must be explicitly enabled via the `--enable-logprob` startup parameter to output logprob values. + """ def __post_init__(self): """ @@ -413,6 +418,11 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: help= "Disabled any whitespaces when using guided decoding backend XGrammar." ) + model_group.add_argument("--enable-logprob", + action="store_true", + default=EngineArgs.enable_logprob, + help="Enable output of token-level log probabilities." + ) # Parallel processing parameters group parallel_group = parser.add_argument_group("Parallel Configuration") @@ -784,4 +794,5 @@ def create_engine_config(self) -> Config: max_capture_batch_size=self.max_capture_batch_size, guided_decoding_backend=self.guided_decoding_backend, disable_any_whitespace=self.guided_decoding_disable_any_whitespace, + enable_logprob = self.enable_logprob, ) diff --git a/fastdeploy/engine/config.py b/fastdeploy/engine/config.py index 65f67254fe..c0ded1427b 100644 --- a/fastdeploy/engine/config.py +++ b/fastdeploy/engine/config.py @@ -526,6 +526,8 @@ def __init__( max_capture_batch_size: int = 64, guided_decoding_backend: Optional[str] = None, disable_any_whitespace: bool = False, + enable_logprob: bool = False, + max_logprobs: int = None, ): """ Initialize the Config class. @@ -621,6 +623,13 @@ def __init__( self.device_ids = ",".join([str(i) for i in range(self.worker_num_per_node)]) self.device_ids = os.getenv("CUDA_VISIBLE_DEVICES", self.device_ids) + self.enable_logprob = enable_logprob + # 现有架构下只能通过指定 max_logprobs 来申请通信缓冲区,默认使用 10(但前提是 enable_logprob 为 True) + if enable_logprob: + self.max_logprobs = 10 if max_logprobs is None else max_logprobs + else: + self.max_logprobs = 0 + self.read_from_config() self.postprocess() self.check() diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index 5fca12f0b9..4b04fef06c 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -998,8 +998,8 @@ def _start_worker_service(self): py_script = os.path.join(current_dir_path, worker_path) ori_vocab_size = ( - len(self.data_processor.tokenizer.sp_model) - if hasattr(self.data_processor.tokenizer, 'sp_model') + len(self.data_processor.tokenizer.sp_model) + if hasattr(self.data_processor.tokenizer, 'sp_model') else len(self.data_processor.tokenizer.vocab) ) @@ -1047,6 +1047,7 @@ def _start_worker_service(self): self.cfg.enable_static_graph_inference, "use_cudagraph": self.cfg.use_cudagraph, "disable_any_whitespace": self.cfg.disable_any_whitespace, + "enable_logprob": self.cfg.enable_logprob, } for worker_flag, value in worker_append_flag.items(): if value: diff --git a/fastdeploy/engine/request.py b/fastdeploy/engine/request.py index e71f398069..96c8ee4448 100644 --- a/fastdeploy/engine/request.py +++ b/fastdeploy/engine/request.py @@ -19,6 +19,7 @@ import time from dataclasses import asdict, dataclass, fields from typing import Any, Dict, Optional, Union +from fastdeploy.worker.output import LogprobsLists import numpy @@ -189,6 +190,8 @@ class CompletionOutput: index: int send_idx: int token_ids: list[int] + logprob: Optional[float] = None + top_logprobs: Optional[LogprobsLists] = None draft_token_ids: list[int] = None text: Optional[str] = None reasoning_content: Optional[str] = None @@ -201,6 +204,8 @@ def to_dict(self): "index": self.index, "send_idx": self.send_idx, "token_ids": self.token_ids, + "logprob": self.logprob, + "top_logprobs": self.top_logprobs, "draft_token_ids": self.draft_token_ids, "text": self.text, "reasoning_content": self.reasoning_content diff --git a/fastdeploy/engine/sampling_params.py b/fastdeploy/engine/sampling_params.py index 0f60cf36b7..d01c05127b 100644 --- a/fastdeploy/engine/sampling_params.py +++ b/fastdeploy/engine/sampling_params.py @@ -52,6 +52,8 @@ class SamplingParams: the model more random. Zero means greedy sampling. top_p: Float that controls the cumulative probability of the top tokens to consider. Must be in [0, 1]. Set to 1 to consider all tokens. + top_k: Integer that controls the number of top tokens to consider. Set + to 0 (or -1) to consider all tokens. seed: Random seed to use for the generation. stop: list of strings that stop the generation when they are generated. The returned output will not contain the stop strings. @@ -82,6 +84,7 @@ class SamplingParams: repetition_penalty: float = None temperature: float = None top_p: float = None + top_k: int = 0 seed: Optional[int] = None stop: Optional[Union[str, List[str]]] = None stop_token_ids: Optional[Union[List[List[int]], List[int]]] = None @@ -111,6 +114,7 @@ def from_optional(cls, repetition_penalty, temperature, top_p, + top_k, seed=None, stop=None, stop_token_ids=None, @@ -130,6 +134,7 @@ def from_optional(cls, if repetition_penalty is not None else 1.0, temperature=temperature if temperature is not None else 1.0, top_p=top_p if top_p is not None else 0.7, + top_k=top_k, seed=seed, stop=stop, stop_token_ids=stop_token_ids, @@ -169,7 +174,13 @@ def _verify_args(self) -> None: f"temperature must be non-negative, got {self.temperature}.") if self.top_p is not None and not 0.0 <= self.top_p <= 1.0: raise ValueError(f"top_p must be in [0, 1], got {self.top_p}.") - + # quietly accept -1 as disabled, but prefer 0 + if self.top_k < -1: + raise ValueError(f"top_k must be 0 (disable), or at least 1, " + f"got {self.top_k}.") + if not isinstance(self.top_k, int): + raise TypeError( + f"top_k must be an integer, got {type(self.top_k).__name__}") if self.max_tokens is not None and self.max_tokens < 1: raise ValueError( f"max_tokens must be at least 1, got {self.max_tokens}.") @@ -189,6 +200,10 @@ def _verify_args(self) -> None: raise ValueError( f"logprobs must be non-negative, got {self.logprobs}.") + if self.logprobs is not None and self.logprobs > 20: + raise ValueError( + f"Invalid value for 'top_logprobs': must be less than or equal to 20.") + if not 0 <= self.seed <= 922337203685477580: raise ValueError("seed must be in [0, 922337203685477580], got " f"{self.seed}.") diff --git a/fastdeploy/entrypoints/openai/protocol.py b/fastdeploy/entrypoints/openai/protocol.py index 8457867047..c264adf192 100644 --- a/fastdeploy/entrypoints/openai/protocol.py +++ b/fastdeploy/entrypoints/openai/protocol.py @@ -122,6 +122,7 @@ class ChatCompletionResponseChoice(BaseModel): """ index: int message: ChatMessage + logprobs: Optional[LogProbs] = None finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop"]] @@ -135,7 +136,15 @@ class ChatCompletionResponse(BaseModel): model: str choices: List[ChatCompletionResponseChoice] usage: UsageInfo +class LogProbEntry(BaseModel): + token: str + logprob: float + bytes: Optional[List[int]] = None + top_logprobs: Optional[List["LogProbEntry"]] = None +class LogProbs(BaseModel): + content: Optional[List[LogProbEntry]] = None + refusal: Optional[Union[str, None]] = None class DeltaMessage(BaseModel): """ @@ -154,6 +163,7 @@ class ChatCompletionResponseStreamChoice(BaseModel): """ index: int delta: DeltaMessage + logprobs: Optional[LogProbs] = None finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None arrival_time: Optional[float] = None @@ -391,7 +401,9 @@ class ChatCompletionRequest(BaseModel): tools: Optional[List[ChatCompletionToolsParam]] = None model: Optional[str] = "default" frequency_penalty: Optional[float] = None - # remove max_tokens when field is removed from OpenAI API + logprobs: Optional[bool] = False + top_logprobs: Optional[int] = 0 + # TODO(#9845): remove max_tokens when field is removed from OpenAI API max_tokens: Optional[int] = Field( default=None, deprecated= @@ -432,6 +444,10 @@ def to_dict_for_infer(self, request_id=None): if request_id is not None: req_dict['request_id'] = request_id + req_dict["max_tokens"] = self.max_completion_tokens or self.max_tokens + + req_dict["logprobs"] = self.top_logprobs if self.logprobs else None + if self.metadata is not None: for key, value in self.metadata.items(): req_dict[key] = value @@ -503,3 +519,17 @@ def validate_stream_options(cls, data): ) return data + @model_validator(mode="before") + @classmethod + def check_logprobs(cls, data): + + if (top_logprobs := data.get("top_logprobs")) is not None: + if top_logprobs < 0: + raise ValueError("`top_logprobs` must be a positive value.") + + if top_logprobs > 0 and not data.get("logprobs"): + raise ValueError( + "when using `top_logprobs`, `logprobs` must be set to true." + ) + + return data diff --git a/fastdeploy/entrypoints/openai/serving_chat.py b/fastdeploy/entrypoints/openai/serving_chat.py index 2c90026660..448ce6947d 100644 --- a/fastdeploy/entrypoints/openai/serving_chat.py +++ b/fastdeploy/entrypoints/openai/serving_chat.py @@ -22,6 +22,7 @@ from collections.abc import AsyncGenerator, AsyncIterator from typing import Callable, Optional, Union, List import uuid +import traceback from fastapi import Request from pydantic import BaseModel @@ -35,12 +36,12 @@ UsageInfo, PromptTokenUsageInfo, ChatCompletionResponse, - ErrorResponse, + ErrorResponse, LogProbEntry, LogProbs, ) from fastdeploy.metrics.work_metrics import work_process_metrics from fastdeploy.utils import api_server_logger, get_host_ip from fastdeploy.engine.request import RequestOutput - +from fastdeploy.worker.output import LogprobsLists class OpenAIServingChat: @@ -62,8 +63,8 @@ def _check_master(self): return False async def create_chat_completion( - self, - request: ChatCompletionRequest + self, + request: ChatCompletionRequest ): """ Create a new chat completion using the specified parameters. @@ -110,11 +111,11 @@ def _create_streaming_error_response(self, message: str) -> str: return error_response.model_dump_json() async def chat_completion_stream_generator( - self, - request: ChatCompletionRequest, - request_id: str, - model_name: str, - prompt_token_ids: list() + self, + request: ChatCompletionRequest, + request_id: str, + model_name: str, + prompt_token_ids: list() ): """ Streaming chat completion generator. @@ -169,7 +170,7 @@ async def chat_completion_stream_generator( current_waiting_time = 0 await asyncio.sleep(0.1) continue - + res = json.loads(raw_data[-1].decode('utf-8')) if res.get("error_code", 200) != 200: raise ValueError("{}".format(res["error_msg"])) @@ -212,15 +213,28 @@ async def chat_completion_stream_generator( output = res["outputs"] delta_text = output["text"] - + raw_top_logprobs = output["top_logprobs"] + logprobs_res = None + if raw_top_logprobs is not None: + top_logprobs = LogprobsLists( + logprob_token_ids=raw_top_logprobs[0], + logprobs=raw_top_logprobs[1], + sampled_token_ranks=raw_top_logprobs[2], + ) + logprobs_res = self.build_logprobs_response( + logprobs=top_logprobs, + request_top_logprobs=request.top_logprobs, + ) previous_num_tokens += len(output["token_ids"]) delta_message = DeltaMessage(content=delta_text, reasoning_content=output.get("reasoning_content"), \ - token_ids=output.get("token_ids"), tool_calls=output.get("tool_call_content", [])) + token_ids=output.get("token_ids"), + tool_calls=output.get("tool_call_content", [])) choice = ChatCompletionResponseStreamChoice( index=0, delta=delta_message, - arrival_time=arrival_time + logprobs=logprobs_res, + arrival_time=arrival_time, ) if res["finished"]: num_choices -= 1 @@ -232,7 +246,7 @@ async def chat_completion_stream_generator( choice.finish_reason = "tool_calls" else: choice.finish_reason = "length" - + if res.get("error_msg") is not None and "Recover" in res["error_msg"]: choice.finish_reason = "recover_stop" @@ -251,7 +265,6 @@ async def chat_completion_stream_generator( yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n" choices = [] - if include_usage: completion_tokens = previous_num_tokens usage = UsageInfo( @@ -270,6 +283,8 @@ async def chat_completion_stream_generator( yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n" except Exception as e: + api_server_logger.error("Exception occurred during chat completion: %s", str(e)) + api_server_logger.error("Stack trace:\n%s", traceback.format_exc()) error_data = self._create_streaming_error_response(str(e)) yield f"data: {error_data}\n\n" finally: @@ -277,11 +292,11 @@ async def chat_completion_stream_generator( yield "data: [DONE]\n\n" async def chat_completion_full_generator( - self, - request: ChatCompletionRequest, - request_id: str, - model_name: str, - prompt_token_ids: list() + self, + request: ChatCompletionRequest, + request_id: str, + model_name: str, + prompt_token_ids: list() ): """ Full chat completion generator. @@ -298,6 +313,7 @@ async def chat_completion_full_generator( final_res = None previous_num_tokens = 0 current_waiting_time = 0 + logprob_contents = [] while True: try: raw_data = await asyncio.wait_for(dealer.read(), timeout=10) @@ -322,6 +338,21 @@ async def chat_completion_full_generator( data, stream=False, enable_thinking=enable_thinking) # api_server_logger.debug(f"Client {request_id} received: {data}") previous_num_tokens += len(data["outputs"]["token_ids"]) + # 处理响应的logprob + output = data["outputs"] + raw_top_logprobs = output["top_logprobs"] + if raw_top_logprobs is not None: + top_logprobs = LogprobsLists( + logprob_token_ids=raw_top_logprobs[0], + logprobs=raw_top_logprobs[1], + sampled_token_ranks=raw_top_logprobs[2], + ) + logprobs_res = self.build_logprobs_response( + logprobs=top_logprobs, + request_top_logprobs=request.top_logprobs, + ) + if logprobs_res and logprobs_res.content is not None: + logprob_contents.extend(logprobs_res.content) if data["finished"]: final_res = data break @@ -337,10 +368,15 @@ async def chat_completion_full_generator( tool_calls=output.get("tool_call_content"), token_ids=output.get("token_ids") ) - + logprobs_full_res = None + if logprob_contents: + logprobs_full_res = LogProbs( + content=logprob_contents + ) choice = ChatCompletionResponseChoice( index=0, message=message, + logprobs=logprobs_full_res, finish_reason=None ) if request.max_tokens is None or previous_num_tokens != request.max_tokens: @@ -350,7 +386,7 @@ async def chat_completion_full_generator( choice.finish_reason = "tool_calls" else: choice.finish_reason = "length" - + if final_res.get("error_msg") is not None and "Recover" in final_res["error_msg"]: choice.finish_reason = "recover_stop" choices.append(choice) @@ -371,3 +407,54 @@ async def chat_completion_full_generator( choices=choices, usage=usage ) + + def build_logprobs_response( + self, + logprobs: Optional[LogprobsLists], + request_top_logprobs: int, + ) -> Optional[LogProbs]: + """ + 构造符合 OpenAI 风格的 logprobs 响应对象。 + 保留完整 top-k 候选,避免循环引用。 + """ + + # 参数验证 + if ( + logprobs is None + or request_top_logprobs is None + or request_top_logprobs <= 0 + or len(logprobs.logprob_token_ids) == 0 + ): + return None + + try: + # 当前 token 的 top-k 候选 + topk_token_ids = logprobs.logprob_token_ids[0][:request_top_logprobs + 1] + topk_logprobs = logprobs.logprobs[0][:request_top_logprobs + 1] + + # 构造 topk 的候选 token 结构(LogProbEntry) + top_logprob_entries: List[LogProbEntry] = [] + for tid, lp in zip(topk_token_ids, topk_logprobs): + token_str = self.engine_client.data_processor.process_logprob_response([tid], + clean_up_tokenization_spaces=False) + # token_bytes = token_str.encode("utf-8", errors="replace") + entry = LogProbEntry( + token=token_str, + logprob=lp, + # bytes=list(token_bytes) + ) + top_logprob_entries.append(entry) + # 构造 sampled token 对象(避免与 top_logprob_entries 共享引用) + sampled_entry = LogProbEntry( + token=top_logprob_entries[0].token, + logprob=top_logprob_entries[0].logprob, + bytes=top_logprob_entries[0].bytes, + top_logprobs=top_logprob_entries[1:] # 这里是完整 topk 候选 + ) + + return LogProbs(content=[sampled_entry]) + + except Exception as e: + api_server_logger.error("Error in build_logprobs_response: %s", e) + api_server_logger.error(traceback.format_exc()) + return None diff --git a/fastdeploy/input/ernie_processor.py b/fastdeploy/input/ernie_processor.py index 51dbed7663..23ca0535e4 100644 --- a/fastdeploy/input/ernie_processor.py +++ b/fastdeploy/input/ernie_processor.py @@ -444,3 +444,7 @@ def update_stop_seq(self, stop_sequences): data_processor_logger.debug( f"processed stop_seqs: {stop_seqs}, {stop_seqs_len}") return stop_seqs, stop_seqs_len + + def process_logprob_response(self, token_ids, **kwargs): + full_text = self.tokenizer.decode(token_ids, **kwargs) + return full_text diff --git a/fastdeploy/input/text_processor.py b/fastdeploy/input/text_processor.py index ae2dc1f29c..9d30dee3e8 100644 --- a/fastdeploy/input/text_processor.py +++ b/fastdeploy/input/text_processor.py @@ -309,6 +309,10 @@ def process_request_dict(self, request, max_model_len=None, **kwargs): data_processor_logger.info(f"Processed request {request}") return request + def process_logprob_response(self, token_ids, **kwargs): + full_text = self.tokenizer.decode(token_ids, **kwargs) + return full_text + def process_response(self, response_dict, **kwargs): """ Preprocess the response diff --git a/fastdeploy/model_executor/layers/sample/meta_data.py b/fastdeploy/model_executor/layers/sample/meta_data.py index 44ebff8d3e..41a96ee1e8 100644 --- a/fastdeploy/model_executor/layers/sample/meta_data.py +++ b/fastdeploy/model_executor/layers/sample/meta_data.py @@ -42,3 +42,4 @@ class SamplingMetadata: top_p: paddle.Tensor top_k: Optional[paddle.Tensor] = None + max_num_logprobs: Optional[int] = None diff --git a/fastdeploy/model_executor/layers/sample/sampler.py b/fastdeploy/model_executor/layers/sample/sampler.py index 217776861c..537e1ae95f 100644 --- a/fastdeploy/model_executor/layers/sample/sampler.py +++ b/fastdeploy/model_executor/layers/sample/sampler.py @@ -29,6 +29,7 @@ apply_penalty_multi_scores, apply_speculative_penalty_multi_scores, top_p_sampling) from fastdeploy.platforms import current_platform +from fastdeploy.worker.output import LogprobsTensors, SamplerOutput class SamplerProcessor: @@ -189,14 +190,67 @@ def pre_process(self, skip_idx_list: List[int] = []): """ pre process before running """ self.processor.pre_process(skip_idx_list) + def compute_logprobs(self, logits: paddle.Tensor) -> paddle.Tensor: + """ + """ + return F.log_softmax(logits, axis=-1) + + def gather_logprobs( + self, + logprobs: paddle.Tensor, + num_logprobs: int, + token_ids: paddle.Tensor, + ) -> LogprobsTensors: + """ + Gather logprobs for topk and sampled/prompt token. + + Args: + logprobs: (num tokens) x (vocab) tensor + num_logprobs: minimum number of logprobs to + retain per token + token_ids: prompt tokens (if prompt logprobs) + or sampled tokens (if sampled + logprobs); 1D token ID tensor + with (num tokens) elements + Must be int64. + + Returns: + Top-k int indices tensor, (num tokens) x (num_logprobs + 1) + Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1) + Sampled token rank tensor, (num tokens) + """ + assert token_ids.dtype == paddle.int64 + # Get with the logprob of the prompt or sampled token. + token_logprobs = paddle.take_along_axis(logprobs, token_ids, axis=-1) + + # Compute the ranks of the actual token. + token_ranks = (logprobs >= token_logprobs).sum(-1) + + if num_logprobs >= 1: + # Find the topK values. + topk_logprobs, topk_indices = paddle.topk(logprobs, + num_logprobs, + axis=-1) + indices = paddle.concat([token_ids, topk_indices], axis=1) + top_logprobs = paddle.concat([token_logprobs, topk_logprobs], axis=1) + else: + indices = token_ids + top_logprobs = token_logprobs + + return LogprobsTensors(indices, top_logprobs, token_ranks) + def forward_cuda( self, logits: paddle.Tensor, sampling_metadata: SamplingMetadata, skip_idx_list: List[int] = [], - ) -> paddle.Tensor: + ) -> SamplerOutput: """ """ + num_logprobs = sampling_metadata.max_num_logprobs + if num_logprobs is not None: + raw_logprobs = self.compute_logprobs(logits) + logits = self.processor.apply_token_mask(logits, skip_idx_list) logits = apply_penalty_multi_scores( @@ -216,8 +270,19 @@ def forward_cuda( _, next_tokens = top_p_sampling(probs, sampling_metadata.top_p) + logprobs_tensors = None if num_logprobs is None else \ + self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=next_tokens) + self.processor.update_output_tokens(next_tokens, skip_idx_list) - return next_tokens + + sampler_output = SamplerOutput( + # The sampled tokens are expanded to 2D tensor with shape + # [num_requests, 1], where each row represents one generated + # token per request. + sampled_token_ids=next_tokens, + logprobs_tensors=logprobs_tensors, + ) + return sampler_output class SpeculativeSampler(nn.Layer): diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index 526197f2a4..268b7e62a6 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -20,6 +20,7 @@ from fastdeploy import envs from fastdeploy.engine.config import SpeculativeConfig from fastdeploy.platforms import current_platform + if current_platform.is_iluvatar(): from fastdeploy.model_executor.ops.iluvatar import ( get_padding_offset, save_output, set_stop_value_multi_ends, @@ -32,8 +33,10 @@ speculate_save_output, speculate_set_value_by_flags_and_idx, speculate_step_paddle, speculate_step_system_cache, speculate_update_v3, step_paddle, step_system_cache, update_inputs, - step_reschedule) -from fastdeploy.worker.output import ModelOutputData + step_reschedule, save_output_topk) + +from fastdeploy.worker.output import (ModelOutputData, ModelRunnerOutput, + SamplerOutput) DISABLE_RECOVER = (envs.FD_DISABLED_RECOVER == "1") @@ -109,10 +112,10 @@ def pre_process( cu_seqlens_k, output_cum_offsets, output_padding_offset) -def post_process_normal(sampled_token_ids: paddle.Tensor, +def post_process_normal(sampler_output: SamplerOutput, model_output: ModelOutputData, save_each_rank: bool = False, - skip_save_output: bool = False) -> None: + skip_save_output: bool = False) -> ModelRunnerOutput: """ Post-processing steps after completing a single token generation. """ # 1. Set stop value paddle.assign( @@ -130,7 +133,8 @@ def post_process_normal(sampled_token_ids: paddle.Tensor, model_output.stop_flags, ) # TODO(gongshaotian): Add use_stop_seqs - set_stop_value_multi_ends(sampled_token_ids, model_output.stop_flags, + set_stop_value_multi_ends(sampler_output.sampled_token_ids, + model_output.stop_flags, model_output.seq_lens_this_time, model_output.eos_token_id, model_output.next_tokens, False) # multi ends @@ -145,18 +149,27 @@ def post_process_normal(sampled_token_ids: paddle.Tensor, model_output.seq_lens_decoder, model_output.input_ids, model_output.stop_nums, - sampled_token_ids, + sampler_output.sampled_token_ids, model_output.is_block_step, ) # 3. Transmit the model's output and stop generation signal via message queue. # In the future, we will abandon this approach. if not skip_save_output: - save_output( - sampled_token_ids, - model_output.not_need_stop, - model_output.mp_rank, - save_each_rank, # save_each_rank - ) + if sampler_output.logprobs_tensors is None: + save_output( + sampler_output.sampled_token_ids, + model_output.not_need_stop, + model_output.mp_rank, + save_each_rank, # save_each_rank + ) + else: + save_output_topk( + sampler_output.logprobs_tensors.logprob_token_ids, + sampler_output.logprobs_tensors.logprobs, + sampler_output.logprobs_tensors.selected_token_ranks, + model_output.not_need_stop, + model_output.mp_rank, + ) def post_process_specualate(model_output, skip_save_output: bool = False): @@ -201,7 +214,7 @@ def post_process_specualate(model_output, skip_save_output: bool = False): ) -def post_process(sampled_token_ids: paddle.Tensor, +def post_process(sampler_output: SamplerOutput, model_output: ModelOutputData, save_each_rank: bool = False, speculative_decoding: bool = False, @@ -210,7 +223,7 @@ def post_process(sampled_token_ids: paddle.Tensor, if speculative_decoding: post_process_specualate(model_output, skip_save_output) else: - post_process_normal(sampled_token_ids, model_output, save_each_rank, + post_process_normal(sampler_output, model_output, save_each_rank, skip_save_output) diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index 136197f9cc..ac60c30e5a 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -30,9 +30,11 @@ from fastdeploy.metrics.metrics import main_process_metrics from fastdeploy.platforms import current_platform from fastdeploy.utils import llm_logger, spec_logger +from fastdeploy.worker.output import LogprobsLists RECOVERY_STOP_SIGNAL = -3 MAX_BSZ = 512 +K = 20 MAX_DRAFT_TOKENS = 6 SPECULATE_MAX_BSZ = 256 @@ -60,8 +62,15 @@ def __init__(self, cfg, cached_generated_tokens, engine_worker_queue, self.output_tokens = paddle.full(shape=[ SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2 ], - fill_value=2, - dtype="int64") + fill_value=2, + dtype="int64") + elif self.cfg.enable_logprob: + self.output_tokens = paddle.full( + shape=[MAX_BSZ * (K + 1) + 2, 1], fill_value=2, dtype="int64") + self.output_scores = paddle.full( + shape=[MAX_BSZ * (K + 1), 1], fill_value=0.0, dtype="float32") + self.output_ranks = paddle.full( + shape=[MAX_BSZ], fill_value=0, dtype="int64") else: self.output_tokens = paddle.full(shape=[MAX_BSZ + 2, 1], fill_value=2, @@ -110,11 +119,51 @@ def run(self): if self.worker is not None: raise Exception("Worker is already running!") - self.worker = threading.Thread(target=self.process_sampling_results, - args=()) + use_logprobs = ( + self.cfg.enable_logprob + and not self.speculative_decoding + and not self.cfg.parallel_config.enable_expert_parallel + ) + + target_func = ( + self.process_sampling_with_logprob_results + if use_logprobs else + self.process_sampling_results + ) + + self.worker = threading.Thread(target=target_func) + self.worker.daemon = True self.worker.start() + def process_sampling_with_logprob_results(self): + """ + read tokens from paddle inference engine and process logprob results + """ + + if current_platform.is_xpu(): + from fastdeploy.model_executor.ops.xpu import get_output_topk + else: + from fastdeploy.model_executor.ops.gpu import get_output_topk + rank_id = self.cfg.parallel_config.local_data_parallel_id + + while True: + try: + is_blocking = True + get_output_topk(self.output_tokens, self.output_scores, self.output_ranks, K, rank_id, is_blocking) + + if self.output_tokens[0, 0] == -2: + continue + llm_logger.debug( + f"rank_id {rank_id} self.output_tokens[0, 0] {self.output_tokens[0, 0]}" + f"rank_id {rank_id} self.output_scores[0, 0] {self.output_scores[0, 0]}" + ) + self._process_prefill_metrics() + self._process_sampling_with_logprob_batch_output() + except Exception as e: + llm_logger.info("while get input_data error: {0} {1}".format( + e, str(traceback.format_exc()))) + def process_sampling_results(self): """ read tokens from paddle inference engine and process @@ -125,10 +174,8 @@ def process_sampling_results(self): elif current_platform.is_iluvatar(): from fastdeploy.model_executor.ops.iluvatar import get_output else: - from fastdeploy.model_executor.ops.gpu import (get_output, - get_output_ep, - speculate_get_output - ) + from fastdeploy.model_executor.ops.gpu import ( + get_output, get_output_ep, speculate_get_output) rank_id = self.cfg.parallel_config.local_data_parallel_id while True: @@ -142,7 +189,7 @@ def process_sampling_results(self): else: if self.cfg.parallel_config.enable_expert_parallel and \ - self.cfg.parallel_config.data_parallel_size > 1: + self.cfg.parallel_config.data_parallel_size > 1: get_output_ep(self.output_tokens, rank_id, is_blocking) else: @@ -240,7 +287,7 @@ def _compute_speculative_status(self): ) if self.cfg.speculative_config.method in ["mtp"] and \ - self.cfg.speculative_config.num_speculative_tokens == 1: + self.cfg.speculative_config.num_speculative_tokens == 1: single_head_accep_ratio = accept_ratio / (1 - accept_ratio) spec_logger.info( f" Single head accept ratio: {single_head_accep_ratio}") @@ -249,6 +296,122 @@ def _compute_speculative_status(self): self.number_of_output_tokens = 0 self.total_step = 0 + def _process_sampling_with_logprob_batch_output(self): + """ + batch post-processing logprob output function + """ + + batch = self.output_tokens[1, 0] + tokens = self.output_tokens[2:batch * (K + 1) + 2].numpy().reshape( + [batch, K + 1])[:, :(K + 1)] + scores = self.output_scores[:batch * (K + 1)].numpy().reshape( + [batch, K + 1])[:, :(K + 1)] + ranks = self.output_ranks[:batch].numpy() + batch_result = list() + for i in range(batch): + if self.resource_manager.stop_flags[i]: + continue + task = self.resource_manager.tasks_list[i] + task_id = task.request_id + token_id = int(tokens[i, 0]) + token_ids = [token_id] + recovery_stop = token_id == RECOVERY_STOP_SIGNAL + if recovery_stop: + llm_logger.info( + f"recovery stop signal found at task {task_id}") + if not recovery_stop and token_id < 0: + continue + + if task.get("prefill_chunk_info", None) is not None: + prefill_chunk_num = task.get("prefill_chunk_num", 0) + task.prefill_chunk_num = prefill_chunk_num + 1 + + if task.prefill_chunk_num < len(task.prefill_chunk_info): + continue + + self.total_step += 1 + current_time = time.time() + if self.tokens_counter[task_id] == 0: + metrics = RequestMetrics( + arrival_time=task.arrival_time, + inference_start_time=task.inference_start_time, + first_token_time=time.time() - task.inference_start_time, + time_in_queue=task.schedule_start_time - + task.preprocess_end_time, + preprocess_cost_time=task.preprocess_end_time - + task.preprocess_start_time) + + self._record_first_token_metrics(task, current_time) + + else: + metrics = RequestMetrics( + arrival_time=time.time(), + request_start_time=task.arrival_time, + ) + self.number_of_output_tokens += len(token_ids) + self._record_metrics(task, current_time, token_ids) + result = RequestOutput(request_id=task_id, + outputs=CompletionOutput( + index=i, + send_idx=self.tokens_counter[task_id], + token_ids=[], + logprob = None, + draft_token_ids=[], + top_logprobs=None, + ), + finished=False, + metrics=metrics) + if self.tokens_counter[task_id] == 0: + if task.messages is not None: + result.prompt = task.messages + result.num_cached_tokens = task.num_cached_tokens + + is_prefill = task.disaggregate_info is not None and task.disaggregate_info[ + "role"] == "prefill" + + if is_prefill and len(token_ids) > 1: + result.outputs.draft_token_ids = copy.deepcopy(token_ids) + + for idx, token_id in enumerate(token_ids): + self.tokens_counter[task_id] += 1 + if token_id != RECOVERY_STOP_SIGNAL: + result.outputs.token_ids.append(token_id) + result.outputs.logprob = float(scores[i, 0]) + # 构造 top_logprobs + topk_token_ids = tokens[i, :].tolist() + topk_logprobs = scores[i, :].tolist() + sampled_rank = ranks[i].item() + + result.outputs.top_logprobs = LogprobsLists( + logprob_token_ids=[topk_token_ids], + logprobs=[topk_logprobs], + sampled_token_ranks=[sampled_rank] + ) + if token_id in task.eos_token_ids or is_prefill or recovery_stop: + result.finished = True + result.prompt = task.prompt + result.prompt_token_ids = task.prompt_token_ids + if recovery_stop: + result.error_msg = "Recover is not supported, the result is incomplete!" + llm_logger.info( + f"Request: {task_id} finished, number of " + f"generated tokens: {self.tokens_counter[task_id]}.") + llm_logger.info( + f"Request: {task_id} token ratio: {self.tokens_counter[task_id] / (time.time() - task.inference_start_time)}" + ) + llm_logger.info(f"{self.resource_manager.info()}") + if self.cfg.speculative_config.method: + self._compute_speculative_status() + if not is_prefill: + self._record_completion_metrics(task, current_time) + self._recycle_resources(task_id, i, task, result, + is_prefill) + break + if not is_prefill or self.cfg.scheduler_config.name == "splitwise": + batch_result.append(result) + + self.postprocess(batch_result) + def _process_batch_output(self): """ batch post-processing function @@ -420,9 +583,8 @@ def process_sampling_results(self): elif current_platform.is_iluvatar(): from fastdeploy.model_executor.ops.iluvatar import get_output else: - from fastdeploy.model_executor.ops.gpu import (get_output, - speculate_get_output - ) + from fastdeploy.model_executor.ops.gpu import ( + get_output, speculate_get_output) while self._is_running: try: diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 8d6ca79a1b..6183cbf5fb 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -63,7 +63,8 @@ def __init__( self.device_id = device_id self.speculative_method = self.fd_config.speculative_config.method self.speculative_decoding = self.speculative_method is not None - + self.enable_logprob = fd_config.model_config.enable_logprob + self.max_num_logprobs = 20 if fd_config.model_config.enable_logprob else None self.guided_backend = None if self.fd_config.parallel_config.guided_decoding_backend != "off": self.guided_backend = get_guided_backend(fd_config=self.fd_config) @@ -582,6 +583,7 @@ def _prepare_inputs(self) -> None: min_dec_lens=self.share_inputs["min_dec_len"], bad_words_token_ids=self.share_inputs["bad_tokens"], eos_token_ids=self.share_inputs["eos_token_id"], + max_num_logprobs=self.max_num_logprobs, ) def load_model(self) -> None: @@ -786,15 +788,15 @@ def _dummy_run(self, self.share_inputs["step_idx"], self.share_inputs["stop_flags"], ) - sampled_token_ids = self.sampler(logits, + sampler_output = self.sampler(logits, self.sampling_metadata) if self.parallel_config.tensor_parallel_degree > 1: - paddle.distributed.broadcast(sampled_token_ids, 0) + paddle.distributed.broadcast(sampler_output.sampled_token_ids, 0) else: self.sampler(logits, self.sampling_metadata, self.parallel_config.max_model_len, self.share_inputs) - sampled_token_ids = None + sampler_output = None if self.parallel_config.tensor_parallel_degree > 1: paddle.distributed.broadcast( self.share_inputs["accept_tokens"], 0) @@ -834,7 +836,7 @@ def _dummy_run(self, accept_num=self.share_inputs["accept_num"] if self.speculative_decoding else None) - post_process(sampled_token_ids=sampled_token_ids, + post_process(sampler_output=sampler_output, model_output=model_output_data, speculative_decoding=self.speculative_decoding, skip_save_output=True) @@ -1021,18 +1023,18 @@ class at the server level, which is too granular for ModelRunner. self.share_inputs["step_idx"], self.share_inputs["stop_flags"], ) - sampled_token_ids = self.sampler( + sampler_output = self.sampler( logits, self.sampling_metadata, skip_idx_list, ) if self.parallel_config.tensor_parallel_degree > 1: - paddle.distributed.broadcast(sampled_token_ids, 0) + paddle.distributed.broadcast(sampler_output.sampled_token_ids, 0) else: self.sampler(logits, self.sampling_metadata, self.parallel_config.max_model_len, self.share_inputs) - sampled_token_ids = None + sampler_output = None if self.parallel_config.tensor_parallel_degree > 1: paddle.distributed.broadcast( self.share_inputs["accept_tokens"], 0) @@ -1075,7 +1077,7 @@ class at the server level, which is too granular for ModelRunner. skip_save_output = True else: skip_save_output = False - post_process(sampled_token_ids=sampled_token_ids, + post_process(sampler_output=sampler_output, model_output=model_output_data, save_each_rank=self.parallel_config.use_ep, speculative_decoding=self.speculative_decoding, diff --git a/fastdeploy/worker/output.py b/fastdeploy/worker/output.py index 7d3c1198fb..209eb52eb6 100644 --- a/fastdeploy/worker/output.py +++ b/fastdeploy/worker/output.py @@ -15,11 +15,80 @@ """ from dataclasses import dataclass -from typing import Optional +from typing import NamedTuple, Optional import paddle +class LogprobsLists(NamedTuple): + """ + """ + + # [num_reqs, max_num_logprobs + 1] + logprob_token_ids: list[list[int]] + # [num_reqs, max_num_logprobs + 1] + logprobs: list[list[float]] + # [num_reqs] + sampled_token_ranks: list[int] + + def slice(self, start: int, end: int): + """slice""" + return LogprobsLists( + self.logprob_token_ids[start:end], + self.logprobs[start:end], + self.sampled_token_ranks[start:end], + ) + + +class LogprobsTensors(NamedTuple): + """ + """ + + # [num_reqs, max_num_logprobs + 1] + logprob_token_ids: paddle.Tensor + # [num_reqs, max_num_logprobs + 1] + logprobs: paddle.Tensor + # [num_reqs] + selected_token_ranks: paddle.Tensor + + def tolists(self): + """Convert to lists.""" + return LogprobsLists( + self.logprob_token_ids.tolist(), + self.logprobs.tolist(), + self.selected_token_ranks.tolist(), + ) + + @staticmethod + def empty_cpu(num_positions: int, + num_tokens_per_position: int) -> "LogprobsTensors": + """Create empty LogprobsTensors on CPU.""" + + logprob_token_ids = paddle.empty( + [num_positions, num_tokens_per_position], + dtype=paddle.int64).cpu() + logprobs = paddle.empty_like(logprob_token_ids, dtype=paddle.float32) + selected_token_ranks = paddle.empty([num_positions], + dtype=paddle.int64).cpu() + return LogprobsTensors( + logprob_token_ids=logprob_token_ids, + logprobs=logprobs, + selected_token_ranks=selected_token_ranks, + ) + + +@dataclass +class SamplerOutput: + """ + """ + + # [num_reqs, max_num_generated_tokens] + # Different requests can have different number of generated tokens. + # All requests are padded to max_num_generated_tokens. + # PLACEHOLDER_TOKEN_ID (-1 by default) is used for padding. + sampled_token_ids: paddle.Tensor + logprobs_tensors: Optional[LogprobsTensors] + @dataclass class ModelOutputData: """ @@ -158,3 +227,8 @@ class ModelRunnerOutput: [num_reqs, num_spec_tokens] """ spec_token_ids: Optional[list[list[int]]] + + # [num_reqs, max_num_logprobs + 1] + # [num_reqs, max_num_logprobs + 1] + # [num_reqs] + logprobs: Optional[LogprobsLists] diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 39c33574ca..8102ad6a29 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -557,6 +557,9 @@ def parse_args(): "'ipc_snapshot': load from disk snapshot of IPC weights, " "'meta': provide RL traing worker, no_weights_load" "'normal':normal load weight") + parser.add_argument("--enable_logprob", + action='store_true', + help="Enable output of token-level log probabilities.") args = parser.parse_args() return args @@ -757,6 +760,7 @@ def initialize_fd_config(args: argparse.Namespace) -> FDConfig: ) model_config.architectures = config.get("architectures") + model_config.enable_logprob = args.enable_logprob logger.info("===========load_config==============") load_config.dynamic_load_weight = args.dynamic_load_weight