Skip to content

Commit 08ccc10

Browse files
ThomasVitaletzolov
authored andcommitted
Streamline EmbeddingOptions
* Add model and dimensions to option abstraction * Use abstraction in Observations directly instead of dedicated implementation * Clean-up the merge of runtime and default embedding options in OpenAI Relates to #gh-1148 Signed-off-by: Thomas Vitale <ThomasVitale@users.noreply.github.com>
1 parent 17ba1fc commit 08ccc10

File tree

29 files changed

+353
-250
lines changed

29 files changed

+353
-250
lines changed

models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingOptions.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
* The configuration information for the embedding requests.
2424
*
2525
* @author Christian Tzolov
26+
* @author Thomas Vitale
2627
* @since 0.8.0
2728
*/
2829
public class AzureOpenAiEmbeddingOptions implements EmbeddingOptions {
@@ -123,6 +124,11 @@ public AzureOpenAiEmbeddingOptions build() {
123124

124125
}
125126

127+
@Override
128+
public String getModel() {
129+
return getDeploymentName();
130+
}
131+
126132
public String getUser() {
127133
return this.user;
128134
}
@@ -147,6 +153,7 @@ public void setInputType(String inputType) {
147153
this.inputType = inputType;
148154
}
149155

156+
@Override
150157
public Integer getDimensions() {
151158
return this.dimensions;
152159
}

models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingOptions.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
/**
2727
* @author Christian Tzolov
28+
* @author Thomas Vitale
2829
*/
2930
@JsonInclude(Include.NON_NULL)
3031
public class BedrockCohereEmbeddingOptions implements EmbeddingOptions {
@@ -86,4 +87,14 @@ public void setTruncate(Truncate truncate) {
8687
this.truncate = truncate;
8788
}
8889

90+
@Override
91+
public String getModel() {
92+
return null;
93+
}
94+
95+
@Override
96+
public Integer getDimensions() {
97+
return null;
98+
}
99+
89100
}

models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingOptions.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
/**
2626
* @author Wei Jiang
27+
* @author Thomas Vitale
2728
*/
2829
@JsonInclude(Include.NON_NULL)
2930
public class BedrockTitanEmbeddingOptions implements EmbeddingOptions {
@@ -62,4 +63,14 @@ public void setInputType(InputType inputType) {
6263
this.inputType = inputType;
6364
}
6465

66+
@Override
67+
public String getModel() {
68+
return null;
69+
}
70+
71+
@Override
72+
public Integer getDimensions() {
73+
return null;
74+
}
75+
6576
}

models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxEmbeddingOptions.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
* This class represents the options for MiniMax embedding.
2525
*
2626
* @author Geng Rong
27+
* @author Thomas Vitale
2728
* @since 1.0.0 M1
2829
*/
2930
@JsonInclude(Include.NON_NULL)
@@ -59,6 +60,7 @@ public MiniMaxEmbeddingOptions build() {
5960

6061
}
6162

63+
@Override
6264
public String getModel() {
6365
return this.model;
6466
}
@@ -67,4 +69,9 @@ public void setModel(String model) {
6769
this.model = model;
6870
}
6971

72+
@Override
73+
public Integer getDimensions() {
74+
return null;
75+
}
76+
7077
}

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingOptions.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
/**
2424
* @author Ricken Bazolo
25+
* @author Thomas Vitale
2526
* @since 0.8.1
2627
*/
2728
@JsonInclude(Include.NON_NULL)
@@ -41,6 +42,7 @@ public static Builder builder() {
4142
return new Builder();
4243
}
4344

45+
@Override
4446
public String getModel() {
4547
return this.model;
4648
}
@@ -57,6 +59,11 @@ public void setEncodingFormat(String encodingFormat) {
5759
this.encodingFormat = encodingFormat;
5860
}
5961

