Skip to content

Commit 036bd79

Browse files
feat(chat): add logprob and topLogprobs (#328)
1 parent 2d870f9 commit 036bd79

File tree

7 files changed

+149
-5
lines changed

7 files changed

+149
-5
lines changed

openai-client/src/commonTest/kotlin/com/aallam/openai/client/TestChatCompletions.kt

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,4 +126,43 @@ class TestChatCompletions : TestOpenAI() {
126126
assertNotNull(answer.question)
127127
assertNotNull(answer.response)
128128
}
129+
130+
@Test
131+
fun logprobs() = test {
132+
val request = chatCompletionRequest {
133+
model = ModelId("gpt-3.5-turbo-0125")
134+
messages {
135+
message {
136+
role = ChatRole.User
137+
content = "What's the weather like in Boston?"
138+
}
139+
}
140+
logprobs = true
141+
}
142+
val response = openAI.chatCompletion(request)
143+
val logprobs = response.choices.first().logprobs
144+
assertNotNull(logprobs)
145+
assertEquals(response.usage!!.completionTokens, logprobs.content!!.size)
146+
}
147+
148+
@Test
149+
fun top_logprobs() = test {
150+
val expectedTopLogProbs = 5
151+
val request = chatCompletionRequest {
152+
model = ModelId("gpt-3.5-turbo-0125")
153+
messages {
154+
message {
155+
role = ChatRole.User
156+
content = "What's the weather like in Boston?"
157+
}
158+
}
159+
logprobs = true
160+
topLogprobs = expectedTopLogProbs
161+
}
162+
val response = openAI.chatCompletion(request)
163+
val logprobs = response.choices.first().logprobs
164+
assertNotNull(logprobs)
165+
assertEquals(response.usage!!.completionTokens, logprobs.content!!.size)
166+
assertEquals(logprobs.content!![0].topLogprobs?.size, expectedTopLogProbs)
167+
}
129168
}

openai-core/src/commonMain/kotlin/com.aallam.openai.api/chat/ChatChoice.kt

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
package com.aallam.openai.api.chat;
1+
package com.aallam.openai.api.chat
22

3-
import com.aallam.openai.api.BetaOpenAI
43
import com.aallam.openai.api.core.FinishReason
54
import kotlinx.serialization.SerialName
65
import kotlinx.serialization.Serializable
@@ -20,9 +19,12 @@ public data class ChatChoice(
2019
* The generated chat message.
2120
*/
2221
@SerialName("message") public val message: ChatMessage,
23-
2422
/**
2523
* The reason why OpenAI stopped generating.
2624
*/
2725
@SerialName("finish_reason") public val finishReason: FinishReason? = null,
26+
/**
27+
* Log probability information for the choice.
28+
*/
29+
@SerialName("logprobs") public val logprobs: Logprobs? = null,
2830
)

openai-core/src/commonMain/kotlin/com.aallam.openai.api/chat/ChatCompletion.kt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
package com.aallam.openai.api.chat
22

3-
import com.aallam.openai.api.BetaOpenAI
43
import com.aallam.openai.api.core.Usage
54
import com.aallam.openai.api.model.ModelId
65
import kotlinx.serialization.SerialName

openai-core/src/commonMain/kotlin/com.aallam.openai.api/chat/ChatCompletionRequest.kt

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,18 @@ public class ChatCompletionRequest(
146146
*/
147147
@property:BetaOpenAI
148148
@SerialName("seed") public val seed: Int? = null,
149+
150+
/**
151+
* Whether to return log probabilities of the output tokens or not. If true,
152+
* returns the log probabilities of each output token returned in the content of message.
153+
*/
154+
@SerialName("logprobs") public val logprobs: Boolean? = null,
155+
156+
/**
157+
* An integer between 0 and 20 specifying the number of most likely tokens to return at each token position,
158+
* each with an associated log probability. logprobs must be set to true if this parameter is used.
159+
*/
160+
@SerialName("top_logprobs") public val topLogprobs: Int? = null,
149161
)
150162

