|
15 | 15 | */
|
16 | 16 | package org.springframework.ai.vertexai.gemini.function;
|
17 | 17 |
|
18 |
| -import java.util.ArrayList; |
19 |
| -import java.util.List; |
20 |
| -import java.util.stream.Collectors; |
21 |
| - |
22 | 18 | import com.google.cloud.vertexai.Transport;
|
23 | 19 | import com.google.cloud.vertexai.VertexAI;
|
24 | 20 | import org.junit.jupiter.api.AfterEach;
|
25 | 21 | import org.junit.jupiter.api.Test;
|
26 | 22 | import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
27 | 23 | import org.slf4j.Logger;
|
28 | 24 | import org.slf4j.LoggerFactory;
|
29 |
| -import reactor.core.publisher.Flux; |
30 |
| - |
31 | 25 | import org.springframework.ai.chat.ChatResponse;
|
32 | 26 | import org.springframework.ai.chat.Generation;
|
33 | 27 | import org.springframework.ai.chat.messages.AssistantMessage;
|
|
42 | 36 | import org.springframework.boot.SpringBootConfiguration;
|
43 | 37 | import org.springframework.boot.test.context.SpringBootTest;
|
44 | 38 | import org.springframework.context.annotation.Bean;
|
| 39 | +import reactor.core.publisher.Flux; |
| 40 | + |
| 41 | +import java.util.ArrayList; |
| 42 | +import java.util.List; |
| 43 | +import java.util.function.Function; |
| 44 | +import java.util.stream.Collectors; |
45 | 45 |
|
46 | 46 | import static org.assertj.core.api.Assertions.assertThat;
|
| 47 | +import static org.junit.jupiter.api.Assertions.assertNotNull; |
47 | 48 |
|
48 | 49 | @SpringBootTest
|
49 | 50 | @EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_PROJECT_ID", matches = ".*")
|
@@ -178,6 +179,48 @@ public void functionCallTestInferredOpenApiSchemaStream() {
|
178 | 179 |
|
179 | 180 | }
|
180 | 181 |
|
| 182 | + //Gemini wants single tool with multiple function, instead multiple tools with single function |
| 183 | + @Test |
| 184 | + public void canDeclareMultipleFunctions() { |
| 185 | + |
| 186 | + UserMessage userMessage = new UserMessage( |
| 187 | + "What's the weather like in San Francisco, in Paris and in Tokyo, Japan? Use Multi-turn function calling. Provide answer for all requested locations."); |
| 188 | + |
| 189 | + List<Message> messages = new ArrayList<>(List.of(userMessage)); |
| 190 | + |
| 191 | + final var weatherFunction = FunctionCallbackWrapper.builder(new MockWeatherService()) |
| 192 | + .withSchemaType(SchemaType.OPEN_API_SCHEMA) |
| 193 | + .withName("getCurrentWeather") |
| 194 | + .withDescription("Get the current weather in a given location") |
| 195 | + .build(); |
| 196 | + final var theAnswer = FunctionCallbackWrapper.builder(new TheAnswerMock()) |
| 197 | + .withSchemaType(SchemaType.OPEN_API_SCHEMA) |
| 198 | + .withName("theAnswerToTheUniverse") |
| 199 | + .withDescription("the answer to the ultimate question of life, the universe, and everything") |
| 200 | + .build(); |
| 201 | + var promptOptions = VertexAiGeminiChatOptions.builder() |
| 202 | + .withModel(VertexAiGeminiChatClient.ChatModel.GEMINI_PRO.getValue()) |
| 203 | + .withFunctionCallbacks(List.of(weatherFunction, theAnswer)) |
| 204 | + .build(); |
| 205 | + |
| 206 | + ChatResponse response = vertexGeminiClient.call(new Prompt(messages, promptOptions)); |
| 207 | + |
| 208 | + String responseString = response.getResult().getOutput().getContent(); |
| 209 | + |
| 210 | + logger.info("Response: {}", responseString); |
| 211 | + assertNotNull(responseString); |
| 212 | + |
| 213 | + } |
| 214 | + |
| 215 | + public static class TheAnswerMock implements Function<String, Integer> { |
| 216 | + |
| 217 | + @Override |
| 218 | + public Integer apply(String s) { |
| 219 | + return 42; |
| 220 | + } |
| 221 | + |
| 222 | + } |
| 223 | + |
181 | 224 | @SpringBootConfiguration
|
182 | 225 | public static class TestConfiguration {
|
183 | 226 |
|
|
0 commit comments