Skip to content

Logprob #2775

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: develop
Choose a base branch
from
Open

Logprob #2775

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions custom_ops/gpu_ops/get_output_msg_with_topk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,18 @@
#endif

#define MAX_BSZ 512
#define K 10
#define K 20

struct msgdata {
long mtype;
int mtext[MAX_BSZ * (K + 1) + 2]; // stop_flag, bsz, tokens
float mtext_f[MAX_BSZ * (K + 1)]; // score
int mtext_ranks[MAX_BSZ]; // ranks
};

void GetOutputTopK(const paddle::Tensor& x,
const paddle::Tensor& scores,
const paddle::Tensor& ranks,
int k,
int64_t rank_id,
bool wait_flag) {
Expand Down Expand Up @@ -66,17 +68,18 @@ void GetOutputTopK(const paddle::Tensor& x,

int64_t* out_data = const_cast<int64_t*>(x.data<int64_t>());
float* scores_data = const_cast<float*>(scores.data<float>());
int64_t* ranks_data = const_cast<int64_t*>(ranks.data<int64_t>());
int ret = -1;
if (!wait_flag) {
ret = msgrcv(msgid,
&msg_rcv,
(MAX_BSZ * (K + 1) + 2) * 4 + MAX_BSZ * (K + 1) * 4,
(MAX_BSZ * (K + 1) + 2) * 4 + MAX_BSZ * (K + 1) * 4 + MAX_BSZ * 4,
0,
IPC_NOWAIT);
} else {
ret = msgrcv(msgid,
&msg_rcv,
(MAX_BSZ * (K + 1) + 2) * 4 + MAX_BSZ * (K + 1) * 4,
(MAX_BSZ * (K + 1) + 2) * 4 + MAX_BSZ * (K + 1) * 4 + MAX_BSZ * 4,
0,
0);
}
Expand All @@ -97,13 +100,14 @@ void GetOutputTopK(const paddle::Tensor& x,
out_data[offset + 2] = (int64_t)msg_rcv.mtext[offset + 2];
scores_data[offset] = msg_rcv.mtext_f[offset];
}
ranks_data[i] = (int64_t)msg_rcv.mtext_ranks[i];
}
return;
}

PD_BUILD_STATIC_OP(get_output_topk)
.Inputs({"x", "scores"})
.Inputs({"x", "scores", "ranks"})
.Attrs({"k: int", "rank_id: int64_t", "wait_flag: bool"})
.Outputs({"x_out", "scores_out"})
.SetInplaceMap({{"x", "x_out"}, {"scores", "scores_out"}})
.Outputs({"x_out", "scores_out", "ranks_out"})
.SetInplaceMap({{"x", "x_out"}, {"scores", "scores_out"}, {"ranks", "ranks_out"}})
.SetKernelFn(PD_KERNEL(GetOutputTopK));
38 changes: 15 additions & 23 deletions custom_ops/gpu_ops/save_output_msg_with_topk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,34 +23,31 @@
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif

#define MAX_BSZ 128
#define K 10
#define MAX_BSZ 512
#define K 20
// #define SAVE_WITH_OUTPUT_DEBUG

struct msgdata {
long mtype;
int mtext[MAX_BSZ * (K + 1) + 2]; // stop_flag, bsz, tokens
float mtext_f[MAX_BSZ * (K + 1)]; // score
int mtext_ranks[MAX_BSZ]; // ranks
};

