Skip to content

Commit ffd8222

Browse files
tzolovmarkpollack
authored andcommitted
OpenAI ChatClient tools support
* Add function call OpenAiApi IT tests * Add function call docs. Unit and IT func tests * Add function call diagram * Allow to opt-in/enable the fuctions to be used in request.
1 parent 254b863 commit ffd8222

File tree

22 files changed

+1417
-36
lines changed

22 files changed

+1417
-36
lines changed

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java

Lines changed: 194 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,12 @@
1616
package org.springframework.ai.openai;
1717

1818
import java.time.Duration;
19+
import java.util.ArrayList;
20+
import java.util.HashMap;
21+
import java.util.HashSet;
1922
import java.util.List;
2023
import java.util.Map;
24+
import java.util.Set;
2125
import java.util.concurrent.ConcurrentHashMap;
2226

2327
import org.slf4j.Logger;
@@ -33,9 +37,12 @@
3337
import org.springframework.ai.chat.metadata.RateLimit;
3438
import org.springframework.ai.chat.prompt.Prompt;
3539
import org.springframework.ai.model.ModelOptionsUtils;
40+
import org.springframework.ai.model.ToolFunctionCallback;
3641
import org.springframework.ai.openai.api.OpenAiApi;
3742
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion;
3843
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage;
44+
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.Role;
45+
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.ToolCall;
3946
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest;
4047
import org.springframework.ai.openai.api.OpenAiApi.OpenAiApiException;
4148
import org.springframework.ai.openai.metadata.OpenAiChatResponseMetadata;
@@ -46,6 +53,7 @@
4653
import org.springframework.retry.RetryListener;
4754
import org.springframework.retry.support.RetryTemplate;
4855
import org.springframework.util.Assert;
56+
import org.springframework.util.CollectionUtils;
4957

5058
/**
5159
* {@link ChatClient} implementation for {@literal OpenAI} backed by {@link OpenAiApi}.
@@ -66,11 +74,14 @@ public class OpenAiChatClient implements ChatClient, StreamingChatClient {
6674

6775
private OpenAiChatOptions defaultOptions;
6876

77+
private Map<String, ToolFunctionCallback> toolCallbackRegister = new ConcurrentHashMap<>();
78+
6979
public final RetryTemplate retryTemplate = RetryTemplate.builder()
7080
.maxAttempts(10)
7181
.retryOn(OpenAiApiException.class)
7282
.exponentialBackoff(Duration.ofMillis(2000), 5, Duration.ofMillis(3 * 60000))
7383
.withListener(new RetryListener() {
84+
@Override
7485
public <T extends Object, E extends Throwable> void onError(RetryContext context,
7586
RetryCallback<T, E> callback, Throwable throwable) {
7687
logger.warn("Retry error. Retry count:" + context.getRetryCount(), throwable);
@@ -108,18 +119,18 @@ public ChatResponse call(Prompt prompt) {
108119

109120
ChatCompletionRequest request = createRequest(prompt, false);
110121

111-
ResponseEntity<ChatCompletion> completionEntity = this.openAiApi.chatCompletionEntity(request);
122+
ResponseEntity<ChatCompletion> completionEntity = this.chatCompletionWithTools(request);
112123

113124
var chatCompletion = completionEntity.getBody();
114125
if (chatCompletion == null) {
115-
logger.warn("No chat completion returned for request: {}", prompt);
126+
logger.warn("No chat completion returned for prompt: {}", prompt);
116127
return new ChatResponse(List.of());
117128
}
118129

119130
RateLimit rateLimits = OpenAiResponseHeaderExtractor.extractAiResponseHeaders(completionEntity);
120131

121132
List<Generation> generations = chatCompletion.choices().stream().map(choice -> {
122-
return new Generation(choice.message().content(), Map.of("role", choice.message().role().name()))
133+
return new Generation(choice.message().content(), toMap(choice.message()))
123134
.withGenerationMetadata(ChatGenerationMetadata.from(choice.finishReason().name(), null));
124135
}).toList();
125136

@@ -162,6 +173,8 @@ public Flux<ChatResponse> stream(Prompt prompt) {
162173
*/
163174
ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
164175

