|
24 | 24 |
|
25 | 25 | import org.junit.jupiter.api.Test;
|
26 | 26 | import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
| 27 | +import org.junit.jupiter.params.ParameterizedTest; |
| 28 | +import org.junit.jupiter.params.provider.ValueSource; |
27 | 29 | import org.slf4j.Logger;
|
28 | 30 | import org.slf4j.LoggerFactory;
|
29 | 31 | import reactor.core.publisher.Flux;
|
@@ -243,33 +245,37 @@ void streamFunctionCallTest() {
|
243 | 245 | assertThat(content).containsAnyOf("15.0", "15");
|
244 | 246 | }
|
245 | 247 |
|
246 |
| - @Test |
247 |
| - void multiModalityEmbeddedImage() throws IOException { |
| 248 | + @ParameterizedTest(name = "{0} : {displayName} ") |
| 249 | + @ValueSource(strings = { "gpt-4-vision-preview", "gpt-4o" }) |
| 250 | + void multiModalityEmbeddedImage(String modelName) throws IOException { |
248 | 251 |
|
249 | 252 | byte[] imageData = new ClassPathResource("/test.png").getContentAsByteArray();
|
250 | 253 |
|
251 | 254 | var userMessage = new UserMessage("Explain what do you see on this picture?",
|
252 | 255 | List.of(new Media(MimeTypeUtils.IMAGE_PNG, imageData)));
|
253 | 256 |
|
254 |
| - ChatResponse response = chatClient.call(new Prompt(List.of(userMessage), |
255 |
| - OpenAiChatOptions.builder().withModel(OpenAiApi.ChatModel.GPT_4_VISION_PREVIEW.getValue()).build())); |
| 257 | + ChatResponse response = chatClient |
| 258 | + .call(new Prompt(List.of(userMessage), OpenAiChatOptions.builder().withModel(modelName).build())); |
256 | 259 |
|
257 | 260 | logger.info(response.getResult().getOutput().getContent());
|
258 |
| - assertThat(response.getResult().getOutput().getContent()).contains("bananas", "apple", "bowl"); |
| 261 | + assertThat(response.getResult().getOutput().getContent()).contains("bananas", "apple"); |
| 262 | + assertThat(response.getResult().getOutput().getContent()).containsAnyOf("bowl", "basket"); |
259 | 263 | }
|
260 | 264 |
|
261 |
| - @Test |
262 |
| - void multiModalityImageUrl() throws IOException { |
| 265 | + @ParameterizedTest(name = "{0} : {displayName} ") |
| 266 | + @ValueSource(strings = { "gpt-4-vision-preview", "gpt-4o" }) |
| 267 | + void multiModalityImageUrl(String modelName) throws IOException { |
263 | 268 |
|
264 | 269 | var userMessage = new UserMessage("Explain what do you see on this picture?",
|
265 | 270 | List.of(new Media(MimeTypeUtils.IMAGE_PNG,
|
266 | 271 | "https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/_images/multimodal.test.png")));
|
267 | 272 |
|
268 |
| - ChatResponse response = chatClient.call(new Prompt(List.of(userMessage), |
269 |
| - OpenAiChatOptions.builder().withModel(OpenAiApi.ChatModel.GPT_4_VISION_PREVIEW.getValue()).build())); |
| 273 | + ChatResponse response = chatClient |
| 274 | + .call(new Prompt(List.of(userMessage), OpenAiChatOptions.builder().withModel(modelName).build())); |
270 | 275 |
|
271 | 276 | logger.info(response.getResult().getOutput().getContent());
|
272 |
| - assertThat(response.getResult().getOutput().getContent()).contains("bananas", "apple", "bowl"); |
| 277 | + assertThat(response.getResult().getOutput().getContent()).contains("bananas", "apple"); |
| 278 | + assertThat(response.getResult().getOutput().getContent()).containsAnyOf("bowl", "basket"); |
273 | 279 | }
|
274 | 280 |
|
275 | 281 | @Test
|
|
0 commit comments