Skip to content

Commit 710bb05

Browse files
tzolovmarkpollack
authored andcommitted
feat(vertex-ai): Refactor Gemini for new tool calling API
- Migrate from function calling to tool calling API - Add support for Gemini 2.0 models (flash, flash-lite) - Implement JSON schema to OpenAPI schema conversion - Add builder pattern for improved configuration - Deprecate legacy function calling constructors and methods - Update default model to GEMINI_2_0_FLASH - Add comprehensive test coverage for tool calling - Upgrade victools dependency to 4.37.0 - Update the Vertex Tool calling docs Part of the #2207 epic Signed-off-by: Christian Tzolov <christian.tzolov@broadcom.com>
1 parent 4d692a5 commit 710bb05

28 files changed

+1850
-541
lines changed

models/spring-ai-vertex-ai-gemini/pom.xml

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
-->
1717

1818
<project xmlns="http://maven.apache.org/POM/4.0.0"
19-
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
19+
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
20+
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
2021
<modelVersion>4.0.0</modelVersion>
2122
<parent>
2223
<groupId>org.springframework.ai</groupId>
@@ -53,6 +54,17 @@
5354

5455
<dependencies>
5556

57+
<dependency>
58+
<groupId>com.github.victools</groupId>
59+
<artifactId>jsonschema-generator</artifactId>
60+
<version>${victools.version}</version>
61+
</dependency>
62+
<dependency>
63+
<groupId>com.github.victools</groupId>
64+
<artifactId>jsonschema-module-jackson</artifactId>
65+
<version>${victools.version}</version>
66+
</dependency>
67+
5668
<dependency>
5769
<groupId>com.google.cloud</groupId>
5870
<artifactId>google-cloud-vertexai</artifactId>

models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java

Lines changed: 313 additions & 97 deletions
Large diffs are not rendered by default.

models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java

Lines changed: 115 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2023-2024 the original author or authors.
2+
* Copyright 2023-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -29,11 +29,12 @@
2929
import com.fasterxml.jackson.annotation.JsonInclude.Include;
3030
import com.fasterxml.jackson.annotation.JsonProperty;
3131

32-
import org.springframework.ai.chat.prompt.ChatOptions;
3332
import org.springframework.ai.model.function.FunctionCallback;
34-
import org.springframework.ai.model.function.FunctionCallingOptions;
33+
import org.springframework.ai.model.tool.ToolCallingChatOptions;
34+
import org.springframework.ai.tool.ToolCallback;
3535
import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel.ChatModel;
3636
import org.springframework.ai.vertexai.gemini.common.VertexAiGeminiSafetySetting;
37+
import org.springframework.lang.Nullable;
3738
import org.springframework.util.Assert;
3839