176+
Set<String> enabledFunctionsForRequest = new HashSet<>();
177+
165178
List<ChatCompletionMessage> chatCompletionMessages = prompt.getInstructions()
166179
.stream()
167180
.map(m -> new ChatCompletionMessage(m.getContent(),
@@ -170,14 +183,15 @@ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
170183

171184
ChatCompletionRequest request = new ChatCompletionRequest(chatCompletionMessages, stream);
172185

173-
if (this.defaultOptions != null) {
174-
request = ModelOptionsUtils.merge(request, this.defaultOptions, ChatCompletionRequest.class);
175-
}
176-
177186
if (prompt.getOptions() != null) {
178187
if (prompt.getOptions() instanceof ChatOptions runtimeOptions) {
179188
OpenAiChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions,
180189
ChatOptions.class, OpenAiChatOptions.class);
190+
191+
Set<String> promptEnabledFunctions = handleToolFunctionConfigurations(updatedRuntimeOptions, true,
192+
true);
193+
enabledFunctionsForRequest.addAll(promptEnabledFunctions);
194+
181195
request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, ChatCompletionRequest.class);
182196
}
183197
else {
@@ -186,7 +200,180 @@ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
186200
}
187201
}
188202

203+
if (this.defaultOptions != null) {
204+
205+
Set<String> defaultEnabledFunctions = handleToolFunctionConfigurations(this.defaultOptions, false, false);
206+
207+
enabledFunctionsForRequest.addAll(defaultEnabledFunctions);
208+
209+
request = ModelOptionsUtils.merge(request, this.defaultOptions, ChatCompletionRequest.class);
210+
}
211+
212+
// Add the enabled functions definitions to the request's tools parameter.
213+
if (!CollectionUtils.isEmpty(enabledFunctionsForRequest)) {
214+
215+
if (stream) {
216+
throw new IllegalArgumentException("Currently tool functions are not supported in streaming mode");
217+
}
218+
219+
request = ModelOptionsUtils.merge(
220+
OpenAiChatOptions.builder().withTools(this.getFunctionTools(enabledFunctionsForRequest)).build(),
221+
request, ChatCompletionRequest.class);
222+
}
223+
189224
return request;
190225
}
191226

