From 90eb41a2779546caf712e3f4051b6ff9f53447b6 Mon Sep 17 00:00:00 2001 From: Thingersoft Date: Mon, 26 Feb 2024 20:07:28 +0100 Subject: [PATCH] Provide the ability to configure Open AI client read timeout --- .../openai/OpenAiAutoConfiguration.java | 87 +++++++++++++------ .../openai/OpenAiConnectionProperties.java | 5 ++ .../openai/OpenAiParentProperties.java | 12 +++ .../openai/OpenAiPropertiesTests.java | 28 ++++++ 4 files changed, 105 insertions(+), 27 deletions(-) diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java index 724d2e20d90..bf6216d2404 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java @@ -16,6 +16,9 @@ package org.springframework.ai.autoconfigure.openai; +import java.time.Duration; +import java.util.List; + import org.springframework.ai.embedding.EmbeddingClient; import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallbackContext; @@ -29,15 +32,17 @@ import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; import org.springframework.boot.context.properties.EnableConfigurationProperties; +import org.springframework.boot.ssl.SslBundle; +import org.springframework.boot.web.client.ClientHttpRequestFactories; +import org.springframework.boot.web.client.ClientHttpRequestFactorySettings; import org.springframework.context.ApplicationContext; import org.springframework.context.annotation.Bean; +import org.springframework.http.client.ClientHttpRequestFactory; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; import org.springframework.web.client.RestClient; -import java.util.List; - @AutoConfiguration(after = { RestClientAutoConfiguration.class }) @ConditionalOnClass(OpenAiApi.class) @EnableConfigurationProperties({ OpenAiConnectionProperties.class, OpenAiChatProperties.class, @@ -51,22 +56,21 @@ public class OpenAiAutoConfiguration { public static final String OPEN_AI_BASE_URL_MUST_BE_SET = "OpenAI base URL must be set"; + public static final String OPEN_AI_READ_TIMEOUT_MUST_BE_SET = "OpenAI base read timeout must be set"; + @Bean @ConditionalOnMissingBean public OpenAiChatClient openAiChatClient(OpenAiConnectionProperties commonProperties, OpenAiChatProperties chatProperties, RestClient.Builder restClientBuilder, List toolFunctionCallbacks, FunctionCallbackContext functionCallbackContext) { - String apiKey = StringUtils.hasText(chatProperties.getApiKey()) ? chatProperties.getApiKey() - : commonProperties.getApiKey(); - - String baseUrl = StringUtils.hasText(chatProperties.getBaseUrl()) ? chatProperties.getBaseUrl() - : commonProperties.getBaseUrl(); - - Assert.hasText(apiKey, OPEN_AI_API_KEY_MUST_BE_SET); - Assert.hasText(baseUrl, OPEN_AI_BASE_URL_MUST_BE_SET); + OpenAiConnectionProperties overridenCommonProperties = checkAndOverrideProperties(commonProperties, + chatProperties); + RestClient.Builder overridenRestClientBuilder = overrideRestClientBuilder(restClientBuilder, + overridenCommonProperties); - var openAiApi = new OpenAiApi(baseUrl, apiKey, restClientBuilder); + var openAiApi = new OpenAiApi(overridenCommonProperties.getBaseUrl(), overridenCommonProperties.getApiKey(), + overridenRestClientBuilder); if (!CollectionUtils.isEmpty(toolFunctionCallbacks)) { chatProperties.getOptions().getFunctionCallbacks().addAll(toolFunctionCallbacks); @@ -80,15 +84,13 @@ public OpenAiChatClient openAiChatClient(OpenAiConnectionProperties commonProper public EmbeddingClient openAiEmbeddingClient(OpenAiConnectionProperties commonProperties, OpenAiEmbeddingProperties embeddingProperties, RestClient.Builder restClientBuilder) { - String apiKey = StringUtils.hasText(embeddingProperties.getApiKey()) ? embeddingProperties.getApiKey() - : commonProperties.getApiKey(); - String baseUrl = StringUtils.hasText(embeddingProperties.getBaseUrl()) ? embeddingProperties.getBaseUrl() - : commonProperties.getBaseUrl(); + OpenAiConnectionProperties overridenCommonProperties = checkAndOverrideProperties(commonProperties, + embeddingProperties); + RestClient.Builder overridenRestClientBuilder = overrideRestClientBuilder(restClientBuilder, + overridenCommonProperties); - Assert.hasText(apiKey, OPEN_AI_API_KEY_MUST_BE_SET); - Assert.hasText(baseUrl, OPEN_AI_BASE_URL_MUST_BE_SET); - - var openAiApi = new OpenAiApi(baseUrl, apiKey, restClientBuilder); + var openAiApi = new OpenAiApi(overridenCommonProperties.getBaseUrl(), overridenCommonProperties.getApiKey(), + overridenRestClientBuilder); return new OpenAiEmbeddingClient(openAiApi, embeddingProperties.getMetadataMode(), embeddingProperties.getOptions()); @@ -98,16 +100,14 @@ public EmbeddingClient openAiEmbeddingClient(OpenAiConnectionProperties commonPr @ConditionalOnMissingBean public OpenAiImageClient openAiImageClient(OpenAiConnectionProperties commonProperties, OpenAiImageProperties imageProperties, RestClient.Builder restClientBuilder) { - String apiKey = StringUtils.hasText(imageProperties.getApiKey()) ? imageProperties.getApiKey() - : commonProperties.getApiKey(); - - String baseUrl = StringUtils.hasText(imageProperties.getBaseUrl()) ? imageProperties.getBaseUrl() - : commonProperties.getBaseUrl(); - Assert.hasText(apiKey, OPEN_AI_API_KEY_MUST_BE_SET); - Assert.hasText(baseUrl, OPEN_AI_BASE_URL_MUST_BE_SET); + OpenAiConnectionProperties overridenCommonProperties = checkAndOverrideProperties(commonProperties, + imageProperties); + RestClient.Builder overridenRestClientBuilder = overrideRestClientBuilder(restClientBuilder, + overridenCommonProperties); - var openAiImageApi = new OpenAiImageApi(baseUrl, apiKey, restClientBuilder); + var openAiImageApi = new OpenAiImageApi(overridenCommonProperties.getBaseUrl(), + overridenCommonProperties.getApiKey(), overridenRestClientBuilder); return new OpenAiImageClient(openAiImageApi).withDefaultOptions(imageProperties.getOptions()); } @@ -120,4 +120,37 @@ public FunctionCallbackContext springAiFunctionManager(ApplicationContext contex return manager; } + private static OpenAiConnectionProperties checkAndOverrideProperties( + OpenAiConnectionProperties commonProperties, T specificProperties) { + + String apiKey = StringUtils.hasText(specificProperties.getApiKey()) ? specificProperties.getApiKey() + : commonProperties.getApiKey(); + + String baseUrl = StringUtils.hasText(specificProperties.getBaseUrl()) ? specificProperties.getBaseUrl() + : commonProperties.getBaseUrl(); + + Duration readTimeout = specificProperties.getReadTimeout() != null ? specificProperties.getReadTimeout() + : commonProperties.getReadTimeout(); + + Assert.hasText(apiKey, OPEN_AI_API_KEY_MUST_BE_SET); + Assert.hasText(baseUrl, OPEN_AI_BASE_URL_MUST_BE_SET); + Assert.notNull(readTimeout, OPEN_AI_READ_TIMEOUT_MUST_BE_SET); + + OpenAiConnectionProperties overridenCommonProperties = new OpenAiConnectionProperties(); + overridenCommonProperties.setApiKey(apiKey); + overridenCommonProperties.setBaseUrl(baseUrl); + overridenCommonProperties.setReadTimeout(readTimeout); + + return overridenCommonProperties; + + } + + private static RestClient.Builder overrideRestClientBuilder(RestClient.Builder restClientBuilder, + OpenAiConnectionProperties overridenCommonProperties) { + ClientHttpRequestFactorySettings requestFactorySettings = new ClientHttpRequestFactorySettings( + Duration.ofHours(1l), overridenCommonProperties.getReadTimeout(), SslBundle.of(null)); + ClientHttpRequestFactory requestFactory = ClientHttpRequestFactories.get(requestFactorySettings); + return restClientBuilder.clone().requestFactory(requestFactory); + } + } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiConnectionProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiConnectionProperties.java index ebdb74f50f0..9a9f6d53cad 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiConnectionProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiConnectionProperties.java @@ -16,6 +16,8 @@ package org.springframework.ai.autoconfigure.openai; +import java.time.Duration; + import org.springframework.boot.context.properties.ConfigurationProperties; @ConfigurationProperties(OpenAiConnectionProperties.CONFIG_PREFIX) @@ -25,8 +27,11 @@ public class OpenAiConnectionProperties extends OpenAiParentProperties { public static final String DEFAULT_BASE_URL = "https://api.openai.com"; + public static final Duration DEFAULT_READ_TIMEOUT = Duration.ofMinutes(1); + public OpenAiConnectionProperties() { super.setBaseUrl(DEFAULT_BASE_URL); + super.setReadTimeout(DEFAULT_READ_TIMEOUT); } } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiParentProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiParentProperties.java index 1250e3698a9..b268e1fd5ec 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiParentProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiParentProperties.java @@ -16,6 +16,8 @@ package org.springframework.ai.autoconfigure.openai; +import java.time.Duration; + /** * Internal parent properties for the OpenAI properties. * @@ -28,6 +30,8 @@ class OpenAiParentProperties { private String baseUrl; + private Duration readTimeout; + public String getApiKey() { return apiKey; } @@ -44,4 +48,12 @@ public void setBaseUrl(String baseUrl) { this.baseUrl = baseUrl; } + public Duration getReadTimeout() { + return readTimeout; + } + + public void setReadTimeout(Duration readTimeout) { + this.readTimeout = readTimeout; + } + } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiPropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiPropertiesTests.java index 41a1147166f..57dc3242de5 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiPropertiesTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiPropertiesTests.java @@ -16,6 +16,7 @@ package org.springframework.ai.autoconfigure.openai; +import java.time.Duration; import java.util.Map; import org.junit.jupiter.api.Test; @@ -46,6 +47,7 @@ public void chatProperties() { // @formatter:off "spring.ai.openai.base-url=TEST_BASE_URL", "spring.ai.openai.api-key=abc123", + "spring.ai.openai.read-timeout=2m", "spring.ai.openai.chat.options.model=MODEL_XYZ", "spring.ai.openai.chat.options.temperature=0.55") // @formatter:on @@ -56,9 +58,11 @@ public void chatProperties() { assertThat(connectionProperties.getApiKey()).isEqualTo("abc123"); assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL"); + assertThat(connectionProperties.getReadTimeout()).isEqualTo(Duration.ofMinutes(2)); assertThat(chatProperties.getApiKey()).isNull(); assertThat(chatProperties.getBaseUrl()).isNull(); + assertThat(chatProperties.getReadTimeout()).isNull(); assertThat(chatProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ"); assertThat(chatProperties.getOptions().getTemperature()).isEqualTo(0.55f); @@ -72,8 +76,10 @@ public void chatOverrideConnectionProperties() { // @formatter:off "spring.ai.openai.base-url=TEST_BASE_URL", "spring.ai.openai.api-key=abc123", + "spring.ai.openai.read-timeout=2m", "spring.ai.openai.chat.base-url=TEST_BASE_URL2", "spring.ai.openai.chat.api-key=456", + "spring.ai.openai.chat.read-timeout=5m", "spring.ai.openai.chat.options.model=MODEL_XYZ", "spring.ai.openai.chat.options.temperature=0.55") // @formatter:on @@ -84,9 +90,11 @@ public void chatOverrideConnectionProperties() { assertThat(connectionProperties.getApiKey()).isEqualTo("abc123"); assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL"); + assertThat(connectionProperties.getReadTimeout()).isEqualTo(Duration.ofMinutes(2)); assertThat(chatProperties.getApiKey()).isEqualTo("456"); assertThat(chatProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL2"); + assertThat(chatProperties.getReadTimeout()).isEqualTo(Duration.ofMinutes(5)); assertThat(chatProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ"); assertThat(chatProperties.getOptions().getTemperature()).isEqualTo(0.55f); @@ -100,6 +108,7 @@ public void embeddingProperties() { // @formatter:off "spring.ai.openai.base-url=TEST_BASE_URL", "spring.ai.openai.api-key=abc123", + "spring.ai.openai.read-timeout=2m", "spring.ai.openai.embedding.options.model=MODEL_XYZ") // @formatter:on .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class)) @@ -109,9 +118,11 @@ public void embeddingProperties() { assertThat(connectionProperties.getApiKey()).isEqualTo("abc123"); assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL"); + assertThat(connectionProperties.getReadTimeout()).isEqualTo(Duration.ofMinutes(2)); assertThat(embeddingProperties.getApiKey()).isNull(); assertThat(embeddingProperties.getBaseUrl()).isNull(); + assertThat(embeddingProperties.getReadTimeout()).isNull(); assertThat(embeddingProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ"); }); @@ -124,8 +135,10 @@ public void embeddingOverrideConnectionProperties() { // @formatter:off "spring.ai.openai.base-url=TEST_BASE_URL", "spring.ai.openai.api-key=abc123", + "spring.ai.openai.read-timeout=2m", "spring.ai.openai.embedding.base-url=TEST_BASE_URL2", "spring.ai.openai.embedding.api-key=456", + "spring.ai.openai.embedding.read-timeout=5m", "spring.ai.openai.embedding.options.model=MODEL_XYZ") // @formatter:on .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class)) @@ -135,9 +148,11 @@ public void embeddingOverrideConnectionProperties() { assertThat(connectionProperties.getApiKey()).isEqualTo("abc123"); assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL"); + assertThat(connectionProperties.getReadTimeout()).isEqualTo(Duration.ofMinutes(2)); assertThat(embeddingProperties.getApiKey()).isEqualTo("456"); assertThat(embeddingProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL2"); + assertThat(embeddingProperties.getReadTimeout()).isEqualTo(Duration.ofMinutes(5)); assertThat(embeddingProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ"); }); @@ -149,6 +164,7 @@ public void imageProperties() { // @formatter:off "spring.ai.openai.base-url=TEST_BASE_URL", "spring.ai.openai.api-key=abc123", + "spring.ai.openai.read-timeout=2m", "spring.ai.openai.image.options.model=MODEL_XYZ", "spring.ai.openai.image.options.n=3") // @formatter:on @@ -159,9 +175,11 @@ public void imageProperties() { assertThat(connectionProperties.getApiKey()).isEqualTo("abc123"); assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL"); + assertThat(connectionProperties.getReadTimeout()).isEqualTo(Duration.ofMinutes(2)); assertThat(imageProperties.getApiKey()).isNull(); assertThat(imageProperties.getBaseUrl()).isNull(); + assertThat(imageProperties.getReadTimeout()).isNull(); assertThat(imageProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ"); assertThat(imageProperties.getOptions().getN()).isEqualTo(3); @@ -174,8 +192,10 @@ public void imageOverrideConnectionProperties() { // @formatter:off "spring.ai.openai.base-url=TEST_BASE_URL", "spring.ai.openai.api-key=abc123", + "spring.ai.openai.read-timeout=2m", "spring.ai.openai.image.base-url=TEST_BASE_URL2", "spring.ai.openai.image.api-key=456", + "spring.ai.openai.image.read-timeout=5m", "spring.ai.openai.image.options.model=MODEL_XYZ", "spring.ai.openai.image.options.n=3") // @formatter:on @@ -186,9 +206,11 @@ public void imageOverrideConnectionProperties() { assertThat(connectionProperties.getApiKey()).isEqualTo("abc123"); assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL"); + assertThat(connectionProperties.getReadTimeout()).isEqualTo(Duration.ofMinutes(2)); assertThat(imageProperties.getApiKey()).isEqualTo("456"); assertThat(imageProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL2"); + assertThat(imageProperties.getReadTimeout()).isEqualTo(Duration.ofMinutes(5)); assertThat(imageProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ"); assertThat(imageProperties.getOptions().getN()).isEqualTo(3); @@ -202,6 +224,7 @@ public void chatOptionsTest() { // @formatter:off "spring.ai.openai.api-key=API_KEY", "spring.ai.openai.base-url=TEST_BASE_URL", + "spring.ai.openai.read-timeout=2m", "spring.ai.openai.chat.options.model=MODEL_XYZ", "spring.ai.openai.chat.options.frequencyPenalty=-1.5", @@ -254,6 +277,7 @@ public void chatOptionsTest() { assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL"); assertThat(connectionProperties.getApiKey()).isEqualTo("API_KEY"); + assertThat(connectionProperties.getReadTimeout()).isEqualTo(Duration.ofMinutes(2)); assertThat(embeddingProperties.getOptions().getModel()).isEqualTo("text-embedding-ada-002"); @@ -290,6 +314,7 @@ public void embeddingOptionsTest() { // @formatter:off "spring.ai.openai.api-key=API_KEY", "spring.ai.openai.base-url=TEST_BASE_URL", + "spring.ai.openai.read-timeout=2m", "spring.ai.openai.embedding.options.model=MODEL_XYZ", "spring.ai.openai.embedding.options.encodingFormat=MyEncodingFormat", @@ -303,6 +328,7 @@ public void embeddingOptionsTest() { assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL"); assertThat(connectionProperties.getApiKey()).isEqualTo("API_KEY"); + assertThat(connectionProperties.getReadTimeout()).isEqualTo(Duration.ofMinutes(2)); assertThat(embeddingProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ"); assertThat(embeddingProperties.getOptions().getEncodingFormat()).isEqualTo("MyEncodingFormat"); @@ -316,6 +342,7 @@ public void imageOptionsTest() { // @formatter:off "spring.ai.openai.api-key=API_KEY", "spring.ai.openai.base-url=TEST_BASE_URL", + "spring.ai.openai.read-timeout=2m", "spring.ai.openai.image.options.n=3", "spring.ai.openai.image.options.model=MODEL_XYZ", @@ -335,6 +362,7 @@ public void imageOptionsTest() { assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL"); assertThat(connectionProperties.getApiKey()).isEqualTo("API_KEY"); + assertThat(connectionProperties.getReadTimeout()).isEqualTo(Duration.ofMinutes(2)); assertThat(imageProperties.getOptions().getN()).isEqualTo(3); assertThat(imageProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ");