Skip to content

Commit e268975

Browse files
committed
Fix Geminie GenerativeModel handling between calls
Resolves #560
1 parent 2b421a4 commit e268975

File tree

4 files changed

+19
-5
lines changed

4 files changed

+19
-5
lines changed

models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatClient.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,6 @@ public class VertexAiGeminiChatClient
7878

7979
private final GenerationConfig generationConfig;
8080

81-
private GenerativeModel generativeModel;
82-
8381
public enum GeminiMessageType {
8482

8583
USER("user"),
@@ -140,7 +138,6 @@ public VertexAiGeminiChatClient(VertexAI vertexAI, VertexAiGeminiChatOptions opt
140138
this.vertexAI = vertexAI;
141139
this.defaultOptions = options;
142140
this.generationConfig = toGenerationConfig(options);
143-
this.generativeModel = new GenerativeModel(options.getModel(), vertexAI);
144141
}
145142

146143
// https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini
@@ -204,7 +201,8 @@ private GeminiRequest createGeminiRequest(Prompt prompt) {
204201
Set<String> functionsForThisRequest = new HashSet<>();
205202

206203
GenerationConfig generationConfig = this.generationConfig;
207-
GenerativeModel generativeModel = this.generativeModel;
204+
205+
GenerativeModel generativeModel = new GenerativeModel(this.defaultOptions.getModel(), this.vertexAI);
208206

209207
VertexAiGeminiChatOptions updatedRuntimeOptions = null;
210208

models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/VertexAiGeminiChatClientFunctionCallingIT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ public void functionCallTestInferredOpenApiSchema() {
146146
public void functionCallTestInferredOpenApiSchemaStream() {
147147

148148
UserMessage userMessage = new UserMessage(
149-
"What's the weather like in San Francisco, in Paris and in Tokyo? Use Multi-turn function calling.");
149+
"What's the weather like in San Francisco, in Paris and in Tokyo, Japan? Use Multi-turn function calling. Provide answer for all requested locations.");
150150

151151
List<Message> messages = new ArrayList<>(List.of(userMessage));
152152

spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/FunctionCallWithFunctionBeanIT.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,13 @@ void functionCallTest() {
8383

8484
assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15");
8585

86+
response = chatClient
87+
.call(new Prompt(List.of(systemMessage, userMessage), VertexAiGeminiChatOptions.builder().build()));
88+
89+
logger.info("Response: {}", response);
90+
91+
assertThat(response.getResult().getOutput().getContent()).doesNotContain("30", "10", "15");
92+
8693
});
8794
}
8895

spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/FunctionCallWithPromptFunctionIT.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,15 @@ void functionCallTest() {
7777
logger.info("Response: {}", response);
7878

7979
assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15");
80+
81+
// Verify that no function call is made.
82+
response = chatClient
83+
.call(new Prompt(List.of(systemMessage, userMessage), VertexAiGeminiChatOptions.builder().build()));
84+
85+
logger.info("Response: {}", response);
86+
87+
assertThat(response.getResult().getOutput().getContent()).doesNotContain("30", "10", "15");
88+
8089
});
8190
}
8291

0 commit comments

Comments
 (0)