Skip to content

Commit d33105b

Browse files
authored
[Feature] Online Chat API Support Return logprobs (#2777)
* online chat support logprobs * check xpu * check vl_gpu_model_runner and xpu_model_runner * get_worker() check platform
1 parent 24f934f commit d33105b

22 files changed

+608
-114
lines changed

custom_ops/gpu_ops/get_output_msg_with_topk.cc

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,18 @@
2424
#endif
2525

2626
#define MAX_BSZ 512
27-
#define K 10
27+
#define K 20
2828

2929
struct msgdata {
3030
long mtype;
3131
int mtext[MAX_BSZ * (K + 1) + 2]; // stop_flag, bsz, tokens
3232
float mtext_f[MAX_BSZ * (K + 1)]; // score
33+
int mtext_ranks[MAX_BSZ]; // ranks
3334
};
3435

3536
void GetOutputTopK(const paddle::Tensor& x,
3637
const paddle::Tensor& scores,
38+
const paddle::Tensor& ranks,
3739
int k,
3840
int64_t rank_id,
3941
bool wait_flag) {
@@ -66,17 +68,18 @@ void GetOutputTopK(const paddle::Tensor& x,
6668

6769
int64_t* out_data = const_cast<int64_t*>(x.data<int64_t>());
6870
float* scores_data = const_cast<float*>(scores.data<float>());
71+
int64_t* ranks_data = const_cast<int64_t*>(ranks.data<int64_t>());
6972
int ret = -1;
7073
if (!wait_flag) {
7174
ret = msgrcv(msgid,
7275
&msg_rcv,
73-
(MAX_BSZ * (K + 1) + 2) * 4 + MAX_BSZ * (K + 1) * 4,
76+
(MAX_BSZ * (K + 1) + 2) * 4 + MAX_BSZ * (K + 1) * 4 + MAX_BSZ * 4,
7477
0,
7578
IPC_NOWAIT);
7679
} else {
7780
ret = msgrcv(msgid,
7881
&msg_rcv,
79-
(MAX_BSZ * (K + 1) + 2) * 4 + MAX_BSZ * (K + 1) * 4,
82+
(MAX_BSZ * (K + 1) + 2) * 4 + MAX_BSZ * (K + 1) * 4 + MAX_BSZ * 4,
8083
0,
8184
0);
8285
}
@@ -97,13 +100,14 @@ void GetOutputTopK(const paddle::Tensor& x,
97100
out_data[offset + 2] = (int64_t)msg_rcv.mtext[offset + 2];
98101
scores_data[offset] = msg_rcv.mtext_f[offset];
99102
}
103+
ranks_data[i] = (int64_t)msg_rcv.mtext_ranks[i];
100104
}
101105
return;
102106
}
103107

104108
PD_BUILD_STATIC_OP(get_output_topk)
105-
.Inputs({"x", "scores"})
109+
.Inputs({"x", "scores", "ranks"})
106110
.Attrs({"k: int", "rank_id: int64_t", "wait_flag: bool"})
107-
.Outputs({"x_out", "scores_out"})
108-
.SetInplaceMap({{"x", "x_out"}, {"scores", "scores_out"}})
111+
.Outputs({"x_out", "scores_out", "ranks_out"})
112+
.SetInplaceMap({{"x", "x_out"}, {"scores", "scores_out"}, {"ranks", "ranks_out"}})
109113
.SetKernelFn(PD_KERNEL(GetOutputTopK));

custom_ops/gpu_ops/save_output_msg_with_topk.cc

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -23,34 +23,34 @@
2323
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
2424
#endif
2525

26-
#define MAX_BSZ 128
27-
#define K 10
26+
#define MAX_BSZ 512
27+
#define K 20
2828
// #define SAVE_WITH_OUTPUT_DEBUG
2929

3030
struct msgdata {
3131
long mtype;
3232
int mtext[MAX_BSZ * (K + 1) + 2]; // stop_flag, bsz, tokens
3333
float mtext_f[MAX_BSZ * (K + 1)]; // score
34+
int mtext_ranks[MAX_BSZ]; // ranks
3435
};
3536

3637
void SaveOutMmsgTopK(const paddle::Tensor& x,
37-
const paddle::Tensor& scores,
38-
const paddle::Tensor& topk_ids,
39-
const paddle::Tensor& topk_scores, // [bsz, k]
38+
const paddle::Tensor& logprob_token_ids, // [bsz, k+1]
39+
const paddle::Tensor& logprob_scores, // [bsz, k+1]
40+
const paddle::Tensor& ranks,
4041
const paddle::Tensor& not_need_stop,
41-
int k,
4242
int64_t rank_id) {
4343
if (rank_id > 0) {
4444
return;
4545
}
4646
auto x_cpu = x.copy_to(paddle::CPUPlace(), false);
47-
auto scores_cpu = scores.copy_to(paddle::CPUPlace(), false);
48-
auto topk_ids_cpu = topk_ids.copy_to(paddle::CPUPlace(), false);
49-
auto topk_scores_cpu = topk_scores.copy_to(paddle::CPUPlace(), false);
47+
auto logprob_token_ids_cpu = logprob_token_ids.copy_to(paddle::CPUPlace(), false);
48+
auto logprob_scores_cpu = logprob_scores.copy_to(paddle::CPUPlace(), false);
49+
auto ranks_cpu = ranks.copy_to(paddle::CPUPlace(), false);
5050
int64_t* x_data = x_cpu.data<int64_t>();
51-
float* scores_data = scores_cpu.data<float>();
52-
int64_t* topk_ids_data = topk_ids_cpu.data<int64_t>();
53-
float* topk_scores_data = topk_scores_cpu.data<float>();
51+
int64_t* logprob_token_ids_data = logprob_token_ids_cpu.data<int64_t>();
52+
float* logprob_scores_data = logprob_scores_cpu.data<float>();
53+
int64_t* ranks_data = ranks_cpu.data<int64_t>();
5454
static struct msgdata msg_sed;
5555
int msg_queue_id = 1;
5656
if (const char* inference_msg_queue_id_env_p =
@@ -106,21 +106,23 @@ void SaveOutMmsgTopK(const paddle::Tensor& x,
106106
msg_sed.mtext[0] = not_need_stop_data ? inference_msg_id_from_env
107107
: -inference_msg_id_from_env;
108108
int bsz = x.shape()[0];
109+
int max_num_logprobs = logprob_token_ids.shape()[1];
109110
msg_sed.mtext[1] = bsz;
110111
for (int i = 0; i < bsz; i++) {
111-
for (int j = 0; j < k + 1; j++) {
112+
for (int j = 0; j < K + 1; j++) {
112113
const int64_t offset = i * (K + 1) + j;
113114
if (j == 0) {
114115
msg_sed.mtext[offset + 2] = (int)x_data[i];
115-
msg_sed.mtext_f[offset] = scores_data[i];
116-
} else if (j <= k + 1) {
117-
msg_sed.mtext[offset + 2] = (int)topk_ids_data[i * k + j - 1];
118-
msg_sed.mtext_f[offset] = topk_scores_data[i * k + j - 1];
116+
msg_sed.mtext_f[offset] = logprob_scores_data[i * max_num_logprobs + j];
117+
} else if (j < max_num_logprobs) {
118+
msg_sed.mtext[offset + 2] = (int)logprob_token_ids_data[i * max_num_logprobs + j];
119+
msg_sed.mtext_f[offset] = logprob_scores_data[i * max_num_logprobs + j];
119120
} else {
120121
msg_sed.mtext[offset + 2] = -1;
121122
msg_sed.mtext_f[offset] = 0.0;
122123
}
123124
}
125+
msg_sed.mtext_ranks[i] = (int)ranks_data[i];
124126
}
125127
#ifdef SAVE_WITH_OUTPUT_DEBUG
126128
std::cout << "msg data: ";
@@ -131,16 +133,16 @@ void SaveOutMmsgTopK(const paddle::Tensor& x,
131133
#endif
132134
if ((msgsnd(msgid,
133135
&msg_sed,
134-
(MAX_BSZ * (K + 1) + 2) * 4 + (MAX_BSZ * (K + 1)) * 4,
136+
(MAX_BSZ * (K + 1) + 2) * 4 + (MAX_BSZ * (K + 1)) * 4 + MAX_BSZ * 4,
135137
0)) == -1) {
136138
printf("full msg buffer\n");
137139
}
138140
return;
139141
}
140142

141143
PD_BUILD_STATIC_OP(save_output_topk)
142-
.Inputs({"x", "scores", "topk_ids", "topk_scores", "not_need_stop"})
143-
.Attrs({"k: int", "rank_id: int64_t"})
144+
.Inputs({"x", "topk_ids", "logprob_scores", "ranks", "not_need_stop"})
145+
.Attrs({"rank_id: int64_t"})
144146
.Outputs({"x_out"})
145147
.SetInplaceMap({{"x", "x_out"}})
146148
.SetKernelFn(PD_KERNEL(SaveOutMmsgTopK));

custom_ops/setup_ops_base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
"gpu_ops/save_with_output_msg.cc",
2323
"gpu_ops/get_output.cc",
2424
"gpu_ops/get_output_msg_with_topk.cc",
25+
"gpu_ops/save_output_msg_with_topk.cc",
2526
"gpu_ops/transfer_output.cc",
2627
"cpu_ops/rebuild_padding.cc",
2728
],

fastdeploy/engine/args_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,12 @@ class EngineArgs:
299299
max_capture_batch_size=64, FastDeploy will capture graphs for batches [1,64].
300300
"""
301301

302+
enable_logprob: bool = False
303+
"""
304+
Flag to enable logprob output. Default is False (disabled).
305+
Must be explicitly enabled via the `--enable-logprob` startup parameter to output logprob values.
306+
"""
307+
302308
def __post_init__(self):
303309
"""
304310
Post-initialization processing to set default tokenizer if not provided.
@@ -419,6 +425,11 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
419425
help=
420426
"Disabled any whitespaces when using guided decoding backend XGrammar."
421427
)
428+
model_group.add_argument("--enable-logprob",
429+
action="store_true",
430+
default=EngineArgs.enable_logprob,
431+
help="Enable output of token-level log probabilities."
432+
)
422433

423434
# Parallel processing parameters group
424435
parallel_group = parser.add_argument_group("Parallel Configuration")
@@ -799,4 +810,5 @@ def create_engine_config(self) -> Config:
799810
guided_decoding_backend=self.guided_decoding_backend,
800811
disable_any_whitespace=self.guided_decoding_disable_any_whitespace,
801812
enable_custom_all_reduce=self.enable_custom_all_reduce,
813+
enable_logprob = self.enable_logprob,
802814
)

fastdeploy/engine/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -590,6 +590,7 @@ def __init__(
590590
guided_decoding_backend: Optional[str] = None,
591591
disable_any_whitespace: bool = False,
592592
enable_custom_all_reduce: bool = False,
593+
enable_logprob: bool = False,
593594
):
594595
"""
595596
Initialize the Config class.
@@ -686,6 +687,8 @@ def __init__(
686687
self.device_ids = ",".join([str(i) for i in range(self.worker_num_per_node)])
687688
self.device_ids = os.getenv("CUDA_VISIBLE_DEVICES", self.device_ids)
688689

690+
self.enable_logprob = enable_logprob
691+
689692
self.read_from_config()
690693
self.postprocess()
691694
self.check()

fastdeploy/engine/engine.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1052,6 +1052,7 @@ def _start_worker_service(self):
10521052
"use_cudagraph": self.cfg.use_cudagraph,
10531053
"disable_any_whitespace": self.cfg.disable_any_whitespace,
10541054
"enable-custom-all-reduce": self.cfg.parallel_config.enable_custom_all_reduce,
1055+
"enable_logprob": self.cfg.enable_logprob,
10551056
}
10561057
for worker_flag, value in worker_append_flag.items():
10571058
if value:

fastdeploy/engine/request.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
from fastdeploy.engine.sampling_params import SamplingParams
2626
from fastdeploy.utils import data_processor_logger
27+
from fastdeploy.worker.output import LogprobsLists
2728

2829

2930
@dataclass
@@ -189,6 +190,8 @@ class CompletionOutput:
189190
index: int
190191
send_idx: int
191192
token_ids: list[int]
193+
logprob: Optional[float] = None
194+
top_logprobs: Optional[LogprobsLists] = None
192195
draft_token_ids: list[int] = None
193196
text: Optional[str] = None
194197
reasoning_content: Optional[str] = None
@@ -201,6 +204,8 @@ def to_dict(self):
201204
"index": self.index,
202205
"send_idx": self.send_idx,
203206
"token_ids": self.token_ids,
207+
"logprob": self.logprob,
208+
"top_logprobs": self.top_logprobs,
204209
"draft_token_ids": self.draft_token_ids,
205210
"text": self.text,
206211
"reasoning_content": self.reasoning_content

fastdeploy/engine/sampling_params.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,13 @@ def _verify_args(self) -> None:
173173
f"temperature must be non-negative, got {self.temperature}.")
174174
if self.top_p is not None and not 0.0 <= self.top_p <= 1.0:
175175
raise ValueError(f"top_p must be in [0, 1], got {self.top_p}.")
176+
# quietly accept -1 as disabled, but prefer 0
177+
if self.top_k < -1:
178+
raise ValueError(f"top_k must be 0 (disable), or at least 1, "
179+
f"got {self.top_k}.")
180+
if not isinstance(self.top_k, int):
181+
raise TypeError(
182+
f"top_k must be an integer, got {type(self.top_k).__name__}")
176183

177184
if self.max_tokens is not None and self.max_tokens < 1:
178185
raise ValueError(
@@ -192,6 +199,9 @@ def _verify_args(self) -> None:
192199
if self.logprobs is not None and self.logprobs < 0:
193200
raise ValueError(
194201
f"logprobs must be non-negative, got {self.logprobs}.")
202+
if self.logprobs is not None and self.logprobs > 20:
203+
raise ValueError(
204+
"Invalid value for 'top_logprobs': must be less than or equal to 20.")
195205

196206
if not 0 <= self.seed <= 922337203685477580:
197207
raise ValueError("seed must be in [0, 922337203685477580], got "

fastdeploy/entrypoints/openai/protocol.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ class ChatCompletionResponseChoice(BaseModel):
122122
"""
123123
index: int
124124
message: ChatMessage
125+
logprobs: Optional[LogProbs] = None
125126
finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop"]]
126127

127128

@@ -136,6 +137,21 @@ class ChatCompletionResponse(BaseModel):
136137
choices: List[ChatCompletionResponseChoice]
137138
usage: UsageInfo
138139

140+
class LogProbEntry(BaseModel):
141+
"""
142+
Log probability entry.
143+
"""
144+
token: str
145+
logprob: float
146+
bytes: Optional[List[int]] = None
147+
top_logprobs: Optional[List["LogProbEntry"]] = None
148+
149+
class LogProbs(BaseModel):
150+
"""
151+
LogProbs.
152+
"""
153+
content: Optional[List[LogProbEntry]] = None
154+
refusal: Optional[Union[str, None]] = None
139155

140156
class DeltaMessage(BaseModel):
141157
"""
@@ -154,6 +170,7 @@ class ChatCompletionResponseStreamChoice(BaseModel):
154170
"""
155171
index: int
156172
delta: DeltaMessage
173+
logprobs: Optional[LogProbs] = None
157174
finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None
158175
arrival_time: Optional[float] = None
159176

@@ -392,6 +409,8 @@ class ChatCompletionRequest(BaseModel):
392409
tools: Optional[List[ChatCompletionToolsParam]] = None
393410
model: Optional[str] = "default"
394411
frequency_penalty: Optional[float] = None
412+
logprobs: Optional[bool] = False
413+
top_logprobs: Optional[int] = 0
395414
# remove max_tokens when field is removed from OpenAI API
396415
max_tokens: Optional[int] = Field(
397416
default=None,
@@ -434,6 +453,9 @@ def to_dict_for_infer(self, request_id=None):
434453
if request_id is not None:
435454
req_dict['request_id'] = request_id
436455

456+
req_dict["max_tokens"] = self.max_completion_tokens or self.max_tokens
457+
req_dict["logprobs"] = self.top_logprobs if self.logprobs else None
458+
437459
if self.metadata is not None:
438460
for key, value in self.metadata.items():
439461
req_dict[key] = value
@@ -505,3 +527,18 @@ def validate_stream_options(cls, data):
505527
)
506528

507529
return data
530+
531+
@model_validator(mode="before")
532+
@classmethod
533+
def check_logprobs(cls, data):
534+
535+
if (top_logprobs := data.get("top_logprobs")) is not None:
536+
if top_logprobs < 0:
537+
raise ValueError("`top_logprobs` must be a positive value.")
538+
539+
if top_logprobs > 0 and not data.get("logprobs"):
540+
raise ValueError(
541+
"when using `top_logprobs`, `logprobs` must be set to true."
542+
)
543+
544+
return data

0 commit comments

Comments
 (0)