3940
/**
@@ -46,7 +47,7 @@
4647
* @since 1.0.0
4748
*/
4849
@JsonInclude(Include.NON_NULL)
49-
public class VertexAiGeminiChatOptions implements FunctionCallingOptions {
50+
public class VertexAiGeminiChatOptions implements ToolCallingChatOptions {
5051

5152
// https://cloud.google.com/vertex-ai/docs/reference/rest/v1/GenerationConfig
5253

@@ -95,40 +96,36 @@ public class VertexAiGeminiChatOptions implements FunctionCallingOptions {
9596
private @JsonProperty("responseMimeType") String responseMimeType;
9697

9798
/**
98-
* Tool Function Callbacks to register with the ChatModel.
99-
* For Prompt Options the functionCallbacks are automatically enabled for the duration of the prompt execution.
100-
* For Default Options the functionCallbacks are registered but disabled by default. Use the enableFunctions to set the functions
101-
* from the registry to be used by the ChatModel chat completion requests.
99+
* Collection of {@link ToolCallback}s to be used for tool calling in the chat
100+
* completion requests.
102101
*/
103102
@JsonIgnore
104-
private List<FunctionCallback> functionCallbacks = new ArrayList<>();
103+
private List<FunctionCallback> toolCallbacks = new ArrayList<>();
105104

106105
/**
107-
* List of functions, identified by their names, to configure for function calling in
108-
* the chat completion requests.
109-
* Functions with those names must exist in the functionCallbacks registry.
110-
* The {@link #functionCallbacks} from the PromptOptions are automatically enabled for the duration of the prompt execution.
111-
*
112-
* Note that function enabled with the default options are enabled for all chat completion requests. This could impact the token count and the billing.
113-
* If the functions is set in a prompt options, then the enabled functions are only active for the duration of this prompt execution.
106+
* Collection of tool names to be resolved at runtime and used for tool calling in the
107+
* chat completion requests.
114108
*/
115109
@JsonIgnore
116-
private Set<String> functions = new HashSet<>();
110+
private Set<String> toolNames = new HashSet<>();
117111

118112
/**
119-
* Use Google search Grounding feature
113+
* Whether to enable the tool execution lifecycle internally in ChatModel.
120114
*/
121115
@JsonIgnore
122-
private boolean googleSearchRetrieval = false;
116+
private Boolean internalToolExecutionEnabled;
123117

124118
@JsonIgnore
125-
private List<VertexAiGeminiSafetySetting> safetySettings = new ArrayList<>();
119+
private Map<String, Object> toolContext = new HashMap<>();
126120

121+
/**
122+
* Use Google search Grounding feature
123+
*/
127124
@JsonIgnore
128-
private Boolean proxyToolCalls;
125+
private Boolean googleSearchRetrieval = false;
129126

130127
@JsonIgnore
131-
private Map<String, Object> toolContext;
128+
private List<VertexAiGeminiSafetySetting> safetySettings = new ArrayList<>();
132129

133130
public static Builder builder() {
134131
return new Builder();
@@ -145,13 +142,13 @@ public static VertexAiGeminiChatOptions fromOptions(VertexAiGeminiChatOptions fr
145142
options.setCandidateCount(fromOptions.getCandidateCount());
146143
options.setMaxOutputTokens(fromOptions.getMaxOutputTokens());
147144
options.setModel(fromOptions.getModel());
148-
options.setFunctionCallbacks(fromOptions.getFunctionCallbacks());
145+
options.setToolCallbacks(fromOptions.getToolCallbacks());
149146
options.setResponseMimeType(fromOptions.getResponseMimeType());
150-
options.setFunctions(fromOptions.getFunctions());
147+
options.setToolNames(fromOptions.getToolNames());
151148
options.setResponseMimeType(fromOptions.getResponseMimeType());
152149
options.setGoogleSearchRetrieval(fromOptions.getGoogleSearchRetrieval());
153150
options.setSafetySettings(fromOptions.getSafetySettings());
154-
options.setProxyToolCalls(fromOptions.getProxyToolCalls());
151+
options.setInternalToolExecutionEnabled(fromOptions.isInternalToolExecutionEnabled());
155152
options.setToolContext(fromOptions.getToolContext());
156153
return options;
157154
}
@@ -236,20 +233,67 @@ public void setResponseMimeType(String mimeType) {
236233
this.responseMimeType = mimeType;
237234
}
238235

236+
@Override
237+
@JsonIgnore
238+
@Deprecated
239239
public List<FunctionCallback> getFunctionCallbacks() {
240-
return this.functionCallbacks;
240+
return this.getToolCallbacks();
241241
}
242242

243+
@Override
244+
@JsonIgnore
245+
@Deprecated
243246
public void setFunctionCallbacks(List<FunctionCallback> functionCallbacks) {
244-
this.functionCallbacks = functionCallbacks;
247+
this.setToolCallbacks(functionCallbacks);
245248
}
246249

250+
@Override
251+
public List<FunctionCallback> getToolCallbacks() {
252+
return this.toolCallbacks;
253+
}
254+
255+
@Override
256+
public void setToolCallbacks(List<FunctionCallback> toolCallbacks) {
257+
Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
258+
Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements");
259+
this.toolCallbacks = toolCallbacks;
260+
}
261+
262+
@Override
263+
@JsonIgnore
264+
@Deprecated
247265
public Set<String> getFunctions() {
248-
return this.functions;
266+
return this.getToolNames();
249267
}
250268

269+
@JsonIgnore
270+
@Deprecated
251271
public void setFunctions(Set<String> functions) {
252-
this.functions = functions;
272+
this.setToolNames(functions);
273+
}
274+
275+
@Override
276+
public Set<String> getToolNames() {
277+
return this.toolNames;
278+
}
279+
280+
@Override
281+
public void setToolNames(Set<String> toolNames) {
282+
Assert.notNull(toolNames, "toolNames cannot be null");
283+
Assert.noNullElements(toolNames, "toolNames cannot contain null elements");
284+
toolNames.forEach(tool -> Assert.hasText(tool, "toolNames cannot contain empty elements"));
285+
this.toolNames = toolNames;
286+
}
287+
288+
@Override
289+
@Nullable
290+
public Boolean isInternalToolExecutionEnabled() {
291+
return internalToolExecutionEnabled;
292+
}
293+
294+
@Override
295+
public void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecutionEnabled) {
296+
this.internalToolExecutionEnabled = internalToolExecutionEnabled;
253297
}
254298

255299
@Override
@@ -264,11 +308,11 @@ public Double getPresencePenalty() {
264308
return null;
265309
}
266310

267-
public boolean getGoogleSearchRetrieval() {
311+
public Boolean getGoogleSearchRetrieval() {
268312
return this.googleSearchRetrieval;
269313
}
270314

271-
public void setGoogleSearchRetrieval(boolean googleSearchRetrieval) {
315+
public void setGoogleSearchRetrieval(Boolean googleSearchRetrieval) {
272316
this.googleSearchRetrieval = googleSearchRetrieval;
273317
}
274318

@@ -281,13 +325,17 @@ public void setSafetySettings(List<VertexAiGeminiSafetySetting> safetySettings)
281325
this.safetySettings = safetySettings;
282326
}
283327

328+
@Deprecated
284329
@Override
330+
@JsonIgnore
285331
public Boolean getProxyToolCalls() {
286-
return this.proxyToolCalls;
332+
return this.internalToolExecutionEnabled != null ? !this.internalToolExecutionEnabled : null;
287333
}
288334

335+
@Deprecated
336+
@JsonIgnore
289337
public void setProxyToolCalls(Boolean proxyToolCalls) {
290-
this.proxyToolCalls = proxyToolCalls;
338+
this.internalToolExecutionEnabled = proxyToolCalls != null ? !proxyToolCalls : null;
291339
}
292340

293341
@Override
@@ -314,96 +362,35 @@ public boolean equals(Object o) {
314362
&& Objects.equals(this.topK, that.topK) && Objects.equals(this.candidateCount, that.candidateCount)
315363
&& Objects.equals(this.maxOutputTokens, that.maxOutputTokens) && Objects.equals(this.model, that.model)
316364
&& Objects.equals(this.responseMimeType, that.responseMimeType)
317-
&& Objects.equals(this.functionCallbacks, that.functionCallbacks)
318-
&& Objects.equals(this.functions, that.functions)
365+
&& Objects.equals(this.toolCallbacks, that.toolCallbacks)
366+
&& Objects.equals(this.toolNames, that.toolNames)
319367
&& Objects.equals(this.safetySettings, that.safetySettings)
320-
&& Objects.equals(this.proxyToolCalls, that.proxyToolCalls)
368+
&& Objects.equals(this.internalToolExecutionEnabled, that.internalToolExecutionEnabled)
321369
&& Objects.equals(this.toolContext, that.toolContext);
322370
}
323371

324372
@Override
325373
public int hashCode() {
326374
return Objects.hash(this.stopSequences, this.temperature, this.topP, this.topK, this.candidateCount,
327-
this.maxOutputTokens, this.model, this.responseMimeType, this.functionCallbacks, this.functions,
328-
this.googleSearchRetrieval, this.safetySettings, this.proxyToolCalls, this.toolContext);
375+
this.maxOutputTokens, this.model, this.responseMimeType, this.toolCallbacks, this.toolNames,
376+
this.googleSearchRetrieval, this.safetySettings, this.internalToolExecutionEnabled, this.toolContext);
329377
}
330378

331379
@Override
332380
public String toString() {
333381
return "VertexAiGeminiChatOptions{" + "stopSequences=" + this.stopSequences + ", temperature="
334382
+ this.temperature + ", topP=" + this.topP + ", topK=" + this.topK + ", candidateCount="
335383
+ this.candidateCount + ", maxOutputTokens=" + this.maxOutputTokens + ", model='" + this.model + '\''
336-
+ ", responseMimeType='" + this.responseMimeType + '\'' + ", functionCallbacks="
337-
+ this.functionCallbacks + ", functions=" + this.functions + ", googleSearchRetrieval="
338-
+ this.googleSearchRetrieval + ", safetySettings=" + this.safetySettings + '}';
384+
+ ", responseMimeType='" + this.responseMimeType + '\'' + ", toolCallbacks=" + this.toolCallbacks
385+
+ ", toolNames=" + this.toolNames + ", googleSearchRetrieval=" + this.googleSearchRetrieval
386+
+ ", safetySettings=" + this.safetySettings + '}';
339387
}
340388

341389
@Override
342390
public VertexAiGeminiChatOptions copy() {
343391
return fromOptions(this);
344392
}
345393

346-
public FunctionCallingOptions merge(ChatOptions options) {
347-
VertexAiGeminiChatOptions.Builder builder = VertexAiGeminiChatOptions.builder();
348-
349-
// Merge chat-specific options
350-
builder.model(options.getModel() != null ? options.getModel() : this.getModel())
351-
.maxOutputTokens(options.getMaxTokens() != null ? options.getMaxTokens() : this.getMaxOutputTokens())
352-
.stopSequences(options.getStopSequences() != null ? options.getStopSequences() : this.getStopSequences())
353-
.temperature(options.getTemperature() != null ? options.getTemperature() : this.getTemperature())
354-
.topP(options.getTopP() != null ? options.getTopP() : this.getTopP())
355-
.topK(options.getTopK() != null ? options.getTopK() : this.getTopK());
356-
357-
// Try to get function-specific properties if options is a FunctionCallingOptions
358-
if (options instanceof FunctionCallingOptions functionOptions) {
359-
builder.proxyToolCalls(functionOptions.getProxyToolCalls() != null ? functionOptions.getProxyToolCalls()
360-
: this.proxyToolCalls);
361-
362-
Set<String> functions = new HashSet<>();
363-
if (this.functions != null) {
364-
functions.addAll(this.functions);
365-
}
366-
if (functionOptions.getFunctions() != null) {
367-
functions.addAll(functionOptions.getFunctions());
368-
}
369-
builder.functions(functions);
370-
371-
List<FunctionCallback> functionCallbacks = new ArrayList<>();
372-
if (this.functionCallbacks != null) {
373-
functionCallbacks.addAll(this.functionCallbacks);
374-
}
375-
if (functionOptions.getFunctionCallbacks() != null) {
376-
functionCallbacks.addAll(functionOptions.getFunctionCallbacks());
377-
}
378-
builder.functionCallbacks(functionCallbacks);
379-
380-
Map<String, Object> context = new HashMap<>();
381-
if (this.toolContext != null) {
382-
context.putAll(this.toolContext);
383-
}
384-
if (functionOptions.getToolContext() != null) {
385-
context.putAll(functionOptions.getToolContext());
386-
}
387-
builder.toolContext(context);
388-
}
389-
else {
390-
// If not a FunctionCallingOptions, preserve current function-specific
391-
// properties
392-
builder.proxyToolCalls(this.proxyToolCalls);
393-
builder.functions(this.functions != null ? new HashSet<>(this.functions) : null);
394-
builder.functionCallbacks(this.functionCallbacks != null ? new ArrayList<>(this.functionCallbacks) : null);
395-
builder.toolContext(this.toolContext != null ? new HashMap<>(this.toolContext) : null);
396-
}
397-
398-
// Preserve Vertex AI Gemini-specific properties
399-
builder.candidateCount(this.candidateCount)
400-
.responseMimeType(this.responseMimeType)
401-
.googleSearchRetrieval(this.googleSearchRetrieval)
402-
.safetySettings(this.safetySettings != null ? new ArrayList<>(this.safetySettings) : null);
403-
404-
return builder.build();
405-
}
406-
407394
public enum TransportType {
408395

409396
GRPC, REST
@@ -460,20 +447,35 @@ public Builder responseMimeType(String mimeType) {
460447
return this;
461448
}
462449

450+
@Deprecated
463451
public Builder functionCallbacks(List<FunctionCallback> functionCallbacks) {
464-
this.options.functionCallbacks = functionCallbacks;
452+
return toolCallbacks(functionCallbacks);
453+
}
454+
455+
public Builder toolCallbacks(List<FunctionCallback> toolCallbacks) {
456+
this.options.toolCallbacks = toolCallbacks;
465457
return this;
466458
}
467459

460+
@Deprecated
468461
public Builder functions(Set<String> functionNames) {
469-
Assert.notNull(functionNames, "Function names must not be null");
470-
this.options.functions = functionNames;
462+
return this.toolNames(functionNames);
463+
}
464+
465+
public Builder toolNames(Set<String> toolNames) {
466+
Assert.notNull(toolNames, "Function names must not be null");
467+
this.options.toolNames = toolNames;
471468
return this;
472469
}
473470

471+
@Deprecated
474472
public Builder function(String functionName) {
475-
Assert.hasText(functionName, "Function name must not be empty");
476-
this.options.functions.add(functionName);
473+
return this.toolName(functionName);
474+
}
475+
476+
public Builder toolName(String toolName) {
477+
Assert.hasText(toolName, "Function name must not be empty");
478+
this.options.toolNames.add(toolName);
477479
return this;
478480
}
479481

@@ -488,8 +490,13 @@ public Builder safetySettings(List<VertexAiGeminiSafetySetting> safetySettings)
488490
return this;
489491
}
490492

493+
@Deprecated
491494
public Builder proxyToolCalls(boolean proxyToolCalls) {
492-
this.options.proxyToolCalls = proxyToolCalls;
495+
return this.internalToolExecutionEnabled(proxyToolCalls);
496+
}
497+
498+
public Builder internalToolExecutionEnabled(boolean internalToolExecutionEnabled) {
499+
this.options.internalToolExecutionEnabled = internalToolExecutionEnabled;
493500
return this;
494501
}
495502

0 commit comments

Comments
 (0)