Skip to content

Commit aa9f1ae

Browse files
feat: Add logprobs support to chat completions (#1311)
* Add logprobs return in ChatCompletionResponse * Fix duplicate field * Set default to false * Simplify check * Add server example --------- Co-authored-by: Andrei Betlen <abetlen@gmail.com>
1 parent 1e60dba commit aa9f1ae

File tree

5 files changed

+28
-1
lines changed

5 files changed

+28
-1
lines changed

llama_cpp/llama.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1653,6 +1653,7 @@ def create_chat_completion(
16531653
top_k=top_k,
16541654
min_p=min_p,
16551655
typical_p=typical_p,
1656+
logprobs=top_logprobs if logprobs else None,
16561657
stream=stream,
16571658
stop=stop,
16581659
seed=seed,

llama_cpp/llama_chat_format.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,7 @@ def _convert_text_completion_to_chat(
231231
"role": "assistant",
232232
"content": completion["choices"][0]["text"],
233233
},
234+
"logprobs": completion["choices"][0]["logprobs"],
234235
"finish_reason": completion["choices"][0]["finish_reason"],
235236
}
236237
],
@@ -254,6 +255,7 @@ def _convert_text_completion_chunks_to_chat(
254255
"delta": {
255256
"role": "assistant",
256257
},
258+
"logprobs": None,
257259
"finish_reason": None,
258260
}
259261
],
@@ -273,6 +275,7 @@ def _convert_text_completion_chunks_to_chat(
273275
if chunk["choices"][0]["finish_reason"] is None
274276
else {}
275277
),
278+
"logprobs": chunk["choices"][0]["logprobs"],
276279
"finish_reason": chunk["choices"][0]["finish_reason"],
277280
}
278281
],
@@ -487,6 +490,7 @@ def chat_completion_handler(
487490
temperature: float = 0.2,
488491
top_p: float = 0.95,
489492
top_k: int = 40,
493+
logprobs: int = 0,
490494
min_p: float = 0.05,
491495
typical_p: float = 1.0,
492496
stream: bool = False,
@@ -576,6 +580,7 @@ def chat_completion_handler(
576580
top_k=top_k,
577581
min_p=min_p,
578582
typical_p=typical_p,
583+
logprobs=logprobs,
579584
stream=stream,
580585
stop=stop,
581586
seed=seed,

llama_cpp/llama_types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ class ChatCompletionFunction(TypedDict):
8484
class ChatCompletionResponseChoice(TypedDict):
8585
index: int
8686
message: "ChatCompletionResponseMessage"
87+
logprobs: Optional[CompletionLogprobs]
8788
finish_reason: Optional[str]
8889

8990

llama_cpp/server/app.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,18 @@ async def create_chat_completion(
405405
}
406406
},
407407
},
408+
"logprobs": {
409+
"summary": "Logprobs",
410+
"value": {
411+
"model": "gpt-3.5-turbo",
412+
"messages": [
413+
{"role": "system", "content": "You are a helpful assistant."},
414+
{"role": "user", "content": "What is the capital of France?"},
415+
],
416+
"logprobs": True,
417+
"top_logprobs": 10
418+
},
419+
},
408420
}
409421
),
410422
llama_proxy: LlamaProxy = Depends(get_llama_proxy),

llama_cpp/server/types.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,6 @@ class CreateCompletionRequest(BaseModel):
130130
presence_penalty: Optional[float] = presence_penalty_field
131131
frequency_penalty: Optional[float] = frequency_penalty_field
132132
logit_bias: Optional[Dict[str, float]] = Field(None)
133-
logprobs: Optional[int] = Field(None)
134133
seed: Optional[int] = Field(None)
135134

136135
# ignored or currently unsupported
@@ -209,6 +208,15 @@ class CreateChatCompletionRequest(BaseModel):
209208
default=None,
210209
description="The maximum number of tokens to generate. Defaults to inf",
211210
)
211+
logprobs: Optional[bool] = Field(
212+
default=False,
213+
description="Whether to output the logprobs or not. Default is True"
214+
)
215+
top_logprobs: Optional[int] = Field(
216+
default=None,
217+
ge=0,
218+
description="The number of logprobs to generate. If None, no logprobs are generated. logprobs need to set to True.",
219+
)
212220
temperature: float = temperature_field
213221
top_p: float = top_p_field
214222
min_p: float = min_p_field

0 commit comments

Comments
 (0)