151163
/**
@@ -282,6 +294,18 @@ public class ChatCompletionRequestBuilder {
282294
*/
283295
public var toolChoice: ToolChoice? = null
284296

297+
/**
298+
* Whether to return log probabilities of the output tokens or not. If true,
299+
* returns the log probabilities of each output token returned in the content of message.
300+
*/
301+
public var logprobs: Boolean? = null
302+
303+
/**
304+
* An integer between 0 and 20 specifying the number of most likely tokens to return at each token position,
305+
* each with an associated log probability. logprobs must be set to true if this parameter is used.
306+
*/
307+
public var topLogprobs: Int? = null
308+
285309
/**
286310
* The messages to generate chat completions for.
287311
*/
@@ -323,7 +347,9 @@ public class ChatCompletionRequestBuilder {
323347
functionCall = functionCall,
324348
responseFormat = responseFormat,
325349
toolChoice = toolChoice,
326-
tools = tools
350+
tools = tools,
351+
logprobs = logprobs,
352+
topLogprobs = topLogprobs
327353
)
328354
}
329355

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
package com.aallam.openai.api.chat
2+
3+
import kotlinx.serialization.SerialName
4+
import kotlinx.serialization.Serializable
5+
6+
/**
7+
* An object containing log probability information for the choice.
8+
*
9+
* [documentation](https://platform.openai.com/docs/api-reference/chat/object)
10+
*/
11+
@Serializable
12+
public data class Logprobs(
13+
/**
14+
* A list of message content tokens with log probability information.
15+
*/
16+
@SerialName("content") public val content: List<LogprobsContent>? = null,
17+
)
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
package com.aallam.openai.api.chat
2+
3+
import kotlinx.serialization.SerialName
4+
import kotlinx.serialization.Serializable
5+
6+
/**
7+
* An object containing logprobs for a single token
8+
*
9+
* [documentation](https://platform.openai.com/docs/api-reference/chat/object)
10+
*/
11+
@Serializable
12+
public data class LogprobsContent(
13+
/**
14+
* The token.
15+
*/
16+
@SerialName("token") public val token: String,
17+
/**
18+
* The log probability of this token, if it is within the top 20 most likely tokens.
19+
* Otherwise, the value -9999.0 is used to signify that the token is very unlikely.
20+
*/
21+
@SerialName("logprob") public val logprob: Double,
22+
/**
23+
* A list of integers representing the UTF-8 bytes representation of the token. Useful in instances where
24+
* characters are represented by multiple tokens and their byte representations must be combined to generate
25+
* the correct text representation. Can be `null` if there is no bytes representation for the token.
26+
*/
27+
@SerialName("bytes") public val bytes: List<Int>? = null,
28+
/**
29+
* List of the most likely tokens and their log probability, at this token position.
30+
* In rare cases, there may be fewer than the number of requested top_logprobs returned.
31+
*/
32+
@SerialName("top_logprobs") public val topLogprobs: List<TopLogprob>,
33+
)
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
package com.aallam.openai.api.chat
2+
3+
import kotlinx.serialization.SerialName
4+
import kotlinx.serialization.Serializable
5+
6+
/**
7+
* An object containing a token and their log probability.
8+
*
9+
* [documentation](https://platform.openai.com/docs/api-reference/chat/object)
10+
*/
11+
@Serializable
12+
public data class TopLogprob(
13+
/**
14+
* The token
15+
*/
16+
@SerialName("token") public val token: String,
17+
/**
18+
* The log probability of this token, if it is within the top 20 most likely tokens.
19+
* Otherwise, the value `-9999.0` is used to signify that the token is very unlikely.
20+
*/
21+
@SerialName("logprob") public val logprob: Double,
22+
/**
23+
* A list of integers representing the UTF-8 bytes representation of the token. Useful in instances where
24+
* characters are represented by multiple tokens and their byte representations must be combined to generate
25+
* the correct text representation. Can be `null` if there is no bytes representation for the token.
26+
*/
27+
@SerialName("bytes") public val bytes: List<Int>? = null,
28+
)

0 commit comments

Comments
 (0)