62+
@Override
63+
public Integer getDimensions() {
64+
return null;
65+
}
66+
6067
public static class Builder {
6168

6269
protected MistralAiEmbeddingOptions options;

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,7 @@ public OllamaOptions withFunction(String functionName) {
498498
// -------------------
499499
// Getters and Setters
500500
// -------------------
501+
@Override
501502
public String getModel() {
502503
return model;
503504
}
@@ -762,6 +763,11 @@ public void setTruncate(Boolean truncate) {
762763
this.truncate = truncate;
763764
}
764765

766+
@Override
767+
public Integer getDimensions() {
768+
return null;
769+
}
770+
765771
@Override
766772
public List<FunctionCallback> getFunctionCallbacks() {
767773
return this.functionCallbacks;

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingModel.java

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention;
3131
import org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation;
3232
import org.springframework.ai.embedding.observation.EmbeddingModelObservationContext;
33-
import org.springframework.ai.embedding.observation.EmbeddingModelRequestOptions;
3433
import org.springframework.ai.model.ModelOptionsUtils;
3534
import org.springframework.ai.observation.AiOperationMetadata;
3635
import org.springframework.ai.observation.conventions.AiOperationType;
@@ -39,9 +38,9 @@
3938
import org.springframework.ai.openai.api.OpenAiApi.EmbeddingList;
4039
import org.springframework.ai.openai.metadata.OpenAiUsage;
4140
import org.springframework.ai.retry.RetryUtils;
41+
import org.springframework.lang.Nullable;
4242
import org.springframework.retry.support.RetryTemplate;
4343
import org.springframework.util.Assert;
44-
import org.springframework.util.StringUtils;
4544

4645
import java.util.List;
4746

@@ -126,7 +125,7 @@ public OpenAiEmbeddingModel(OpenAiApi openAiApi, MetadataMode metadataMode, Open
126125
*/
127126
public OpenAiEmbeddingModel(OpenAiApi openAiApi, MetadataMode metadataMode, OpenAiEmbeddingOptions options,
128127
RetryTemplate retryTemplate, ObservationRegistry observationRegistry) {
129-
Assert.notNull(openAiApi, "OpenAiService must not be null");
128+
Assert.notNull(openAiApi, "openAiApi must not be null");
130129
Assert.notNull(metadataMode, "metadataMode must not be null");
131130
Assert.notNull(options, "options must not be null");
132131
Assert.notNull(retryTemplate, "retryTemplate must not be null");
@@ -147,12 +146,13 @@ public List<Double> embed(Document document) {
147146

148147
@Override
149148
public EmbeddingResponse call(EmbeddingRequest request) {
150-
org.springframework.ai.openai.api.OpenAiApi.EmbeddingRequest<List<String>> apiRequest = createRequest(request);
149+
OpenAiEmbeddingOptions requestOptions = mergeOptions(request.getOptions(), this.defaultOptions);
150+
OpenAiApi.EmbeddingRequest<List<String>> apiRequest = createRequest(request, requestOptions);
151151

152152
var observationContext = EmbeddingModelObservationContext.builder()
153153
.embeddingRequest(request)
154154
.operationMetadata(buildOperationMetadata())
155-
.requestOptions(buildRequestOptions(apiRequest))
155+
.requestOptions(requestOptions)
156156
.build();
157157

158158
return EmbeddingModelObservationDocumentation.EMBEDDING_MODEL_OPERATION
@@ -183,21 +183,31 @@ public EmbeddingResponse call(EmbeddingRequest request) {
183183
});
184184
}
185185

186-
@SuppressWarnings("unchecked")
187-
private OpenAiApi.EmbeddingRequest<List<String>> createRequest(EmbeddingRequest request) {
188-
org.springframework.ai.openai.api.OpenAiApi.EmbeddingRequest<List<String>> apiRequest = (this.defaultOptions != null)
189-
? new org.springframework.ai.openai.api.OpenAiApi.EmbeddingRequest<>(request.getInstructions(),
190-
this.defaultOptions.getModel(), this.defaultOptions.getEncodingFormat(),
191-
this.defaultOptions.getDimensions(), this.defaultOptions.getUser())
192-
: new org.springframework.ai.openai.api.OpenAiApi.EmbeddingRequest<>(request.getInstructions(),
193-
OpenAiApi.DEFAULT_EMBEDDING_MODEL);
194-
195-
if (request.getOptions() != null && !EmbeddingOptions.EMPTY.equals(request.getOptions())) {
196-
apiRequest = ModelOptionsUtils.merge(request.getOptions(), apiRequest,
197-
org.springframework.ai.openai.api.OpenAiApi.EmbeddingRequest.class);
186+
private OpenAiApi.EmbeddingRequest<List<String>> createRequest(EmbeddingRequest request,
187+
OpenAiEmbeddingOptions requestOptions) {
188+
return new OpenAiApi.EmbeddingRequest<>(request.getInstructions(), requestOptions.getModel(),
189+
requestOptions.getEncodingFormat(), requestOptions.getDimensions(), requestOptions.getUser());
190+
}
191+
192+
/**
193+
* Merge runtime and default {@link EmbeddingOptions} to compute the final options to
194+
* use in the request.
195+
*/
196+
private OpenAiEmbeddingOptions mergeOptions(@Nullable EmbeddingOptions runtimeOptions,
197+
OpenAiEmbeddingOptions defaultOptions) {
198+
if (runtimeOptions == null) {
199+
return defaultOptions;
198200
}
199201

200-
return apiRequest;
202+
return OpenAiEmbeddingOptions.builder()
203+
// Handle portable embedding options
204+
.withModel(ModelOptionsUtils.mergeOption(runtimeOptions.getModel(), defaultOptions.getModel()))
205+
.withDimensions(
206+
ModelOptionsUtils.mergeOption(runtimeOptions.getDimensions(), defaultOptions.getDimensions()))
207+
// Handle OpenAI specific embedding options
208+
.withEncodingFormat(defaultOptions.getEncodingFormat())
209+
.withUser(defaultOptions.getUser())
210+
.build();
201211
}
202212

203213
private AiOperationMetadata buildOperationMetadata() {
@@ -207,14 +217,6 @@ private AiOperationMetadata buildOperationMetadata() {
207217
.build();
208218
}
209219

210-
private EmbeddingModelRequestOptions buildRequestOptions(OpenAiApi.EmbeddingRequest<List<String>> request) {
211-
return EmbeddingModelRequestOptions.builder()
212-
.model(StringUtils.hasText(request.model()) ? request.model() : "unknown")
213-
.dimensions(request.dimensions())
214-
.encodingFormat(request.encodingFormat())
215-
.build();
216-
}
217-
218220
/**
219221
* Use the provided convention for reporting observation data
220222
* @param observationConvention The provided convention

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingOptions.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ public OpenAiEmbeddingOptions build() {
8585

8686
}
8787

88+
@Override
8889
public String getModel() {
8990
return this.model;
9091
}
@@ -101,6 +102,7 @@ public void setEncodingFormat(String encodingFormat) {
101102
this.encodingFormat = encodingFormat;
102103
}
103104

105+
@Override
104106
public Integer getDimensions() {
105107
return this.dimensions;
106108
}

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageOptions.java

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,18 @@ public void setResponseFormat(String responseFormat) {
190190

191191
@Override
192192
public Integer getWidth() {
193-
return this.width;
193+
if (this.width != null) {
194+
return this.width;
195+
}
196+
else if (this.size != null) {
197+
try {
198+
return Integer.parseInt(this.size.split("x")[0]);
199+
}
200+
catch (NumberFormatException ex) {
201+
return null;
202+
}
203+
}
204+
return null;
194205
}
195206

196207
public void setWidth(Integer width) {
@@ -200,7 +211,18 @@ public void setWidth(Integer width) {
200211

201212
@Override
202213
public Integer getHeight() {
203-
return this.height;
214+
if (this.height != null) {
215+
return this.height;
216+
}
217+
else if (this.size != null) {
218+
try {
219+
return Integer.parseInt(this.size.split("x")[1]);
220+
}
221+
catch (NumberFormatException ex) {
222+
return null;
223+
}
224+
}
225+
return null;
204226
}
205227

206228
public void setHeight(Integer height) {
@@ -230,7 +252,6 @@ public void setSize(String size) {
230252
}
231253

232254
public String getSize() {
233-
234255
if (this.size != null) {
235256
return this.size;
236257
}
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
/*
2+
* Copyright 2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.ai.openai;
17+
18+
import org.junit.jupiter.api.Test;
19+
20+
import static org.assertj.core.api.Assertions.assertThat;
21+
22+
/**
23+
* Unit tests for {@link OpenAiImageOptions}.
24+
*
25+
* @author Thomas Vitale
26+
*/
27+
class OpenAiImageOptionsTests {
28+
29+
@Test
30+
void whenImageDimensionsAreAllUnset() {
31+
OpenAiImageOptions options = new OpenAiImageOptions();
32+
assertThat(options.getHeight()).isEqualTo(null);
33+
assertThat(options.getWidth()).isEqualTo(null);
34+
assertThat(options.getSize()).isEqualTo(null);
35+
}
36+
37+
@Test
38+
void whenSizeIsSet() {
39+
OpenAiImageOptions options = new OpenAiImageOptions();
40+
options.setSize("1920x1080");
41+
assertThat(options.getHeight()).isEqualTo(1080);
42+
assertThat(options.getWidth()).isEqualTo(1920);
43+
assertThat(options.getSize()).isEqualTo("1920x1080");
44+
}
45+
46+
@Test
47+
void whenWidthAndHeightAreSet() {
48+
OpenAiImageOptions options = new OpenAiImageOptions();
49+
options.setWidth(1920);
50+
options.setHeight(1080);
51+
assertThat(options.getHeight()).isEqualTo(1080);
52+
assertThat(options.getWidth()).isEqualTo(1920);
53+
assertThat(options.getSize()).isEqualTo("1920x1080");
54+
}
55+
56+
@Test
57+
void whenWidthIsSet() {
58+
OpenAiImageOptions options = new OpenAiImageOptions();
59+
options.setWidth(1920);
60+
assertThat(options.getHeight()).isEqualTo(null);
61+
assertThat(options.getWidth()).isEqualTo(1920);
62+
// This is because "setWidth()" computes "size" without checking for null values.
63+
assertThat(options.getSize()).isEqualTo("1920xnull");
64+
}
65+
66+
@Test
67+
void whenHeightIsSet() {
68+
OpenAiImageOptions options = new OpenAiImageOptions();
69+
options.setHeight(1080);
70+
assertThat(options.getHeight()).isEqualTo(1080);
71+
assertThat(options.getWidth()).isEqualTo(null);
72+
// This is because "setHeight()" computes "size" without checking for null values.
73+
assertThat(options.getSize()).isEqualTo("nullx1080");
74+
}
75+
76+
}

0 commit comments

Comments
 (0)