15
15
*/
16
16
package org .springframework .ai .openai .chat ;
17
17
18
- import com .fasterxml .jackson .core .JacksonException ;
19
- import com .fasterxml .jackson .core .JsonProcessingException ;
20
- import com .fasterxml .jackson .databind .DeserializationFeature ;
21
- import com .fasterxml .jackson .databind .JsonMappingException ;
22
- import com .fasterxml .jackson .databind .ObjectMapper ;
18
+ import static org .assertj .core .api .Assertions .assertThat ;
19
+
23
20
import org .junit .jupiter .api .Test ;
24
21
import org .junit .jupiter .api .condition .EnabledIfEnvironmentVariable ;
25
22
import org .slf4j .Logger ;
26
23
import org .slf4j .LoggerFactory ;
27
-
28
24
import org .springframework .ai .chat .model .ChatResponse ;
29
25
import org .springframework .ai .chat .prompt .Prompt ;
26
+ import org .springframework .ai .converter .BeanOutputConverter ;
30
27
import org .springframework .ai .openai .OpenAiChatModel ;
31
28
import org .springframework .ai .openai .OpenAiChatOptions ;
32
29
import org .springframework .ai .openai .api .OpenAiApi ;
33
- import org .springframework .ai .openai .api .OpenAiApi .ChatCompletionRequest ;
30
+ import org .springframework .ai .openai .api .OpenAiApi .ChatCompletionRequest .ResponseFormat ;
31
+ import org .springframework .ai .openai .api .OpenAiApi .ChatModel ;
34
32
import org .springframework .beans .factory .annotation .Autowired ;
35
33
import org .springframework .boot .SpringBootConfiguration ;
36
34
import org .springframework .boot .test .context .SpringBootTest ;
37
35
import org .springframework .context .annotation .Bean ;
38
36
39
- import static org .assertj .core .api .Assertions .assertThat ;
37
+ import com .fasterxml .jackson .annotation .JsonProperty ;
38
+ import com .fasterxml .jackson .core .JacksonException ;
39
+ import com .fasterxml .jackson .core .JsonProcessingException ;
40
+ import com .fasterxml .jackson .databind .DeserializationFeature ;
41
+ import com .fasterxml .jackson .databind .JsonMappingException ;
42
+ import com .fasterxml .jackson .databind .ObjectMapper ;
40
43
41
44
/**
42
45
* @author Christian Tzolov
43
46
*/
44
- @ SpringBootTest (classes = OpenAiChatModel2IT .Config .class )
47
+ @ SpringBootTest (classes = OpenAiChatModelResponseFormatIT .Config .class )
45
48
@ EnabledIfEnvironmentVariable (named = "OPENAI_API_KEY" , matches = ".+" )
46
- public class OpenAiChatModel2IT {
49
+ public class OpenAiChatModelResponseFormatIT {
47
50
48
51
private final Logger logger = LoggerFactory .getLogger (getClass ());
49
52
50
53
@ Autowired
51
54
private OpenAiChatModel openAiChatModel ;
52
55
53
56
@ Test
54
- void responseFormatTest () throws JsonMappingException , JsonProcessingException {
57
+ void jsonObject () throws JsonMappingException , JsonProcessingException {
55
58
56
59
// 400 - ResponseError[error=Error[message='json' is not one of ['json_object',
57
60
// 'text'] -
@@ -64,7 +67,50 @@ void responseFormatTest() throws JsonMappingException, JsonProcessingException {
64
67
65
68
Prompt prompt = new Prompt ("List 8 planets. Use JSON response" ,
66
69
OpenAiChatOptions .builder ()
67
- .withResponseFormat (new ChatCompletionRequest .ResponseFormat ("json_object" ))
70
+ .withResponseFormat (new ResponseFormat (ResponseFormat .Type .JSON_OBJECT ))
71
+ .build ());
72
+
73
+ ChatResponse response = this .openAiChatModel .call (prompt );
74
+
75
+ assertThat (response ).isNotNull ();
76
+
77
+ String content = response .getResult ().getOutput ().getContent ();
78
+
79
+ logger .info ("Response content: {}" , content );
80
+
81
+ assertThat (isValidJson (content )).isTrue ();
82
+ }
83
+
84
+ @ Test
85
+ void jsonSchema () throws JsonMappingException , JsonProcessingException {
86
+
87
+ var jsonSchema = """
88
+ {
89
+ "type": "object",
90
+ "properties": {
91
+ "steps": {
92
+ "type": "array",
93
+ "items": {
94
+ "type": "object",
95
+ "properties": {
96
+ "explanation": { "type": "string" },
97
+ "output": { "type": "string" }
98
+ },
99
+ "required": ["explanation", "output"],
100
+ "additionalProperties": false
101
+ }
102
+ },
103
+ "final_answer": { "type": "string" }
104
+ },
105
+ "required": ["steps", "final_answer"],
106
+ "additionalProperties": false
107
+ }
108
+ """ ;
109
+
110
+ Prompt prompt = new Prompt ("how can I solve 8x + 7 = -23" ,
111
+ OpenAiChatOptions .builder ()
112
+ .withModel (ChatModel .GPT_4_O_MINI )
113
+ .withResponseFormat (new ResponseFormat (ResponseFormat .Type .JSON_SCHEMA , jsonSchema ))
68
114
.build ());
69
115
70
116
ChatResponse response = this .openAiChatModel .call (prompt );
@@ -78,6 +124,47 @@ void responseFormatTest() throws JsonMappingException, JsonProcessingException {
78
124
assertThat (isValidJson (content )).isTrue ();
79
125
}
80
126
127
+ @ Test
128
+ void jsonSchemaBeanConverter () throws JsonMappingException , JsonProcessingException {
129
+
130
+ record MathReasoning (@ JsonProperty (required = true , value = "steps" ) Steps steps ,
131
+ @ JsonProperty (required = true , value = "final_answer" ) String finalAnswer ) {
132
+
133
+ record Steps (@ JsonProperty (required = true , value = "items" ) Items [] items ) {
134
+
135
+ record Items (@ JsonProperty (required = true , value = "explanation" ) String explanation ,
136
+ @ JsonProperty (required = true , value = "output" ) String output ) {
137
+ }
138
+ }
139
+ }
140
+
141
+ var outputConverter = new BeanOutputConverter <>(MathReasoning .class );
142
+
143
+ var jsonSchema1 = outputConverter .getJsonSchema ();
144
+
145
+ System .out .println (jsonSchema1 );
146
+
147
+ Prompt prompt = new Prompt ("how can I solve 8x + 7 = -23" ,
148
+ OpenAiChatOptions .builder ()
149
+ .withModel (ChatModel .GPT_4_O_MINI )
150
+ .withResponseFormat (new ResponseFormat (ResponseFormat .Type .JSON_SCHEMA , jsonSchema1 ))
151
+ .build ());
152
+
153
+ ChatResponse response = this .openAiChatModel .call (prompt );
154
+
155
+ assertThat (response ).isNotNull ();
156
+
157
+ String content = response .getResult ().getOutput ().getContent ();
158
+
159
+ logger .info ("Response content: {}" , content );
160
+
161
+ MathReasoning mathReasoning = outputConverter .convert (content );
162
+
163
+ System .out .println (mathReasoning );
164
+
165
+ assertThat (isValidJson (content )).isTrue ();
166
+ }
167
+
81
168
private static ObjectMapper MAPPER = new ObjectMapper ().enable (DeserializationFeature .FAIL_ON_TRAILING_TOKENS );
82
169
83
170
public static boolean isValidJson (String json ) {
0 commit comments