227+
private Set<String> handleToolFunctionConfigurations(OpenAiChatOptions options, boolean autoEnableCallbackFunctions,
228+
boolean overrideCallbackFunctionsRegister) {
229+
230+
Set<String> enabledFunctions = new HashSet<>();
231+
232+
if (options != null) {
233+
if (!CollectionUtils.isEmpty(options.getToolCallbacks())) {
234+
options.getToolCallbacks().stream().forEach(toolCallback -> {
235+
236+
// Register the tool callback.
237+
if (overrideCallbackFunctionsRegister) {
238+
this.toolCallbackRegister.put(toolCallback.getName(), toolCallback);
239+
}
240+
else {
241+
this.toolCallbackRegister.putIfAbsent(toolCallback.getName(), toolCallback);
242+
}
243+
244+
// Automatically enable the function, usually from prompt callback.
245+
if (autoEnableCallbackFunctions) {
246+
enabledFunctions.add(toolCallback.getName());
247+
}
248+
});
249+
}
250+
251+
// Add the explicitly enabled functions.
252+
if (!CollectionUtils.isEmpty(options.getEnabledFunctions())) {
253+
enabledFunctions.addAll(options.getEnabledFunctions());
254+
}
255+
}
256+
257+
return enabledFunctions;
258+
}
259+
260+
/**
261+
* @return returns the registered tool callbacks.
262+
*/
263+
Map<String, ToolFunctionCallback> getToolCallbackRegister() {
264+
return toolCallbackRegister;
265+
}
266+
267+
public List<OpenAiApi.FunctionTool> getFunctionTools(Set<String> functionNames) {
268+
269+
List<OpenAiApi.FunctionTool> functionTools = new ArrayList<>();
270+
for (String functionName : functionNames) {
271+
if (!this.toolCallbackRegister.containsKey(functionName)) {
272+
throw new IllegalStateException("No function callback found for function name: " + functionName);
273+
}
274+
ToolFunctionCallback functionCallback = this.toolCallbackRegister.get(functionName);
275+
276+
var function = new OpenAiApi.FunctionTool.Function(functionCallback.getDescription(),
277+
functionCallback.getName(), functionCallback.getInputTypeSchema());
278+
functionTools.add(new OpenAiApi.FunctionTool(function));
279+
}
280+
281+
return functionTools;
282+
}
283+
284+
/**
285+
* Function Call handling. If the model calls a function, the function is called and
286+
* the response is added to the conversation history. The conversation history is then
287+
* sent back to the model.
288+
* @param request the chat completion request
289+
* @return the chat completion response.
290+
*/
291+
@SuppressWarnings("null")
292+
private ResponseEntity<ChatCompletion> chatCompletionWithTools(OpenAiApi.ChatCompletionRequest request) {
293+
294+
ResponseEntity<ChatCompletion> chatCompletion = this.openAiApi.chatCompletionEntity(request);
295+
296+
// Return the result if the model is not calling a function.
297+
if (Boolean.FALSE.equals(this.isToolCall(chatCompletion))) {
298+
return chatCompletion;
299+
}
300+
301+
// The OpenAI chat completion tool call API requires the complete conversation
302+
// history. Including the initial user message.
303+
List<ChatCompletionMessage> conversationMessages = new ArrayList<>(request.messages());
304+
305+
// We assume that the tool calling information is inside the response's first
306+
// choice.
307+
ChatCompletionMessage responseMessage = chatCompletion.getBody().choices().iterator().next().message();
308+
309+
if (chatCompletion.getBody().choices().size() > 1) {
310+
logger.warn("More than one choice returned. Only the first choice is processed.");
311+
}
312+
313+
// Add the assistant response to the message conversation history.
314+
conversationMessages.add(responseMessage);
315+
316+
// Every tool-call item requires a separate function call and a response (TOOL)
317+
// message.
318+
for (ToolCall toolCall : responseMessage.toolCalls()) {
319+
320+
var functionName = toolCall.function().name();
321+
String functionArguments = toolCall.function().arguments();
322+
323+
if (!this.toolCallbackRegister.containsKey(functionName)) {
324+
throw new IllegalStateException("No function callback found for function name: " + functionName);
325+
}
326+
327+
String functionResponse = this.toolCallbackRegister.get(functionName).call(functionArguments);
328+
329+
// Add the function response to the conversation.
330+
conversationMessages.add(new ChatCompletionMessage(functionResponse, Role.TOOL, null, toolCall.id(), null));
331+
}
332+
333+
// Recursively call chatCompletionWithTools until the model doesn't call a
334+
// functions anymore.
335+
ChatCompletionRequest newRequest = new ChatCompletionRequest(conversationMessages, request.stream());
336+
newRequest = ModelOptionsUtils.merge(newRequest, request, ChatCompletionRequest.class);
337+
338+
return this.chatCompletionWithTools(newRequest);
339+
}
340+
341+
private Map<String, Object> toMap(ChatCompletionMessage message) {
342+
Map<String, Object> map = new HashMap<>();
343+
344+
// The tool_calls and tool_call_id are not used by the OpenAiChatClient functions
345+
// call support! Useful only for users that want to use the tool_calls and
346+
// tool_call_id in their applications.
347+
if (message.toolCalls() != null) {
348+
map.put("tool_calls", message.toolCalls());
349+
}
350+
if (message.toolCallId() != null) {
351+
map.put("tool_call_id", message.toolCallId());
352+
}
353+
354+
if (message.role() != null) {
355+
map.put("role", message.role().name());
356+
}
357+
return map;
358+
}
359+
360+
/**
361+
* Check if it is a model calls function response.
362+
* @param chatCompletion the chat completion response.
363+
* @return true if the model expects a function call.
364+
*/
365+
private Boolean isToolCall(ResponseEntity<ChatCompletion> chatCompletion) {
366+
var body = chatCompletion.getBody();
367+
if (body == null) {
368+
return false;
369+
}
370+
371+
var choices = body.choices();
372+
if (CollectionUtils.isEmpty(choices)) {
373+
return false;
374+
}
375+
376+
return choices.get(0).message().toolCalls() != null;
377+
}
378+
192379
}

0 commit comments

Comments
 (0)