void SaveOutMmsgTopK(const paddle::Tensor& x,
const paddle::Tensor& scores,
const paddle::Tensor& topk_ids,
const paddle::Tensor& topk_scores, // [bsz, k]
const paddle::Tensor& ranks,
const paddle::Tensor& not_need_stop,
int k,
int64_t rank_id) {
if (rank_id > 0) {
return;
}
auto x_cpu = x.copy_to(paddle::CPUPlace(), false);
auto scores_cpu = scores.copy_to(paddle::CPUPlace(), false);
auto topk_ids_cpu = topk_ids.copy_to(paddle::CPUPlace(), false);
auto topk_scores_cpu = topk_scores.copy_to(paddle::CPUPlace(), false);
auto ranks_cpu = ranks.copy_to(paddle::CPUPlace(), false);
int64_t* x_data = x_cpu.data<int64_t>();
float* scores_data = scores_cpu.data<float>();
int64_t* topk_ids_data = topk_ids_cpu.data<int64_t>();
float* topk_scores_data = topk_scores_cpu.data<float>();
int64_t* ranks_data = ranks_cpu.data<int64_t>();
static struct msgdata msg_sed;
int msg_queue_id = 1;
if (const char* inference_msg_queue_id_env_p =
Expand Down Expand Up @@ -106,21 +103,16 @@ void SaveOutMmsgTopK(const paddle::Tensor& x,
msg_sed.mtext[0] = not_need_stop_data ? inference_msg_id_from_env
: -inference_msg_id_from_env;
int bsz = x.shape()[0];
int token_num = x.shape()[1];
int k = token_num - 1;
msg_sed.mtext[1] = bsz;
for (int i = 0; i < bsz; i++) {
for (int j = 0; j < k + 1; j++) {
for (int j = 0; j < token_num; j++) {
const int64_t offset = i * (K + 1) + j;
if (j == 0) {
msg_sed.mtext[offset + 2] = (int)x_data[i];
msg_sed.mtext_f[offset] = scores_data[i];
} else if (j <= k + 1) {
msg_sed.mtext[offset + 2] = (int)topk_ids_data[i * k + j - 1];
msg_sed.mtext_f[offset] = topk_scores_data[i * k + j - 1];
} else {
msg_sed.mtext[offset + 2] = -1;
msg_sed.mtext_f[offset] = 0.0;
}
msg_sed.mtext[offset + 2] = (int)x_data[i * token_num + j];
msg_sed.mtext_f[offset] = scores_data[i * token_num + j];
}
msg_sed.mtext_ranks[i] = (int)ranks_data[i];
}
#ifdef SAVE_WITH_OUTPUT_DEBUG
std::cout << "msg data: ";
Expand All @@ -131,16 +123,16 @@ void SaveOutMmsgTopK(const paddle::Tensor& x,
#endif
if ((msgsnd(msgid,
&msg_sed,
(MAX_BSZ * (K + 1) + 2) * 4 + (MAX_BSZ * (K + 1)) * 4,
(MAX_BSZ * (K + 1) + 2) * 4 + (MAX_BSZ * (K + 1)) * 4 + MAX_BSZ * 4,
0)) == -1) {
printf("full msg buffer\n");
}
return;
}

PD_BUILD_STATIC_OP(save_output_topk)
.Inputs({"x", "scores", "topk_ids", "topk_scores", "not_need_stop"})
.Attrs({"k: int", "rank_id: int64_t"})
.Inputs({"x", "scores", "ranks", "not_need_stop"})
.Attrs({"rank_id: int64_t"})
.Outputs({"x_out"})
.SetInplaceMap({{"x", "x_out"}})
.SetKernelFn(PD_KERNEL(SaveOutMmsgTopK));
1 change: 1 addition & 0 deletions custom_ops/setup_ops_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"gpu_ops/save_with_output_msg.cc",
"gpu_ops/get_output.cc",
"gpu_ops/get_output_msg_with_topk.cc",
"gpu_ops/save_output_msg_with_topk.cc",
"gpu_ops/transfer_output.cc",
"cpu_ops/rebuild_padding.cc",
],
Expand Down
11 changes: 11 additions & 0 deletions fastdeploy/engine/args_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,11 @@ class EngineArgs:
Example:
max_capture_batch_size=64, FastDeploy will capture graphs for batches [1,64].
"""
enable_logprob: bool = False
"""
Flag to enable logprob output. Default is False (disabled).
Must be explicitly enabled via the `--enable-logprob` startup parameter to output logprob values.
"""

def __post_init__(self):
"""
Expand Down Expand Up @@ -413,6 +418,11 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
help=
"Disabled any whitespaces when using guided decoding backend XGrammar."
)
model_group.add_argument("--enable-logprob",
action="store_true",
default=EngineArgs.enable_logprob,
help="Enable output of token-level log probabilities."
)

# Parallel processing parameters group
parallel_group = parser.add_argument_group("Parallel Configuration")
Expand Down Expand Up @@ -784,4 +794,5 @@ def create_engine_config(self) -> Config:
max_capture_batch_size=self.max_capture_batch_size,
guided_decoding_backend=self.guided_decoding_backend,
disable_any_whitespace=self.guided_decoding_disable_any_whitespace,
enable_logprob = self.enable_logprob,
)
9 changes: 9 additions & 0 deletions fastdeploy/engine/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,8 @@ def __init__(
max_capture_batch_size: int = 64,
guided_decoding_backend: Optional[str] = None,
disable_any_whitespace: bool = False,
enable_logprob: bool = False,
max_logprobs: int = None,
):
"""
Initialize the Config class.
Expand Down Expand Up @@ -621,6 +623,13 @@ def __init__(
self.device_ids = ",".join([str(i) for i in range(self.worker_num_per_node)])
self.device_ids = os.getenv("CUDA_VISIBLE_DEVICES", self.device_ids)

self.enable_logprob = enable_logprob
# 现有架构下只能通过指定 max_logprobs 来申请通信缓冲区,默认使用 10(但前提是 enable_logprob 为 True)
if enable_logprob:
self.max_logprobs = 10 if max_logprobs is None else max_logprobs
else:
self.max_logprobs = 0

self.read_from_config()
self.postprocess()
self.check()
Expand Down
5 changes: 3 additions & 2 deletions fastdeploy/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -998,8 +998,8 @@ def _start_worker_service(self):
py_script = os.path.join(current_dir_path, worker_path)

ori_vocab_size = (
len(self.data_processor.tokenizer.sp_model)
if hasattr(self.data_processor.tokenizer, 'sp_model')
len(self.data_processor.tokenizer.sp_model)
if hasattr(self.data_processor.tokenizer, 'sp_model')
else len(self.data_processor.tokenizer.vocab)
)

Expand Down Expand Up @@ -1047,6 +1047,7 @@ def _start_worker_service(self):
self.cfg.enable_static_graph_inference,
"use_cudagraph": self.cfg.use_cudagraph,
"disable_any_whitespace": self.cfg.disable_any_whitespace,
"enable_logprob": self.cfg.enable_logprob,
}
for worker_flag, value in worker_append_flag.items():
if value:
Expand Down
5 changes: 5 additions & 0 deletions fastdeploy/engine/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import time
from dataclasses import asdict, dataclass, fields
from typing import Any, Dict, Optional, Union
from fastdeploy.worker.output import LogprobsLists

import numpy

Expand Down Expand Up @@ -189,6 +190,8 @@ class CompletionOutput:
index: int
send_idx: int
token_ids: list[int]
logprob: Optional[float] = None
top_logprobs: Optional[LogprobsLists] = None
draft_token_ids: list[int] = None
text: Optional[str] = None
reasoning_content: Optional[str] = None
Expand All @@ -201,6 +204,8 @@ def to_dict(self):
"index": self.index,
"send_idx": self.send_idx,
"token_ids": self.token_ids,
"logprob": self.logprob,
"top_logprobs": self.top_logprobs,
"draft_token_ids": self.draft_token_ids,
"text": self.text,
"reasoning_content": self.reasoning_content
Expand Down
17 changes: 16 additions & 1 deletion fastdeploy/engine/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ class SamplingParams:
the model more random. Zero means greedy sampling.
top_p: Float that controls the cumulative probability of the top tokens
to consider. Must be in [0, 1]. Set to 1 to consider all tokens.
top_k: Integer that controls the number of top tokens to consider. Set
to 0 (or -1) to consider all tokens.
seed: Random seed to use for the generation.
stop: list of strings that stop the generation when they are generated.
The returned output will not contain the stop strings.
Expand Down Expand Up @@ -82,6 +84,7 @@ class SamplingParams:
repetition_penalty: float = None
temperature: float = None
top_p: float = None
top_k: int = 0
seed: Optional[int] = None
stop: Optional[Union[str, List[str]]] = None
stop_token_ids: Optional[Union[List[List[int]], List[int]]] = None
Expand Down Expand Up @@ -111,6 +114,7 @@ def from_optional(cls,
repetition_penalty,
temperature,
top_p,
top_k,
seed=None,
stop=None,
stop_token_ids=None,
Expand All @@ -130,6 +134,7 @@ def from_optional(cls,
if repetition_penalty is not None else 1.0,
temperature=temperature if temperature is not None else 1.0,
top_p=top_p if top_p is not None else 0.7,
top_k=top_k,
seed=seed,
stop=stop,
stop_token_ids=stop_token_ids,
Expand Down Expand Up @@ -169,7 +174,13 @@ def _verify_args(self) -> None:
f"temperature must be non-negative, got {self.temperature}.")
if self.top_p is not None and not 0.0 <= self.top_p <= 1.0:
raise ValueError(f"top_p must be in [0, 1], got {self.top_p}.")

# quietly accept -1 as disabled, but prefer 0
if self.top_k < -1:
raise ValueError(f"top_k must be 0 (disable), or at least 1, "
f"got {self.top_k}.")
if not isinstance(self.top_k, int):
raise TypeError(
f"top_k must be an integer, got {type(self.top_k).__name__}")
if self.max_tokens is not None and self.max_tokens < 1:
raise ValueError(
f"max_tokens must be at least 1, got {self.max_tokens}.")
Expand All @@ -189,6 +200,10 @@ def _verify_args(self) -> None:
raise ValueError(
f"logprobs must be non-negative, got {self.logprobs}.")

if self.logprobs is not None and self.logprobs > 20:
raise ValueError(
f"Invalid value for 'top_logprobs': must be less than or equal to 20.")

if not 0 <= self.seed <= 922337203685477580:
raise ValueError("seed must be in [0, 922337203685477580], got "
f"{self.seed}.")
Expand Down
32 changes: 31 additions & 1 deletion fastdeploy/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ class ChatCompletionResponseChoice(BaseModel):
"""
index: int
message: ChatMessage
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop"]]


Expand All @@ -135,7 +136,15 @@ class ChatCompletionResponse(BaseModel):
model: str
choices: List[ChatCompletionResponseChoice]
usage: UsageInfo
class LogProbEntry(BaseModel):
token: str
logprob: float
bytes: Optional[List[int]] = None
top_logprobs: Optional[List["LogProbEntry"]] = None

class LogProbs(BaseModel):
content: Optional[List[LogProbEntry]] = None
refusal: Optional[Union[str, None]] = None

class DeltaMessage(BaseModel):
"""
Expand All @@ -154,6 +163,7 @@ class ChatCompletionResponseStreamChoice(BaseModel):
"""
index: int
delta: DeltaMessage
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None
arrival_time: Optional[float] = None

Expand Down Expand Up @@ -391,7 +401,9 @@ class ChatCompletionRequest(BaseModel):
tools: Optional[List[ChatCompletionToolsParam]] = None
model: Optional[str] = "default"
frequency_penalty: Optional[float] = None
# remove max_tokens when field is removed from OpenAI API
logprobs: Optional[bool] = False
top_logprobs: Optional[int] = 0
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens: Optional[int] = Field(
default=None,
deprecated=
Expand Down Expand Up @@ -432,6 +444,10 @@ def to_dict_for_infer(self, request_id=None):
if request_id is not None:
req_dict['request_id'] = request_id

req_dict["max_tokens"] = self.max_completion_tokens or self.max_tokens

req_dict["logprobs"] = self.top_logprobs if self.logprobs else None

if self.metadata is not None:
for key, value in self.metadata.items():
req_dict[key] = value
Expand Down Expand Up @@ -503,3 +519,17 @@ def validate_stream_options(cls, data):
)

return data
@model_validator(mode="before")
@classmethod
def check_logprobs(cls, data):

if (top_logprobs := data.get("top_logprobs")) is not None:
if top_logprobs < 0:
raise ValueError("`top_logprobs` must be a positive value.")

if top_logprobs > 0 and not data.get("logprobs"):
raise ValueError(
"when using `top_logprobs`, `logprobs` must be set to true."
)

return data
Loading