Skip to content

Commit 1f4cce1

Browse files
authored
feat(chat): allow specifying JSON schema for chat completions (#398)
1 parent 593e327 commit 1f4cce1

File tree

3 files changed

+99
-2
lines changed

3 files changed

+99
-2
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
# Unreleased
2+
3+
### Added
4+
- **chat**: Add support for structured outputs (#397)
5+
16
## 4.0.0-beta01
27
> Published 27 Oct 2024
38

openai-client/src/commonTest/kotlin/com/aallam/openai/client/TestChatCompletions.kt

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package com.aallam.openai.client
22

33
import com.aallam.openai.api.chat.*
4+
import com.aallam.openai.api.chat.ChatResponseFormat.Companion.jsonSchema
45
import com.aallam.openai.api.model.ModelId
56
import kotlinx.coroutines.flow.collect
67
import kotlinx.coroutines.flow.launchIn
@@ -9,6 +10,9 @@ import kotlinx.coroutines.launch
910
import kotlinx.coroutines.test.advanceTimeBy
1011
import kotlinx.serialization.Serializable
1112
import kotlinx.serialization.json.Json
13+
import kotlinx.serialization.json.JsonArray
14+
import kotlinx.serialization.json.JsonObject
15+
import kotlinx.serialization.json.JsonPrimitive
1216
import kotlin.coroutines.cancellation.CancellationException
1317
import kotlin.test.*
1418

@@ -131,6 +135,62 @@ class TestChatCompletions : TestOpenAI() {
131135
assertNotNull(answer.response)
132136
}
133137

138+
@Test
139+
fun jsonSchema() = test {
140+
val schemaJson = JsonObject(mapOf(
141+
"type" to JsonPrimitive("object"),
142+
"properties" to JsonObject(mapOf(
143+
"question" to JsonObject(mapOf(
144+
"type" to JsonPrimitive("string"),
145+
"description" to JsonPrimitive("The question that was asked")
146+
)),
147+
"response" to JsonObject(mapOf(
148+
"type" to JsonPrimitive("string"),
149+
"description" to JsonPrimitive("The answer to the question")
150+
))
151+
)),
152+
"required" to JsonArray(listOf(
153+
JsonPrimitive("question"),
154+
JsonPrimitive("response")
155+
))
156+
))
157+
158+
val jsonSchema = JsonSchema(
159+
name = "AnswerSchema",
160+
schema = schemaJson,
161+
strict = true
162+
)
163+
164+
val request = chatCompletionRequest {
165+
model = ModelId("gpt-4o-mini-2024-07-18")
166+
responseFormat = jsonSchema(jsonSchema)
167+
messages {
168+
message {
169+
role = ChatRole.System
170+
content = "You are a helpful assistant.!"
171+
}
172+
message {
173+
role = ChatRole.System
174+
content = """All your answers should be a valid JSON
175+
""".trimMargin()
176+
}
177+
message {
178+
role = ChatRole.User
179+
content = "Who won the world cup in 1998?"
180+
}
181+
}
182+
}
183+
val response = openAI.chatCompletion(request)
184+
val content = response.choices.first().message.content.orEmpty()
185+
186+
@Serializable
187+
data class Answer(val question: String? = null, val response: String? = null)
188+
189+
val answer = Json.decodeFromString<Answer>(content)
190+
assertNotNull(answer.question)
191+
assertNotNull(answer.response)
192+
}
193+
134194
@Test
135195
fun logprobs() = test {
136196
val request = chatCompletionRequest {

openai-core/src/commonMain/kotlin/com.aallam.openai.api/chat/ChatResponseFormat.kt

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package com.aallam.openai.api.chat
22

33
import kotlinx.serialization.SerialName
44
import kotlinx.serialization.Serializable
5+
import kotlinx.serialization.json.JsonObject
56

67
/**
78
* An object specifying the format that the model must output.
@@ -11,9 +12,13 @@ public data class ChatResponseFormat(
1112
/**
1213
* Response format type.
1314
*/
14-
@SerialName("type") val type: String
15-
) {
15+
@SerialName("type") val type: String,
1616

17+
/**
18+
* Optional JSON schema specification when type is "json_schema"
19+
*/
20+
@SerialName("json_schema") val jsonSchema: JsonSchema? = null
21+
) {
1722
public companion object {
1823
/**
1924
* JSON mode, which guarantees the message the model generates, is valid JSON.
@@ -24,5 +29,32 @@ public data class ChatResponseFormat(
2429
* Default text mode.
2530
*/
2631
public val Text: ChatResponseFormat = ChatResponseFormat(type = "text")
32+
33+
/**
34+
* Creates a JSON schema response format with the specified schema
35+
*/
36+
public fun jsonSchema(schema: JsonSchema): ChatResponseFormat =
37+
ChatResponseFormat(type = "json_schema", jsonSchema = schema)
2738
}
2839
}
40+
41+
/**
42+
* Specification for JSON schema response format
43+
*/
44+
@Serializable
45+
public data class JsonSchema(
46+
/**
47+
* Optional name for the schema
48+
*/
49+
@SerialName("name") val name: String? = null,
50+
51+
/**
52+
* The JSON schema specification
53+
*/
54+
@SerialName("schema") val schema: JsonObject,
55+
56+
/**
57+
* Whether to enforce strict schema validation
58+
*/
59+
@SerialName("strict") val strict: Boolean = true
60+
)

0 commit comments

Comments
 (0)