Skip to content

Commit 65d42c9

Browse files
committed
Add full support or Vertex AI Gemini and Azure OpenAI function calling
- Extend the Spring AI Message with getMediaData() : List<MediaData> MediaData is a pair of MimeType and data of type Object. Message#getContent() return text only. - VertexAI Gemini Support - implement VertexAiGeiminChatClient for ChatClient and StreamingChat client and support for MediaData content. add IT tests for Chat, Streaming and Multimodality - add Auto-configuration + ITs - add Gemini Spring Boot starter. - add clients and boot starters to the Spring AI BOM. - add Anotra documentation for the Gemini chat client. - update gemini to latest 26.33.0 BOM. - add vertex ai gemini dependencies to the BOM. - add Vertex AI Gemini API Function Calling support - add Gemini API Function Calling Streaming support - add vertex ai gemini function calling documentation - factor out the Function Calling functionality into common abstraction used by OpenAI, Azure and Gemini. - group the VertexAI documentation under a common parent - add PortableFunctionCallingOption that implements FunctionCallingOptions and ChatOptions and provide builder for it. - remove some deprecated code. - allow authorization with GoogleCredentials form json file. - add AOT support for VertexAI Gemini. - move legacy Vertex AI into VertexAI PaLM2. - better handling for empty chat responses. - update the Gemini version to latest 26.33.0. This required lifting the protobuf-java to 3.25.2 as well. - fix a bug for handling System messages with Gemini. - Implement Azure OpenAI Function Calling Uses the same the common abstractions used by OpenAI and Gemini: FunctionCallingOptions and AbstractFunctionCallSupport
1 parent 3580849 commit 65d42c9

File tree

98 files changed

+4877
-377
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

98 files changed

+4877
-377
lines changed

models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatClient.java

Lines changed: 130 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,35 +17,49 @@
1717
package org.springframework.ai.azure.openai;
1818

1919
import java.util.Collections;
20+
import java.util.HashSet;
2021
import java.util.List;
22+
import java.util.Set;
2123

2224
import com.azure.ai.openai.OpenAIClient;
2325
import com.azure.ai.openai.models.ChatChoice;
2426
import com.azure.ai.openai.models.ChatCompletions;
27+
import com.azure.ai.openai.models.ChatCompletionsFunctionToolCall;
28+
import com.azure.ai.openai.models.ChatCompletionsFunctionToolDefinition;
2529
import com.azure.ai.openai.models.ChatCompletionsOptions;
30+
import com.azure.ai.openai.models.ChatCompletionsToolCall;
31+
import com.azure.ai.openai.models.ChatCompletionsToolDefinition;
2632
import com.azure.ai.openai.models.ChatRequestAssistantMessage;
2733
import com.azure.ai.openai.models.ChatRequestMessage;
2834
import com.azure.ai.openai.models.ChatRequestSystemMessage;
35+
import com.azure.ai.openai.models.ChatRequestToolMessage;
2936
import com.azure.ai.openai.models.ChatRequestUserMessage;
37+
import com.azure.ai.openai.models.ChatResponseMessage;
38+
import com.azure.ai.openai.models.CompletionsFinishReason;
3039
import com.azure.ai.openai.models.ContentFilterResultsForPrompt;
40+
import com.azure.ai.openai.models.FunctionDefinition;
41+
import com.azure.core.util.BinaryData;
3142
import com.azure.core.util.IterableStream;
3243
import org.slf4j.Logger;
3344
import org.slf4j.LoggerFactory;
3445
import reactor.core.publisher.Flux;
3546

3647
import org.springframework.ai.azure.openai.metadata.AzureOpenAiChatResponseMetadata;
3748
import org.springframework.ai.chat.ChatClient;
38-
import org.springframework.ai.chat.prompt.ChatOptions;
3949
import org.springframework.ai.chat.ChatResponse;
4050
import org.springframework.ai.chat.Generation;
4151
import org.springframework.ai.chat.StreamingChatClient;
4252
import org.springframework.ai.chat.messages.Message;
4353
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
4454
import org.springframework.ai.chat.metadata.PromptMetadata;
4555
import org.springframework.ai.chat.metadata.PromptMetadata.PromptFilterMetadata;
56+
import org.springframework.ai.chat.prompt.ChatOptions;
4657
import org.springframework.ai.chat.prompt.Prompt;
4758
import org.springframework.ai.model.ModelOptionsUtils;
59+
import org.springframework.ai.model.function.AbstractFunctionCallSupport;
60+
import org.springframework.ai.model.function.FunctionCallbackContext;
4861
import org.springframework.util.Assert;
62+
import org.springframework.util.CollectionUtils;
4963

