Skip to content

Commit 41a459a

Browse files
ThomasVitaletzolov
authored andcommitted
Add OpenAiImageClient auto-configuration
Fixes gh-289 Signed-off-by: Thomas Vitale <ThomasVitale@users.noreply.github.com>
1 parent 6cf2f86 commit 41a459a

File tree

4 files changed

+176
-3
lines changed

4 files changed

+176
-3
lines changed

spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2023 the original author or authors.
2+
* Copyright 2023-2024 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -20,7 +20,9 @@
2020
import org.springframework.ai.embedding.EmbeddingClient;
2121
import org.springframework.ai.openai.OpenAiChatClient;
2222
import org.springframework.ai.openai.OpenAiEmbeddingClient;
23+
import org.springframework.ai.openai.OpenAiImageClient;
2324
import org.springframework.ai.openai.api.OpenAiApi;
25+
import org.springframework.ai.openai.api.OpenAiImageApi;
2426
import org.springframework.boot.autoconfigure.AutoConfiguration;
2527
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
2628
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
@@ -34,7 +36,7 @@
3436
@AutoConfiguration
3537
@ConditionalOnClass(OpenAiApi.class)
3638
@EnableConfigurationProperties({ OpenAiConnectionProperties.class, OpenAiChatProperties.class,
37-
OpenAiEmbeddingProperties.class })
39+
OpenAiEmbeddingProperties.class, OpenAiImageProperties.class })
3840
@ImportRuntimeHints(NativeHints.class)
3941
public class OpenAiAutoConfiguration {
4042

@@ -78,4 +80,22 @@ public EmbeddingClient openAiEmbeddingClient(OpenAiConnectionProperties commonPr
7880
return new OpenAiEmbeddingClient(openAiApi).withDefaultOptions(embeddingProperties.getOptions());
7981
}
8082

83+
@Bean
84+
@ConditionalOnMissingBean
85+
public OpenAiImageClient openAiImageClient(OpenAiConnectionProperties commonProperties,
86+
OpenAiImageProperties imageProperties) {
87+
String apiKey = StringUtils.hasText(imageProperties.getApiKey()) ? imageProperties.getApiKey()
88+
: commonProperties.getApiKey();
89+
90+
String baseUrl = StringUtils.hasText(imageProperties.getBaseUrl()) ? imageProperties.getBaseUrl()
91+
: commonProperties.getBaseUrl();
92+
93+
Assert.hasText(apiKey, "OpenAI API key must be set");
94+
Assert.hasText(baseUrl, "OpenAI base URL must be set");
95+
96+
var openAiImageApi = new OpenAiImageApi(baseUrl, apiKey, RestClient.builder());
97+
98+
return new OpenAiImageClient(openAiImageApi).withDefaultOptions(imageProperties.getOptions());
99+
}
100+
81101
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
/*
2+
* Copyright 2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.autoconfigure.openai;
18+
19+
import org.springframework.ai.openai.OpenAiImageOptions;
20+
import org.springframework.boot.context.properties.ConfigurationProperties;
21+
import org.springframework.boot.context.properties.NestedConfigurationProperty;
22+
23+
/**
24+
* OpenAI Image autoconfiguration properties.
25+
*
26+
* @author Thomas Vitale
27+
* @since 0.8.0
28+
*/
29+
@ConfigurationProperties(OpenAiImageProperties.CONFIG_PREFIX)
30+
public class OpenAiImageProperties extends OpenAiParentProperties {
31+
32+
public static final String CONFIG_PREFIX = "spring.ai.openai.image";
33+
34+
/**
35+
* Options for OpenAI Image API.
36+
*/
37+
@NestedConfigurationProperty
38+
private OpenAiImageOptions options = OpenAiImageOptions.builder().build();
39+
40+
public OpenAiImageOptions getOptions() {
41+
return options;
42+
}
43+
44+
public void setOptions(OpenAiImageOptions options) {
45+
this.options = options;
46+
}
47+
48+
}

spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfigurationIT.java

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2023-2023 the original author or authors.
2+
* Copyright 2023-2024 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -25,6 +25,9 @@
2525
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
2626
import org.springframework.ai.chat.messages.UserMessage;
2727
import org.springframework.ai.chat.prompt.Prompt;
28+
import org.springframework.ai.image.ImagePrompt;
29+
import org.springframework.ai.image.ImageResponse;
30+
import org.springframework.ai.openai.OpenAiImageClient;
2831
import reactor.core.publisher.Flux;
2932

3033
import org.springframework.ai.chat.ChatResponse;
@@ -86,4 +89,15 @@ void embedding() {
8689
});
8790
}
8891

92+
@Test
93+
void generateImage() {
94+
contextRunner.withPropertyValues("spring.ai.openai.image.options.size=256x256").run(context -> {
95+
OpenAiImageClient client = context.getBean(OpenAiImageClient.class);
96+
ImageResponse imageResponse = client.call(new ImagePrompt("forest"));
97+
assertThat(imageResponse.getResults()).hasSize(1);
98+
assertThat(imageResponse.getResult().getOutput().getUrl()).isNotEmpty();
99+
logger.info("Generated image: " + imageResponse.getResult().getOutput().getUrl());
100+
});
101+
}
102+
89103
}

spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiPropertiesTests.java

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
* {@link OpenAiEmbeddingProperties}.
3434
*
3535
* @author Christian Tzolov
36+
* @author Thomas Vitale
3637
* @since 0.8.0
3738
*/
3839
public class OpenAiPropertiesTests {
@@ -141,6 +142,58 @@ public void embeddingOverrideConnectionProperties() {
141142
});
142143
}
143144

145+
@Test
146+
public void imageProperties() {
147+
new ApplicationContextRunner().withPropertyValues(
148+
// @formatter:off
149+
"spring.ai.openai.base-url=TEST_BASE_URL",
150+
"spring.ai.openai.api-key=abc123",
151+
"spring.ai.openai.image.options.model=MODEL_XYZ",
152+
"spring.ai.openai.image.options.n=3")
153+
// @formatter:on
154+
.withConfiguration(AutoConfigurations.of(OpenAiAutoConfiguration.class))
155+
.run(context -> {
156+
var imageProperties = context.getBean(OpenAiImageProperties.class);
157+
var connectionProperties = context.getBean(OpenAiConnectionProperties.class);
158+
159+
assertThat(connectionProperties.getApiKey()).isEqualTo("abc123");
160+
assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL");
161+
162+
assertThat(imageProperties.getApiKey()).isNull();
163+
assertThat(imageProperties.getBaseUrl()).isNull();
164+
165+
assertThat(imageProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ");
166+
assertThat(imageProperties.getOptions().getN()).isEqualTo(3);
167+
});
168+
}
169+
170+
@Test
171+
public void imageOverrideConnectionProperties() {
172+
new ApplicationContextRunner().withPropertyValues(
173+
// @formatter:off
174+
"spring.ai.openai.base-url=TEST_BASE_URL",
175+
"spring.ai.openai.api-key=abc123",
176+
"spring.ai.openai.image.base-url=TEST_BASE_URL2",
177+
"spring.ai.openai.image.api-key=456",
178+
"spring.ai.openai.image.options.model=MODEL_XYZ",
179+
"spring.ai.openai.image.options.n=3")
180+
// @formatter:on
181+
.withConfiguration(AutoConfigurations.of(OpenAiAutoConfiguration.class))
182+
.run(context -> {
183+
var imageProperties = context.getBean(OpenAiImageProperties.class);
184+
var connectionProperties = context.getBean(OpenAiConnectionProperties.class);
185+
186+
assertThat(connectionProperties.getApiKey()).isEqualTo("abc123");
187+
assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL");
188+
189+
assertThat(imageProperties.getApiKey()).isEqualTo("456");
190+
assertThat(imageProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL2");
191+
192+
assertThat(imageProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ");
193+
assertThat(imageProperties.getOptions().getN()).isEqualTo(3);
194+
});
195+
}
196+
144197
@Test
145198
public void chatOptionsTest() {
146199

@@ -256,4 +309,42 @@ public void embeddingOptionsTest() {
256309
});
257310
}
258311

312+
@Test
313+
public void imageOptionsTest() {
314+
new ApplicationContextRunner().withPropertyValues(
315+
// @formatter:off
316+
"spring.ai.openai.api-key=API_KEY",
317+
"spring.ai.openai.base-url=TEST_BASE_URL",
318+
319+
"spring.ai.openai.image.options.n=3",
320+
"spring.ai.openai.image.options.model=MODEL_XYZ",
321+
"spring.ai.openai.image.options.quality=hd",
322+
"spring.ai.openai.image.options.response_format=url",
323+
"spring.ai.openai.image.options.size=1024x1024",
324+
"spring.ai.openai.image.options.width=1024",
325+
"spring.ai.openai.image.options.height=1024",
326+
"spring.ai.openai.image.options.style=vivid",
327+
"spring.ai.openai.image.options.user=userXYZ"
328+
)
329+
// @formatter:on
330+
.withConfiguration(AutoConfigurations.of(OpenAiAutoConfiguration.class))
331+
.run(context -> {
332+
var imageProperties = context.getBean(OpenAiImageProperties.class);
333+
var connectionProperties = context.getBean(OpenAiConnectionProperties.class);
334+
335+
assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL");
336+
assertThat(connectionProperties.getApiKey()).isEqualTo("API_KEY");
337+
338+
assertThat(imageProperties.getOptions().getN()).isEqualTo(3);
339+
assertThat(imageProperties.getOptions().getModel()).isEqualTo("MODEL_XYZ");
340+
assertThat(imageProperties.getOptions().getQuality()).isEqualTo("hd");
341+
assertThat(imageProperties.getOptions().getResponseFormat()).isEqualTo("url");
342+
assertThat(imageProperties.getOptions().getSize()).isEqualTo("1024x1024");
343+
assertThat(imageProperties.getOptions().getWidth()).isEqualTo(1024);
344+
assertThat(imageProperties.getOptions().getHeight()).isEqualTo(1024);
345+
assertThat(imageProperties.getOptions().getStyle()).isEqualTo("vivid");
346+
assertThat(imageProperties.getOptions().getUser()).isEqualTo("userXYZ");
347+
});
348+
}
349+
259350
}

0 commit comments

Comments
 (0)