Skip to content

Commit 91afed5

Browse files
tzolovmarkpollack
authored andcommitted
OpenAi: Add support for structured outputs and JSON schema
- Added support for OpenAI's structured outputs feature, which allows specifying a JSON schema for the model to match - Introduced new record to configure the desired response format - Added support for configuring the response format via application properties or the chat options builder - Extend teh BeanOutputConverter to help generate JSON schema from a target domain object and convert the response. - Added comprehensive tests to cover the new response format functionality Resolves #1196
1 parent 866b262 commit 91afed5

File tree

7 files changed

+634
-25
lines changed

7 files changed

+634
-25
lines changed

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import org.springframework.util.CollectionUtils;
3333
import org.springframework.util.LinkedMultiValueMap;
3434
import org.springframework.util.MultiValueMap;
35+
import org.springframework.util.StringUtils;
3536
import org.springframework.web.client.ResponseErrorHandler;
3637
import org.springframework.web.client.RestClient;
3738
import org.springframework.web.reactive.function.client.WebClient;
@@ -521,7 +522,53 @@ public static Object FUNCTION(String functionName) {
521522
*/
522523
@JsonInclude(Include.NON_NULL)
523524
public record ResponseFormat(
524-
@JsonProperty("type") String type) {
525+
@JsonProperty("type") Type type,
526+
@JsonProperty("json_schema") JsonSchema jsonSchema ) {
527+
528+
public enum Type {
529+
/**
530+
* Enables JSON mode, which guarantees the message
531+
* the model generates is valid JSON.
532+
*/
533+
@JsonProperty("json_object")
534+
JSON_OBJECT,
535+
536+
/**
537+
* Enables Structured Outputs which guarantees the model
538+
* will match your supplied JSON schema.
539+
*/
540+
@JsonProperty("json_schema")
541+
JSON_SCHEMA
542+
}
543+
544+
@JsonInclude(Include.NON_NULL)
545+
public record JsonSchema(
546+
@JsonProperty("name") String name,
547+
@JsonProperty("schema") Map<String, Object> schema,
548+
@JsonProperty("strict") Boolean strict) {
549+
550+
public JsonSchema(String name, String schema) {
551+
this(name, ModelOptionsUtils.jsonToMap(schema), true);
552+
}
553+
554+
public JsonSchema(String name, String schema, Boolean strict) {
555+
this(StringUtils.hasText(name)? name : "custom_response_format_schema", ModelOptionsUtils.jsonToMap(schema), strict);
556+
}
557+
}
558+
559+
public ResponseFormat(Type type) {
560+
this(type, (JsonSchema) null);
561+
}
562+
563+
public ResponseFormat(Type type, String jsonSchena) {
564+
this(type, "custom_response_format_schema", jsonSchena, true);
565+
}
566+
567+
@ConstructorBinding
568+
public ResponseFormat(Type type, String name, String schema, Boolean strict) {
569+
this(type, StringUtils.hasText(schema)? new JsonSchema(name, schema, strict): null);
570+
}
571+
525572
}
526573

527574
/**

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModel2IT.java renamed to models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelResponseFormatIT.java

Lines changed: 99 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,43 +15,46 @@
1515
*/
1616
package org.springframework.ai.openai.chat;
1717

18-
import com.fasterxml.jackson.core.JacksonException;
19-
import com.fasterxml.jackson.core.JsonProcessingException;
20-
import com.fasterxml.jackson.databind.DeserializationFeature;
21-
import com.fasterxml.jackson.databind.JsonMappingException;
22-
import com.fasterxml.jackson.databind.ObjectMapper;
18+
import static org.assertj.core.api.Assertions.assertThat;
19+
2320
import org.junit.jupiter.api.Test;
2421
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
2522
import org.slf4j.Logger;
2623
import org.slf4j.LoggerFactory;
27-
2824
import org.springframework.ai.chat.model.ChatResponse;
2925
import org.springframework.ai.chat.prompt.Prompt;
26+
import org.springframework.ai.converter.BeanOutputConverter;
3027
import org.springframework.ai.openai.OpenAiChatModel;
3128
import org.springframework.ai.openai.OpenAiChatOptions;
3229
import org.springframework.ai.openai.api.OpenAiApi;
33-
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest;
30+
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.ResponseFormat;
31+
import org.springframework.ai.openai.api.OpenAiApi.ChatModel;
3432
import org.springframework.beans.factory.annotation.Autowired;
3533
import org.springframework.boot.SpringBootConfiguration;
3634
import org.springframework.boot.test.context.SpringBootTest;
3735
import org.springframework.context.annotation.Bean;
3836

39-
import static org.assertj.core.api.Assertions.assertThat;
37+
import com.fasterxml.jackson.annotation.JsonProperty;
38+
import com.fasterxml.jackson.core.JacksonException;
39+
import com.fasterxml.jackson.core.JsonProcessingException;
40+
import com.fasterxml.jackson.databind.DeserializationFeature;
41+
import com.fasterxml.jackson.databind.JsonMappingException;
42+
import com.fasterxml.jackson.databind.ObjectMapper;
4043

4144
/**
4245
* @author Christian Tzolov
4346
*/
44-
@SpringBootTest(classes = OpenAiChatModel2IT.Config.class)
47+
@SpringBootTest(classes = OpenAiChatModelResponseFormatIT.Config.class)
4548
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
46-
public class OpenAiChatModel2IT {
49+
public class OpenAiChatModelResponseFormatIT {
4750

4851
private final Logger logger = LoggerFactory.getLogger(getClass());
4952

5053
@Autowired
5154
private OpenAiChatModel openAiChatModel;
5255

5356
@Test
54-
void responseFormatTest() throws JsonMappingException, JsonProcessingException {
57+
void jsonObject() throws JsonMappingException, JsonProcessingException {
5558

5659
// 400 - ResponseError[error=Error[message='json' is not one of ['json_object',
5760
// 'text'] -
@@ -64,7 +67,50 @@ void responseFormatTest() throws JsonMappingException, JsonProcessingException {
6467

6568
Prompt prompt = new Prompt("List 8 planets. Use JSON response",
6669
OpenAiChatOptions.builder()
67-
.withResponseFormat(new ChatCompletionRequest.ResponseFormat("json_object"))
70+
.withResponseFormat(new ResponseFormat(ResponseFormat.Type.JSON_OBJECT))
71+
.build());
72+
73+
ChatResponse response = this.openAiChatModel.call(prompt);
74+
75+
assertThat(response).isNotNull();
76+
77+
String content = response.getResult().getOutput().getContent();
78+
79+
logger.info("Response content: {}", content);
80+
81+
assertThat(isValidJson(content)).isTrue();
82+
}
83+
84+
@Test
85+
void jsonSchema() throws JsonMappingException, JsonProcessingException {
86+
87+
var jsonSchema = """
88+
{
89+
"type": "object",
90+
"properties": {
91+
"steps": {
92+
"type": "array",
93+
"items": {
94+
"type": "object",
95+
"properties": {
96+
"explanation": { "type": "string" },
97+
"output": { "type": "string" }
98+
},
99+
"required": ["explanation", "output"],
100+
"additionalProperties": false
101+
}
102+
},
103+
"final_answer": { "type": "string" }
104+
},
105+
"required": ["steps", "final_answer"],
106+
"additionalProperties": false
107+
}
108+
""";
109+
110+
Prompt prompt = new Prompt("how can I solve 8x + 7 = -23",
111+
OpenAiChatOptions.builder()
112+
.withModel(ChatModel.GPT_4_O_MINI)
113+
.withResponseFormat(new ResponseFormat(ResponseFormat.Type.JSON_SCHEMA, jsonSchema))
68114
.build());
69115

70116
ChatResponse response = this.openAiChatModel.call(prompt);
@@ -78,6 +124,47 @@ void responseFormatTest() throws JsonMappingException, JsonProcessingException {
78124
assertThat(isValidJson(content)).isTrue();
79125
}
80126

127+
@Test
128+
void jsonSchemaBeanConverter() throws JsonMappingException, JsonProcessingException {
129+
130+
record MathReasoning(@JsonProperty(required = true, value = "steps") Steps steps,
131+
@JsonProperty(required = true, value = "final_answer") String finalAnswer) {
132+
133+
record Steps(@JsonProperty(required = true, value = "items") Items[] items) {
134+
135+
record Items(@JsonProperty(required = true, value = "explanation") String explanation,
136+
@JsonProperty(required = true, value = "output") String output) {
137+
}
138+
}
139+
}
140+
141+
var outputConverter = new BeanOutputConverter<>(MathReasoning.class);
142+
143+
var jsonSchema1 = outputConverter.getJsonSchema();
144+
145+
System.out.println(jsonSchema1);
146+
147+
Prompt prompt = new Prompt("how can I solve 8x + 7 = -23",
148+
OpenAiChatOptions.builder()
149+
.withModel(ChatModel.GPT_4_O_MINI)
150+
.withResponseFormat(new ResponseFormat(ResponseFormat.Type.JSON_SCHEMA, jsonSchema1))
151+
.build());
152+
153+
ChatResponse response = this.openAiChatModel.call(prompt);
154+
155+
assertThat(response).isNotNull();
156+
157+
String content = response.getResult().getOutput().getContent();
158+
159+
logger.info("Response content: {}", content);
160+
161+
MathReasoning mathReasoning = outputConverter.convert(content);
162+
163+
System.out.println(mathReasoning);
164+
165+
assertThat(isValidJson(content)).isTrue();
166+
}
167+
81168
private static ObjectMapper MAPPER = new ObjectMapper().enable(DeserializationFeature.FAIL_ON_TRAILING_TOKENS);
82169

83170
public static boolean isValidJson(String json) {

spring-ai-core/src/main/java/org/springframework/ai/converter/BeanOutputConverter.java

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,12 @@
3535
import com.fasterxml.jackson.databind.ObjectMapper;
3636
import com.fasterxml.jackson.databind.ObjectWriter;
3737
import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule;
38+
import com.github.victools.jsonschema.generator.Option;
3839
import com.github.victools.jsonschema.generator.SchemaGenerator;
3940
import com.github.victools.jsonschema.generator.SchemaGeneratorConfig;
4041
import com.github.victools.jsonschema.generator.SchemaGeneratorConfigBuilder;
4142
import com.github.victools.jsonschema.module.jackson.JacksonModule;
43+
import com.github.victools.jsonschema.module.jackson.JacksonOption;
4244

4345
/**
4446
* An implementation of {@link StructuredOutputConverter} that transforms the LLM output
@@ -140,9 +142,10 @@ private BeanOutputConverter(TypeReference<T> typeRef, ObjectMapper objectMapper)
140142
* Generates the JSON schema for the target type.
141143
*/
142144
private void generateSchema() {
143-
JacksonModule jacksonModule = new JacksonModule();
145+
JacksonModule jacksonModule = new JacksonModule(JacksonOption.RESPECT_JSONPROPERTY_REQUIRED);
144146
SchemaGeneratorConfigBuilder configBuilder = new SchemaGeneratorConfigBuilder(DRAFT_2020_12, PLAIN_JSON)
145-
.with(jacksonModule);
147+
.with(jacksonModule)
148+
.with(Option.FORBIDDEN_ADDITIONAL_PROPERTIES_BY_DEFAULT);
146149
SchemaGeneratorConfig config = configBuilder.build();
147150
SchemaGenerator generator = new SchemaGenerator(config);
148151
JsonNode jsonNode = generator.generateSchema(this.typeRef.getType());
@@ -205,4 +208,12 @@ public String getFormat() {
205208
return String.format(template, this.jsonSchema);
206209
}
207210

211+
/**
212+
* Provides the generated JSON schema for the target type.
213+
* @return The generated JSON schema.
214+
*/
215+
public String getJsonSchema() {
216+
return this.jsonSchema;
217+
}
218+
208219
}

spring-ai-core/src/test/java/org/springframework/ai/converter/BeanOutputConverterTest.java

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,8 @@ public void formatClassType() {
133133
"someString" : {
134134
"type" : "string"
135135
}
136-
}
136+
},
137+
"additionalProperties" : false
137138
}```
138139
""");
139140
}
@@ -156,7 +157,8 @@ public void formatTypeReference() {
156157
"someString" : {
157158
"type" : "string"
158159
}
159-
}
160+
},
161+
"additionalProperties" : false
160162
}```
161163
""");
162164
}
@@ -181,7 +183,8 @@ public void formatTypeReferenceArray() {
181183
"someString" : {
182184
"type" : "string"
183185
}
184-
}
186+
},
187+
"additionalProperties" : false
185188
}
186189
}```
187190
""");
@@ -199,7 +202,8 @@ public void formatClassTypeWithAnnotations() {
199202
"type" : "string",
200203
"description" : "string_property_description"
201204
}
202-
}
205+
},
206+
"additionalProperties" : false
203207
}```
204208
""");
205209
}
@@ -217,7 +221,8 @@ public void formatTypeReferenceWithAnnotations() {
217221
"type" : "string",
218222
"description" : "string_property_description"
219223
}
220-
}
224+
},
225+
"additionalProperties" : false
221226
}```
222227
""");
223228
}

0 commit comments

Comments
 (0)