16
16
package org .springframework .ai .openai ;
17
17
18
18
import java .time .Duration ;
19
+ import java .util .ArrayList ;
20
+ import java .util .HashMap ;
21
+ import java .util .HashSet ;
19
22
import java .util .List ;
20
23
import java .util .Map ;
24
+ import java .util .Set ;
21
25
import java .util .concurrent .ConcurrentHashMap ;
22
26
23
27
import org .slf4j .Logger ;
33
37
import org .springframework .ai .chat .metadata .RateLimit ;
34
38
import org .springframework .ai .chat .prompt .Prompt ;
35
39
import org .springframework .ai .model .ModelOptionsUtils ;
40
+ import org .springframework .ai .model .ToolFunctionCallback ;
36
41
import org .springframework .ai .openai .api .OpenAiApi ;
37
42
import org .springframework .ai .openai .api .OpenAiApi .ChatCompletion ;
38
43
import org .springframework .ai .openai .api .OpenAiApi .ChatCompletionMessage ;
44
+ import org .springframework .ai .openai .api .OpenAiApi .ChatCompletionMessage .Role ;
45
+ import org .springframework .ai .openai .api .OpenAiApi .ChatCompletionMessage .ToolCall ;
39
46
import org .springframework .ai .openai .api .OpenAiApi .ChatCompletionRequest ;
40
47
import org .springframework .ai .openai .api .OpenAiApi .OpenAiApiException ;
41
48
import org .springframework .ai .openai .metadata .OpenAiChatResponseMetadata ;
46
53
import org .springframework .retry .RetryListener ;
47
54
import org .springframework .retry .support .RetryTemplate ;
48
55
import org .springframework .util .Assert ;
56
+ import org .springframework .util .CollectionUtils ;
49
57
50
58
/**
51
59
* {@link ChatClient} implementation for {@literal OpenAI} backed by {@link OpenAiApi}.
@@ -66,11 +74,14 @@ public class OpenAiChatClient implements ChatClient, StreamingChatClient {
66
74
67
75
private OpenAiChatOptions defaultOptions ;
68
76
77
+ private Map <String , ToolFunctionCallback > toolCallbackRegister = new ConcurrentHashMap <>();
78
+
69
79
public final RetryTemplate retryTemplate = RetryTemplate .builder ()
70
80
.maxAttempts (10 )
71
81
.retryOn (OpenAiApiException .class )
72
82
.exponentialBackoff (Duration .ofMillis (2000 ), 5 , Duration .ofMillis (3 * 60000 ))
73
83
.withListener (new RetryListener () {
84
+ @ Override
74
85
public <T extends Object , E extends Throwable > void onError (RetryContext context ,
75
86
RetryCallback <T , E > callback , Throwable throwable ) {
76
87
logger .warn ("Retry error. Retry count:" + context .getRetryCount (), throwable );
@@ -108,18 +119,18 @@ public ChatResponse call(Prompt prompt) {
108
119
109
120
ChatCompletionRequest request = createRequest (prompt , false );
110
121
111
- ResponseEntity <ChatCompletion > completionEntity = this .openAiApi . chatCompletionEntity (request );
122
+ ResponseEntity <ChatCompletion > completionEntity = this .chatCompletionWithTools (request );
112
123
113
124
var chatCompletion = completionEntity .getBody ();
114
125
if (chatCompletion == null ) {
115
- logger .warn ("No chat completion returned for request : {}" , prompt );
126
+ logger .warn ("No chat completion returned for prompt : {}" , prompt );
116
127
return new ChatResponse (List .of ());
117
128
}
118
129
119
130
RateLimit rateLimits = OpenAiResponseHeaderExtractor .extractAiResponseHeaders (completionEntity );
120
131
121
132
List <Generation > generations = chatCompletion .choices ().stream ().map (choice -> {
122
- return new Generation (choice .message ().content (), Map . of ( "role" , choice .message (). role (). name ()))
133
+ return new Generation (choice .message ().content (), toMap ( choice .message ()))
123
134
.withGenerationMetadata (ChatGenerationMetadata .from (choice .finishReason ().name (), null ));
124
135
}).toList ();
125
136
@@ -162,6 +173,8 @@ public Flux<ChatResponse> stream(Prompt prompt) {
162
173
*/
163
174
ChatCompletionRequest createRequest (Prompt prompt , boolean stream ) {
164
175
176
+ Set <String > enabledFunctionsForRequest = new HashSet <>();
177
+
165
178
List <ChatCompletionMessage > chatCompletionMessages = prompt .getInstructions ()
166
179
.stream ()
167
180
.map (m -> new ChatCompletionMessage (m .getContent (),
@@ -170,14 +183,15 @@ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
170
183
171
184
ChatCompletionRequest request = new ChatCompletionRequest (chatCompletionMessages , stream );
172
185
173
- if (this .defaultOptions != null ) {
174
- request = ModelOptionsUtils .merge (request , this .defaultOptions , ChatCompletionRequest .class );
175
- }
176
-
177
186
if (prompt .getOptions () != null ) {
178
187
if (prompt .getOptions () instanceof ChatOptions runtimeOptions ) {
179
188
OpenAiChatOptions updatedRuntimeOptions = ModelOptionsUtils .copyToTarget (runtimeOptions ,
180
189
ChatOptions .class , OpenAiChatOptions .class );
190
+
191
+ Set <String > promptEnabledFunctions = handleToolFunctionConfigurations (updatedRuntimeOptions , true ,
192
+ true );
193
+ enabledFunctionsForRequest .addAll (promptEnabledFunctions );
194
+
181
195
request = ModelOptionsUtils .merge (updatedRuntimeOptions , request , ChatCompletionRequest .class );
182
196
}
183
197
else {
@@ -186,7 +200,180 @@ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
186
200
}
187
201
}
188
202
203
+ if (this .defaultOptions != null ) {
204
+
205
+ Set <String > defaultEnabledFunctions = handleToolFunctionConfigurations (this .defaultOptions , false , false );
206
+
207
+ enabledFunctionsForRequest .addAll (defaultEnabledFunctions );
208
+
209
+ request = ModelOptionsUtils .merge (request , this .defaultOptions , ChatCompletionRequest .class );
210
+ }
211
+
212
+ // Add the enabled functions definitions to the request's tools parameter.
213
+ if (!CollectionUtils .isEmpty (enabledFunctionsForRequest )) {
214
+
215
+ if (stream ) {
216
+ throw new IllegalArgumentException ("Currently tool functions are not supported in streaming mode" );
217
+ }
218
+
219
+ request = ModelOptionsUtils .merge (
220
+ OpenAiChatOptions .builder ().withTools (this .getFunctionTools (enabledFunctionsForRequest )).build (),
221
+ request , ChatCompletionRequest .class );
222
+ }
223
+
189
224
return request ;
190
225
}
191
226
227
+ private Set <String > handleToolFunctionConfigurations (OpenAiChatOptions options , boolean autoEnableCallbackFunctions ,
228
+ boolean overrideCallbackFunctionsRegister ) {
229
+
230
+ Set <String > enabledFunctions = new HashSet <>();
231
+
232
+ if (options != null ) {
233
+ if (!CollectionUtils .isEmpty (options .getToolCallbacks ())) {
234
+ options .getToolCallbacks ().stream ().forEach (toolCallback -> {
235
+
236
+ // Register the tool callback.
237
+ if (overrideCallbackFunctionsRegister ) {
238
+ this .toolCallbackRegister .put (toolCallback .getName (), toolCallback );
239
+ }
240
+ else {
241
+ this .toolCallbackRegister .putIfAbsent (toolCallback .getName (), toolCallback );
242
+ }
243
+
244
+ // Automatically enable the function, usually from prompt callback.
245
+ if (autoEnableCallbackFunctions ) {
246
+ enabledFunctions .add (toolCallback .getName ());
247
+ }
248
+ });
249
+ }
250
+
251
+ // Add the explicitly enabled functions.
252
+ if (!CollectionUtils .isEmpty (options .getEnabledFunctions ())) {
253
+ enabledFunctions .addAll (options .getEnabledFunctions ());
254
+ }
255
+ }
256
+
257
+ return enabledFunctions ;
258
+ }
259
+
260
+ /**
261
+ * @return returns the registered tool callbacks.
262
+ */
263
+ Map <String , ToolFunctionCallback > getToolCallbackRegister () {
264
+ return toolCallbackRegister ;
265
+ }
266
+
267
+ public List <OpenAiApi .FunctionTool > getFunctionTools (Set <String > functionNames ) {
268
+
269
+ List <OpenAiApi .FunctionTool > functionTools = new ArrayList <>();
270
+ for (String functionName : functionNames ) {
271
+ if (!this .toolCallbackRegister .containsKey (functionName )) {
272
+ throw new IllegalStateException ("No function callback found for function name: " + functionName );
273
+ }
274
+ ToolFunctionCallback functionCallback = this .toolCallbackRegister .get (functionName );
275
+
276
+ var function = new OpenAiApi .FunctionTool .Function (functionCallback .getDescription (),
277
+ functionCallback .getName (), functionCallback .getInputTypeSchema ());
278
+ functionTools .add (new OpenAiApi .FunctionTool (function ));
279
+ }
280
+
281
+ return functionTools ;
282
+ }
283
+
284
+ /**
285
+ * Function Call handling. If the model calls a function, the function is called and
286
+ * the response is added to the conversation history. The conversation history is then
287
+ * sent back to the model.
288
+ * @param request the chat completion request
289
+ * @return the chat completion response.
290
+ */
291
+ @ SuppressWarnings ("null" )
292
+ private ResponseEntity <ChatCompletion > chatCompletionWithTools (OpenAiApi .ChatCompletionRequest request ) {
293
+
294
+ ResponseEntity <ChatCompletion > chatCompletion = this .openAiApi .chatCompletionEntity (request );
295
+
296
+ // Return the result if the model is not calling a function.
297
+ if (Boolean .FALSE .equals (this .isToolCall (chatCompletion ))) {
298
+ return chatCompletion ;
299
+ }
300
+
301
+ // The OpenAI chat completion tool call API requires the complete conversation
302
+ // history. Including the initial user message.
303
+ List <ChatCompletionMessage > conversationMessages = new ArrayList <>(request .messages ());
304
+
305
+ // We assume that the tool calling information is inside the response's first
306
+ // choice.
307
+ ChatCompletionMessage responseMessage = chatCompletion .getBody ().choices ().iterator ().next ().message ();
308
+
309
+ if (chatCompletion .getBody ().choices ().size () > 1 ) {
310
+ logger .warn ("More than one choice returned. Only the first choice is processed." );
311
+ }
312
+
313
+ // Add the assistant response to the message conversation history.
314
+ conversationMessages .add (responseMessage );
315
+
316
+ // Every tool-call item requires a separate function call and a response (TOOL)
317
+ // message.
318
+ for (ToolCall toolCall : responseMessage .toolCalls ()) {
319
+
320
+ var functionName = toolCall .function ().name ();
321
+ String functionArguments = toolCall .function ().arguments ();
322
+
323
+ if (!this .toolCallbackRegister .containsKey (functionName )) {
324
+ throw new IllegalStateException ("No function callback found for function name: " + functionName );
325
+ }
326
+
327
+ String functionResponse = this .toolCallbackRegister .get (functionName ).call (functionArguments );
328
+
329
+ // Add the function response to the conversation.
330
+ conversationMessages .add (new ChatCompletionMessage (functionResponse , Role .TOOL , null , toolCall .id (), null ));
331
+ }
332
+
333
+ // Recursively call chatCompletionWithTools until the model doesn't call a
334
+ // functions anymore.
335
+ ChatCompletionRequest newRequest = new ChatCompletionRequest (conversationMessages , request .stream ());
336
+ newRequest = ModelOptionsUtils .merge (newRequest , request , ChatCompletionRequest .class );
337
+
338
+ return this .chatCompletionWithTools (newRequest );
339
+ }
340
+
341
+ private Map <String , Object > toMap (ChatCompletionMessage message ) {
342
+ Map <String , Object > map = new HashMap <>();
343
+
344
+ // The tool_calls and tool_call_id are not used by the OpenAiChatClient functions
345
+ // call support! Useful only for users that want to use the tool_calls and
346
+ // tool_call_id in their applications.
347
+ if (message .toolCalls () != null ) {
348
+ map .put ("tool_calls" , message .toolCalls ());
349
+ }
350
+ if (message .toolCallId () != null ) {
351
+ map .put ("tool_call_id" , message .toolCallId ());
352
+ }
353
+
354
+ if (message .role () != null ) {
355
+ map .put ("role" , message .role ().name ());
356
+ }
357
+ return map ;
358
+ }
359
+
360
+ /**
361
+ * Check if it is a model calls function response.
362
+ * @param chatCompletion the chat completion response.
363
+ * @return true if the model expects a function call.
364
+ */
365
+ private Boolean isToolCall (ResponseEntity <ChatCompletion > chatCompletion ) {
366
+ var body = chatCompletion .getBody ();
367
+ if (body == null ) {
368
+ return false ;
369
+ }
370
+
371
+ var choices = body .choices ();
372
+ if (CollectionUtils .isEmpty (choices )) {
373
+ return false ;
374
+ }
375
+
376
+ return choices .get (0 ).message ().toolCalls () != null ;
377
+ }
378
+
192
379
}
0 commit comments