Skip to content

Commit b49458b

Browse files
committed
Improve OpenAiImage/Client implementation
- Impove the code style. - Merge the OpenAiImageOptions interface impl and builder into one class. - Improve the OpenAiImageClient call implementation. Resolves #274
1 parent 9b79359 commit b49458b

File tree

8 files changed

+366
-326
lines changed

8 files changed

+366
-326
lines changed

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageClient.java

Lines changed: 51 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,22 @@
1616

1717
package org.springframework.ai.openai;
1818

19+
import java.time.Duration;
20+
import java.util.List;
21+
1922
import org.slf4j.Logger;
2023
import org.slf4j.LoggerFactory;
21-
import org.springframework.ai.image.*;
24+
25+
import org.springframework.ai.image.Image;
26+
import org.springframework.ai.image.ImageClient;
27+
import org.springframework.ai.image.ImageGeneration;
28+
import org.springframework.ai.image.ImageOptions;
29+
import org.springframework.ai.image.ImagePrompt;
30+
import org.springframework.ai.image.ImageResponse;
31+
import org.springframework.ai.image.ImageResponseMetadata;
2232
import org.springframework.ai.model.ModelOptionsUtils;
23-
import org.springframework.ai.openai.api.*;
33+
import org.springframework.ai.openai.api.OpenAiApi;
34+
import org.springframework.ai.openai.api.OpenAiImageApi;
2435
import org.springframework.ai.openai.metadata.OpenAiImageGenerationMetadata;
2536
import org.springframework.ai.openai.metadata.OpenAiImageResponseMetadata;
2637
import org.springframework.http.ResponseEntity;
@@ -30,14 +41,19 @@
3041
import org.springframework.retry.support.RetryTemplate;
3142
import org.springframework.util.Assert;
3243