5064
/**
5165
* {@link ChatClient} implementation for {@literal Microsoft Azure AI} backed by
@@ -58,7 +72,9 @@
5872
* @see ChatClient
5973
* @see com.azure.ai.openai.OpenAIClient
6074
*/
61-
public class AzureOpenAiChatClient implements ChatClient, StreamingChatClient {
75+
public class AzureOpenAiChatClient
76+
extends AbstractFunctionCallSupport<ChatRequestMessage, ChatCompletionsOptions, ChatCompletions>
77+
implements ChatClient, StreamingChatClient {
6278

6379
private static final String DEFAULT_MODEL = "gpt-35-turbo";
6480

@@ -82,6 +98,12 @@ public AzureOpenAiChatClient(OpenAIClient microsoftOpenAiClient) {
8298
}
8399

84100
public AzureOpenAiChatClient(OpenAIClient microsoftOpenAiClient, AzureOpenAiChatOptions options) {
101+
this(microsoftOpenAiClient, options, null);
102+
}
103+
104+
public AzureOpenAiChatClient(OpenAIClient microsoftOpenAiClient, AzureOpenAiChatOptions options,
105+
FunctionCallbackContext functionCallbackContext) {
106+
super(functionCallbackContext);
85107
Assert.notNull(microsoftOpenAiClient, "com.azure.ai.openai.OpenAIClient must not be null");
86108
Assert.notNull(options, "AzureOpenAiChatOptions must not be null");
87109
this.openAIClient = microsoftOpenAiClient;
@@ -100,7 +122,7 @@ public AzureOpenAiChatClient withDefaultOptions(AzureOpenAiChatOptions defaultOp
100122
}
101123

102124
public AzureOpenAiChatOptions getDefaultOptions() {
103-
return defaultOptions;
125+
return this.defaultOptions;
104126
}
105127

106128
@Override
@@ -111,7 +133,10 @@ public ChatResponse call(Prompt prompt) {
111133

112134
logger.trace("Azure ChatCompletionsOptions: {}", options);
113135

114-
ChatCompletions chatCompletions = this.openAIClient.getChatCompletions(options.getModel(), options);
136+
ChatCompletions chatCompletions = this.callWithFunctionSupport(options);
137+
138+
// ChatCompletions chatCompletions =
139+
// this.openAIClient.getChatCompletions(options.getModel(), options);
115140

116141
logger.trace("Azure ChatCompletions: {}", chatCompletions);
117142

@@ -154,6 +179,8 @@ public Flux<ChatResponse> stream(Prompt prompt) {
154179
*/
155180
ChatCompletionsOptions toAzureChatCompletionsOptions(Prompt prompt) {
156181

182+
Set<String> functionsForThisRequest = new HashSet<>();
183+
157184
List<ChatRequestMessage> azureMessages = prompt.getInstructions()
158185
.stream()
159186
.map(this::fromSpringAiMessage)
@@ -167,6 +194,10 @@ ChatCompletionsOptions toAzureChatCompletionsOptions(Prompt prompt) {
167194
// options = ModelOptionsUtils.merge(options, this.defaultOptions,
168195
// ChatCompletionsOptions.class);
169196
options = merge(options, this.defaultOptions);
197+
198+
Set<String> defaultEnabledFunctions = this.handleFunctionCallbackConfigurations(this.defaultOptions,
199+
!IS_RUNTIME_CALL);
200+
functionsForThisRequest.addAll(defaultEnabledFunctions);
170201
}
171202

172203
if (prompt.getOptions() != null) {
@@ -178,16 +209,43 @@ ChatCompletionsOptions toAzureChatCompletionsOptions(Prompt prompt) {
178209
// options = ModelOptionsUtils.merge(runtimeOptions, options,
179210
// ChatCompletionsOptions.class);
180211
options = merge(updatedRuntimeOptions, options);
212+
213+
Set<String> promptEnabledFunctions = this.handleFunctionCallbackConfigurations(updatedRuntimeOptions,
214+
IS_RUNTIME_CALL);
215+
functionsForThisRequest.addAll(promptEnabledFunctions);
216+
181217
}
182218
else {
183219
throw new IllegalArgumentException("Prompt options are not of type ChatCompletionsOptions:"
184220
+ prompt.getOptions().getClass().getSimpleName());
185221
}
186222
}
187223

224+
// Add the enabled functions definitions to the request's tools parameter.
225+
226+
if (!CollectionUtils.isEmpty(functionsForThisRequest)) {
227+
List<ChatCompletionsFunctionToolDefinition> tools = this.getFunctionTools(functionsForThisRequest);
228+
List<ChatCompletionsToolDefinition> tools2 = tools.stream()
229+
.map(t -> ((ChatCompletionsToolDefinition) t))
230+
.toList();
231+
options.setTools(tools2);
232+
}
233+
188234
return options;
189235
}
190236

237+
private List<ChatCompletionsFunctionToolDefinition> getFunctionTools(Set<String> functionNames) {
238+
return this.resolveFunctionCallbacks(functionNames).stream().map(functionCallback -> {
239+
240+
FunctionDefinition functionDefinition = new FunctionDefinition(functionCallback.getName());
241+
functionDefinition.setDescription(functionCallback.getDescription());
242+
BinaryData parameters = BinaryData
243+
.fromObject(ModelOptionsUtils.jsonToMap(functionCallback.getInputTypeSchema()));
244+
functionDefinition.setParameters(parameters);
245+
return new ChatCompletionsFunctionToolDefinition(functionDefinition);
246+
}).toList();
247+
}
248+
191249
private ChatRequestMessage fromSpringAiMessage(Message message) {
192250

193251
switch (message.getMessageType()) {
@@ -281,6 +339,8 @@ private ChatCompletionsOptions merge(AzureOpenAiChatOptions springAiOptions, Cha
281339
ChatCompletionsOptions mergedAzureOptions = new ChatCompletionsOptions(azureOptions.getMessages());
282340
mergedAzureOptions = merge(azureOptions, mergedAzureOptions);
283341

342+
mergedAzureOptions.setStream(azureOptions.isStream());
343+
284344
if (springAiOptions.getMaxTokens() != null) {
285345
mergedAzureOptions.setMaxTokens(springAiOptions.getMaxTokens());
286346
}
@@ -324,6 +384,8 @@ private ChatCompletionsOptions merge(AzureOpenAiChatOptions springAiOptions, Cha
324384
return mergedAzureOptions;
325385
}
326386

387+
// https://github.com/Azure/azure-sdk-for-java/blob/azure-ai-openai_1.0.0-beta.6/sdk/openai/azure-ai-openai/src/samples/java/com/azure/ai/openai/usage/GetChatCompletionsToolCallSample.java
388+
327389
private ChatCompletionsOptions merge(ChatCompletionsOptions fromOptions, ChatCompletionsOptions toOptions) {
328390

329391
if (fromOptions == null) {
@@ -367,4 +429,68 @@ private ChatCompletionsOptions merge(ChatCompletionsOptions fromOptions, ChatCom
367429
return mergedOptions;
368430
}
369431

432+
@Override
433+
protected ChatCompletionsOptions doCreateToolResponseRequest(ChatCompletionsOptions previousRequest,
434+
ChatRequestMessage responseMessage, List<ChatRequestMessage> conversationHistory) {
435+
436+
// Every tool-call item requires a separate function call and a response (TOOL)
437+
// message.
438+
for (ChatCompletionsToolCall toolCall : ((ChatRequestAssistantMessage) responseMessage).getToolCalls()) {
439+
440+
var functionName = ((ChatCompletionsFunctionToolCall) toolCall).getFunction().getName();
441+
String functionArguments = ((ChatCompletionsFunctionToolCall) toolCall).getFunction().getArguments();
442+
443+
if (!this.functionCallbackRegister.containsKey(functionName)) {
444+
throw new IllegalStateException("No function callback found for function name: " + functionName);
445+
}
446+
447+
String functionResponse = this.functionCallbackRegister.get(functionName).call(functionArguments);
448+
449+
// Add the function response to the conversation.
450+
conversationHistory.add(new ChatRequestToolMessage(functionResponse, toolCall.getId()));
451+
}
452+
453+
// Recursively call chatCompletionWithTools until the model doesn't call a
454+
// functions anymore.
455+
ChatCompletionsOptions newRequest = new ChatCompletionsOptions(conversationHistory);
456+
457+
newRequest = merge(previousRequest, newRequest);
458+
459+
return newRequest;
460+
}
461+
462+
@Override
463+
protected List<ChatRequestMessage> doGetUserMessages(ChatCompletionsOptions request) {
464+
return request.getMessages();
465+
}
466+
467+
@Override
468+
protected ChatRequestMessage doGetToolResponseMessage(ChatCompletions response) {
469+
ChatResponseMessage responseMessage = response.getChoices().get(0).getMessage();
470+
ChatRequestAssistantMessage assistantMessage = new ChatRequestAssistantMessage("");
471+
assistantMessage.setToolCalls(responseMessage.getToolCalls());
472+
return assistantMessage;
473+
}
474+
475+
@Override
476+
protected ChatCompletions doChatCompletion(ChatCompletionsOptions request) {
477+
return this.openAIClient.getChatCompletions(request.getModel(), request);
478+
}
479+
480+
@Override
481+
protected boolean isToolFunctionCall(ChatCompletions chatCompletions) {
482+
483+
if (chatCompletions == null || CollectionUtils.isEmpty(chatCompletions.getChoices())) {
484+
return false;
485+
}
486+
487+
var choice = chatCompletions.getChoices().get(0);
488+
489+
if (choice == null || choice.getFinishReason() == null) {
490+
return false;
491+
}
492+
493+
return choice.getFinishReason() == CompletionsFinishReason.TOOL_CALLS;
494+
}
495+
370496
}

models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,22 @@
1616

1717
package org.springframework.ai.azure.openai;
1818

19+
import java.util.ArrayList;
20+
import java.util.HashSet;
1921
import java.util.List;
2022
import java.util.Map;
23+
import java.util.Set;
2124

2225
import com.fasterxml.jackson.annotation.JsonIgnore;
2326
import com.fasterxml.jackson.annotation.JsonInclude;
2427
import com.fasterxml.jackson.annotation.JsonInclude.Include;
2528
import com.fasterxml.jackson.annotation.JsonProperty;
2629

2730
import org.springframework.ai.chat.prompt.ChatOptions;
31+
import org.springframework.ai.model.function.FunctionCallback;
32+
import org.springframework.ai.model.function.FunctionCallingOptions;
33+
import org.springframework.boot.context.properties.NestedConfigurationProperty;
34+
import org.springframework.util.Assert;
2835

2936
/**
3037
* The configuration information for a chat completions request. Completions support a
@@ -34,7 +41,7 @@
3441
* @author Christian Tzolov
3542
*/
3643
@JsonInclude(Include.NON_NULL)
37-
public class AzureOpenAiChatOptions implements ChatOptions {
44+
public class AzureOpenAiChatOptions implements FunctionCallingOptions, ChatOptions {
3845

3946
/**
4047
* The maximum number of tokens to generate.
@@ -121,6 +128,32 @@ public class AzureOpenAiChatOptions implements ChatOptions {
121128
@JsonProperty(value = "model")
122129
private String model;
123130

131+
/**
132+
* OpenAI Tool Function Callbacks to register with the ChatClient. For Prompt Options
133+
* the functionCallbacks are automatically enabled for the duration of the prompt
134+
* execution. For Default Options the functionCallbacks are registered but disabled by
135+
* default. Use the enableFunctions to set the functions from the registry to be used
136+
* by the ChatClient chat completion requests.
137+
*/
138+
@NestedConfigurationProperty
139+
@JsonIgnore
140+
private List<FunctionCallback> functionCallbacks = new ArrayList<>();
141+
142+
/**
143+
* List of functions, identified by their names, to configure for function calling in
144+
* the chat completion requests. Functions with those names must exist in the
145+
* functionCallbacks registry. The {@link #functionCallbacks} from the PromptOptions
146+
* are automatically enabled for the duration of the prompt execution.
147+
*
148+
* Note that function enabled with the default options are enabled for all chat
149+
* completion requests. This could impact the token count and the billing. If the
150+
* functions is set in a prompt options, then the enabled functions are only active
151+
* for the duration of this prompt execution.
152+
*/
153+
@NestedConfigurationProperty
154+
@JsonIgnore
155+
private Set<String> functions = new HashSet<>();
156+
124157
public static Builder builder() {
125158
return new Builder();
126159
}
@@ -187,6 +220,23 @@ public Builder withUser(String user) {
187220
return this;
188221
}
189222

223+
public Builder withFunctionCallbacks(List<FunctionCallback> functionCallbacks) {
224+
this.options.functionCallbacks = functionCallbacks;
225+
return this;
226+
}
227+
228+
public Builder withFunctions(Set<String> functionNames) {
229+
Assert.notNull(functionNames, "Function names must not be null");
230+
this.options.functions = functionNames;
231+
return this;
232+
}
233+
234+
public Builder withFunction(String functionName) {
235+
Assert.hasText(functionName, "Function name must not be empty");
236+
this.options.functions.add(functionName);
237+
return this;
238+
}
239+
190240
public AzureOpenAiChatOptions build() {
191241
return this.options;
192242
}
@@ -289,4 +339,24 @@ public void setTopK(Integer topK) {
289339
throw new UnsupportedOperationException("Unimplemented method 'setTopK'");
290340
}
291341

342+
@Override
343+
public List<FunctionCallback> getFunctionCallbacks() {
344+
return this.functionCallbacks;
345+
}
346+
347+
@Override
348+
public void setFunctionCallbacks(List<FunctionCallback> functionCallbacks) {
349+
this.functionCallbacks = functionCallbacks;
350+
}
351+
352+
@Override
353+
public Set<String> getFunctions() {
354+
return this.functions;
355+
}
356+
357+
@Override
358+
public void setFunctions(Set<String> functions) {
359+
this.functions = functions;
360+
}
361+
292362
}

0 commit comments

Comments
 (0)