33-
import java.time.Duration;
34-
import java.util.List;
35-
44+
/**
45+
* OpenAiImageClient is a class that implements the ImageClient interface. It provides a
46+
* client for calling the OpenAI image generation API.
47+
*
48+
* @author Mark Pollack
49+
* @author Christian Tzolov
50+
* @since 0.8.0
51+
*/
3652
public class OpenAiImageClient implements ImageClient {
3753

3854
private final Logger logger = LoggerFactory.getLogger(getClass());
3955

40-
private OpenAiImageOptions options;
56+
private OpenAiImageOptions defaultOptions;
4157

4258
private final OpenAiImageApi openAiImageApi;
4359

@@ -58,45 +74,40 @@ public OpenAiImageClient(OpenAiImageApi openAiImageApi) {
5874
this.openAiImageApi = openAiImageApi;
5975
}
6076

61-
public OpenAiImageOptions getOptions() {
62-
return options;
77+
public OpenAiImageOptions getDefaultOptions() {
78+
return this.defaultOptions;
79+
}
80+
81+
public OpenAiImageClient withDefaultOptions(OpenAiImageOptions defaultOptions) {
82+
this.defaultOptions = defaultOptions;
83+
return this;
6384
}
6485

6586
@Override
6687
public ImageResponse call(ImagePrompt imagePrompt) {
6788
return this.retryTemplate.execute(ctx -> {
68-
ImageOptions runtimeOptions = imagePrompt.getOptions();
69-
OpenAiImageOptions imageOptionsToUse = updateImageOptions(imagePrompt.getOptions());
70-
71-
// Merge the runtime options passed via the prompt with the
72-
// StabilityAiImageClient
73-
// options configured via Autoconfiguration.
74-
// Runtime options overwrite StabilityAiImageClient options
75-
OpenAiImageOptions optionsToUse = ModelOptionsUtils.merge(runtimeOptions, this.options,
76-
OpenAiImageOptionsImpl.class);
77-
78-
// Copy the org.springframework.ai.model derived ImagePrompt and ImageOptions
79-
// data
80-
// types to the data types used in OpenAiImageApi
89+
8190
String instructions = imagePrompt.getInstructions().get(0).getText();
82-
String size;
83-
if (imageOptionsToUse.getWidth() != null && imageOptionsToUse.getHeight() != null) {
84-
size = imageOptionsToUse.getWidth() + "x" + imageOptionsToUse.getHeight();
91+
92+
OpenAiImageApi.OpenAiImageRequest imageRequest = new OpenAiImageApi.OpenAiImageRequest(instructions,
93+
OpenAiImageApi.DEFAULT_IMAGE_MODEL);
94+
95+
if (this.defaultOptions != null) {
96+
imageRequest = ModelOptionsUtils.merge(this.defaultOptions, imageRequest,
97+
OpenAiImageApi.OpenAiImageRequest.class);
8598
}
86-
else {
87-
size = null;
99+
100+
if (imagePrompt.getOptions() != null) {
101+
imageRequest = ModelOptionsUtils.merge(toOpenAiImageOptions(imagePrompt.getOptions()), imageRequest,
102+
OpenAiImageApi.OpenAiImageRequest.class);
88103
}
89-
OpenAiImageApi.OpenAiImageRequest openAiImageRequest = new OpenAiImageApi.OpenAiImageRequest(instructions,
90-
imageOptionsToUse.getModel(), imageOptionsToUse.getN(), imageOptionsToUse.getQuality(), size,
91-
imageOptionsToUse.getResponseFormat(), imageOptionsToUse.getStyle(), imageOptionsToUse.getUser());
92104

93105
// Make the request
94106
ResponseEntity<OpenAiImageApi.OpenAiImageResponse> imageResponseEntity = this.openAiImageApi
95-
.createImage(openAiImageRequest);
107+
.createImage(imageRequest);
96108

97109
// Convert to org.springframework.ai.model derived ImageResponse data type
98-
return convertResponse(imageResponseEntity, openAiImageRequest);
99-
110+
return convertResponse(imageResponseEntity, imageRequest);
100111
});
101112
}
102113

@@ -117,8 +128,13 @@ private ImageResponse convertResponse(ResponseEntity<OpenAiImageApi.OpenAiImageR
117128
return new ImageResponse(imageGenerationList, openAiImageResponseMetadata);
118129
}
119130

120-
private OpenAiImageOptions updateImageOptions(ImageOptions runtimeImageOptions) {
121-
OpenAiImageOptionsBuilder openAiImageOptionsBuilder = OpenAiImageOptionsBuilder.builder();
131+
/**
132+
* Convert the {@link ImageOptions} into {@link OpenAiImageOptions}.
133+
* @param defaultOptions the image options to use.
134+
* @return the converted {@link OpenAiImageOptions}.
135+
*/
136+
private OpenAiImageOptions toOpenAiImageOptions(ImageOptions runtimeImageOptions) {
137+
OpenAiImageOptions.Builder openAiImageOptionsBuilder = OpenAiImageOptions.builder();
122138
if (runtimeImageOptions != null) {
123139
// Handle portable image options
124140
if (runtimeImageOptions.getN() != null) {
@@ -150,8 +166,7 @@ private OpenAiImageOptions updateImageOptions(ImageOptions runtimeImageOptions)
150166
}
151167
}
152168
}
153-
OpenAiImageOptions updatedOpenAiImageOptions = openAiImageOptionsBuilder.build();
154-
return updatedOpenAiImageOptions;
169+
return openAiImageOptionsBuilder.build();
155170
}
156171

157172
}
Lines changed: 229 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,238 @@
1+
/*
2+
* Copyright 2024-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
116
package org.springframework.ai.openai;
217

18+
import com.fasterxml.jackson.annotation.JsonInclude;
19+
import com.fasterxml.jackson.annotation.JsonProperty;
20+
321
import org.springframework.ai.image.ImageOptions;
22+
import org.springframework.ai.openai.api.OpenAiImageApi;
23+
24+
/**
25+
* OpenAI Image API options. OpenAiImageOptions.java
26+
*
27+
* @author Mark Pollack
28+
* @author Christian Tzolov
29+
* @since 0.8.0
30+
*/
31+
@JsonInclude(JsonInclude.Include.NON_NULL)
32+
public class OpenAiImageOptions implements ImageOptions {
33+
34+
/**
35+
* The number of images to generate. Must be between 1 and 10. For dall-e-3, only n=1
36+
* is supported.
37+
*/
38+
@JsonProperty("n")
39+
private Integer n;
40+
41+
/**
42+
* The model to use for image generation.
43+
*/
44+
@JsonProperty("model")
45+
private String model = OpenAiImageApi.DEFAULT_IMAGE_MODEL;
46+
47+
/**
48+
* The quality of the image that will be generated. hd creates images with finer
49+
* details and greater consistency across the image. This param is only supported for
50+
* dall-e-3.
51+
*/
52+
@JsonProperty("quality")
53+
private String quality;
54+
55+
/**
56+
* The format in which the generated images are returned. Must be one of url or
57+
* b64_json.
58+
*/
59+
@JsonProperty("response_format")
60+
private String responseFormat;
61+
62+
/**
63+
* The size of the generated images. Must be one of 256x256, 512x512, or 1024x1024 for
64+
* dall-e-2. Must be one of 1024x1024, 1792x1024, or 1024x1792 for dall-e-3 models.
65+
*/
66+
@JsonProperty("size")
67+
private String size;
68+
69+
/**
70+
* The width of the generated images. Must be one of 256, 512, or 1024 for dall-e-2.
71+
*/
72+
@JsonProperty("size_width")
73+
private Integer width;
74+
75+
/**
76+
* The height of the generated images. Must be one of 256, 512, or 1024 for dall-e-2.
77+
*/
78+
@JsonProperty("size_height")
79+
private Integer height;
80+
81+
/**
82+
* The style of the generated images. Must be one of vivid or natural. Vivid causes
83+
* the model to lean towards generating hyper-real and dramatic images. Natural causes
84+
* the model to produce more natural, less hyper-real looking images. This param is
85+
* only supported for dall-e-3.
86+
*/
87+
@JsonProperty("style")
88+
private String style;
89+
90+
/**
91+
* A unique identifier representing your end-user, which can help OpenAI to monitor
92+
* and detect abuse.
93+
*/
94+
@JsonProperty("user")
95+
private String user;
96+
97+
public static Builder builder() {
98+
return new Builder();
99+
}
100+
101+
public static class Builder {
102+
103+
private final OpenAiImageOptions options;
104+
105+
private Builder() {
106+
this.options = new OpenAiImageOptions();
107+
}
108+
109+
public Builder withN(Integer n) {
110+
options.setN(n);
111+
return this;
112+
}
113+
114+
public Builder withModel(String model) {
115+
options.setModel(model);
116+
return this;
117+
}
118+
119+
public Builder withQuality(String quality) {
120+
options.setQuality(quality);
121+
return this;
122+
}
123+
124+
public Builder withResponseFormat(String responseFormat) {
125+
options.setResponseFormat(responseFormat);
126+
return this;
127+
}
128+
129+
public Builder withWidth(Integer width) {
130+
options.setWidth(width);
131+
return this;
132+
}
133+
134+
public Builder withHeight(Integer height) {
135+
options.setHeight(height);
136+
return this;
137+
}
138+
139+
public Builder withStyle(String style) {
140+
options.setStyle(style);
141+
return this;
142+
}
143+
144+
public Builder withUser(String user) {
145+
options.setUser(user);
146+
return this;
147+
}
148+
149+
public OpenAiImageOptions build() {
150+
return options;
151+
}
152+
153+
}
154+
155+
@Override
156+
public Integer getN() {
157+
return this.n;
158+
}
159+
160+
public void setN(Integer n) {
161+
this.n = n;
162+
}
163+
164+
@Override
165+
public String getModel() {
166+
return this.model;
167+
}
168+
169+
public void setModel(String model) {
170+
this.model = model;
171+
}
172+
173+
public String getQuality() {
174+
return this.quality;
175+
}
176+
177+
public void setQuality(String quality) {
178+
this.quality = quality;
179+
}
180+
181+
@Override
182+
public String getResponseFormat() {
183+
return responseFormat;
184+
}
185+
186+
public void setResponseFormat(String responseFormat) {
187+
this.responseFormat = responseFormat;
188+
}
189+
190+
@Override
191+
public Integer getWidth() {
192+
return this.width;
193+
}
194+
195+
public void setWidth(Integer width) {
196+
this.width = width;
197+
this.size = this.width + "x" + this.height;
198+
}
199+
200+
@Override
201+
public Integer getHeight() {
202+
return this.height;
203+
}
204+
205+
public void setHeight(Integer height) {
206+
this.height = height;
207+
this.size = this.width + "x" + this.height;
208+
}
209+
210+
public String getStyle() {
211+
return this.style;
212+
}
213+
214+
public void setStyle(String style) {
215+
this.style = style;
216+
}
217+
218+
public String getUser() {
219+
return this.user;
220+
}
4221

5-
public interface OpenAiImageOptions extends ImageOptions {
222+
public void setUser(String user) {
223+
this.user = user;
224+
}
6225

7-
String getQuality();
226+
public void setSize(String size) {
227+
this.size = size;
228+
}
8229

9-
String getStyle();
230+
public String getSize() {
10231

11-
String getUser();
232+
if (this.size != null) {
233+
return this.size;
234+
}
235+
return (this.width != null && this.height != null) ? this.width + "x" + this.height : null;
236+
}
12237

13238
}

0 commit comments

Comments
 (0)