Skip to content

Commit d538e00

Browse files
tzolovmarkpollack
authored andcommitted
Replace the Embedding format from List<Double> to float[]
- Adjust all affected classes including the Document. - Update docs. Related to #405
1 parent 656fa8b commit d538e00

File tree

67 files changed

+442
-412
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

67 files changed

+442
-412
lines changed

models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingModel.java

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,21 @@
2929
import org.springframework.ai.embedding.EmbeddingRequest;
3030
import org.springframework.ai.embedding.EmbeddingResponse;
3131
import org.springframework.ai.embedding.EmbeddingResponseMetadata;
32+
import org.springframework.ai.model.EmbeddingUtils;
3233
import org.springframework.util.Assert;
34+
import org.springframework.util.CollectionUtils;
3335

3436
import java.util.ArrayList;
3537
import java.util.List;
3638

39+
/**
40+
* Azure Open AI Embedding Model implementation.
41+
*
42+
* @author Mark Pollack
43+
* @author Christian Tzolov
44+
* @author Thomas Vitale
45+
* @since 1.0.0
46+
*/
3747
public class AzureOpenAiEmbeddingModel extends AbstractEmbeddingModel {
3848

3949
private static final Logger logger = LoggerFactory.getLogger(AzureOpenAiEmbeddingModel.class);
@@ -64,13 +74,17 @@ public AzureOpenAiEmbeddingModel(OpenAIClient azureOpenAiClient, MetadataMode me
6474
}
6575

6676
@Override
67-
public List<Double> embed(Document document) {
77+
public float[] embed(Document document) {
6878
logger.debug("Retrieving embeddings");
6979

7080
EmbeddingResponse response = this
7181
.call(new EmbeddingRequest(List.of(document.getFormattedContent(this.metadataMode)), null));
7282
logger.debug("Embeddings retrieved");
73-
return response.getResults().stream().map(embedding -> embedding.getOutput()).flatMap(List::stream).toList();
83+
84+
if (CollectionUtils.isEmpty(response.getResults())) {
85+
return new float[0];
86+
}
87+
return response.getResults().get(0).getOutput();
7488
}
7589

7690
@Override
@@ -108,8 +122,7 @@ private List<Embedding> generateEmbeddingList(List<EmbeddingItem> nativeData) {
108122
for (EmbeddingItem nativeDatum : nativeData) {
109123
List<Float> nativeDatumEmbedding = nativeDatum.getEmbedding();
110124
int nativeIndex = nativeDatum.getPromptIndex();
111-
Embedding embedding = new Embedding(nativeDatumEmbedding.stream().map(f -> f.doubleValue()).toList(),
112-
nativeIndex);
125+
Embedding embedding = new Embedding(EmbeddingUtils.toPrimitive(nativeDatumEmbedding), nativeIndex);
113126
data.add(embedding);
114127
}
115128
return data;

models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingModel.java

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -66,33 +66,8 @@ public BedrockCohereEmbeddingModel(CohereEmbeddingBedrockApi cohereEmbeddingBedr
6666
this.defaultOptions = options;
6767
}
6868

69-
// /**
70-
// * Cohere Embedding API input types.
71-
// * @param inputType the input type to use.
72-
// * @return this client.
73-
// */
74-
// public BedrockCohereEmbeddingModel withInputType(CohereEmbeddingRequest.InputType
75-
// inputType) {
76-
// this.inputType = inputType;
77-
// return this;
78-
// }
79-
80-
// /**
81-
// * Specifies how the API handles inputs longer than the maximum token length. If you
82-
// specify LEFT or RIGHT, the
83-
// * model discards the input until the remaining input is exactly the maximum input
84-
// token length for the model.
85-
// * @param truncate the truncate option to use.
86-
// * @return this client.
87-
// */
88-
// public BedrockCohereEmbeddingModel withTruncate(CohereEmbeddingRequest.Truncate
89-
// truncate) {
90-
// this.truncate = truncate;
91-
// return this;
92-
// }
93-
9469
@Override
95-
public List<Double> embed(Document document) {
70+
public float[] embed(Document document) {
9671
return embed(document.getContent());
9772
}
9873

models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereEmbeddingBedrockApi.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ public enum Truncate {
183183
@JsonInclude(Include.NON_NULL)
184184
public record CohereEmbeddingResponse(
185185
@JsonProperty("id") String id,
186-
@JsonProperty("embeddings") List<List<Double>> embeddings,
186+
@JsonProperty("embeddings") List<float[]> embeddings,
187187
@JsonProperty("texts") List<String> texts,
188188
@JsonProperty("response_type") String responseType,
189189
// For future use: Currently bedrock doesn't return invocationMetrics for the cohere embedding model.

models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingModel.java

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ public BedrockTitanEmbeddingModel withInputType(InputType inputType) {
7575
}
7676

7777
@Override
78-
public List<Double> embed(Document document) {
78+
public float[] embed(Document document) {
7979
return embed(document.getContent());
8080
}
8181

@@ -87,16 +87,13 @@ public EmbeddingResponse call(EmbeddingRequest request) {
8787
"Titan Embedding does not support batch embedding. Will make multiple API calls to embed(Document)");
8888
}
8989

90-
List<List<Double>> embeddingList = new ArrayList<>();
90+
List<Embedding> embeddings = new ArrayList<>();
91+
var indexCounter = new AtomicInteger(0);
9192
for (String inputContent : request.getInstructions()) {
9293
var apiRequest = createTitanEmbeddingRequest(inputContent, request.getOptions());
9394
TitanEmbeddingResponse response = this.embeddingApi.embedding(apiRequest);
94-
embeddingList.add(response.embedding());
95+
embeddings.add(new Embedding(response.embedding(), indexCounter.getAndIncrement()));
9596
}
96-
var indexCounter = new AtomicInteger(0);
97-
List<Embedding> embeddings = embeddingList.stream()
98-
.map(e -> new Embedding(e, indexCounter.getAndIncrement()))
99-
.toList();
10097
return new EmbeddingResponse(embeddings);
10198
}
10299

models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanEmbeddingBedrockApi.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ public TitanEmbeddingRequest build() {
137137
*/
138138
@JsonInclude(Include.NON_NULL)
139139
public record TitanEmbeddingResponse(
140-
@JsonProperty("embedding") List<Double> embedding,
140+
@JsonProperty("embedding") float[] embedding,
141141
@JsonProperty("inputTextTokenCount") Integer inputTextTokenCount,
142142
@JsonProperty("message") Object message) {
143143
}

models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxEmbeddingModel.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ public MiniMaxEmbeddingModel(MiniMaxApi miniMaxApi, MetadataMode metadataMode, M
105105
}
106106

107107
@Override
108-
public List<Double> embed(Document document) {
108+
public float[] embed(Document document) {
109109
Assert.notNull(document, "Document must not be null");
110110
return this.embed(document.getFormattedContent(this.metadataMode));
111111
}
@@ -137,7 +137,7 @@ public EmbeddingResponse call(EmbeddingRequest request) {
137137

138138
List<Embedding> embeddings = new ArrayList<>();
139139
for (int i = 0; i < apiEmbeddingResponse.vectors().size(); i++) {
140-
List<Double> vector = apiEmbeddingResponse.vectors().get(i);
140+
float[] vector = apiEmbeddingResponse.vectors().get(i);
141141
embeddings.add(new Embedding(vector, i));
142142
}
143143
return new EmbeddingResponse(embeddings, metadata);

models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/api/MiniMaxApi.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -865,7 +865,7 @@ public EmbeddingRequest(List<String> texts, EmbeddingType type) {
865865
*/
866866
@JsonInclude(Include.NON_NULL)
867867
public record EmbeddingList(
868-
@JsonProperty("vectors") List<List<Double>> vectors,
868+
@JsonProperty("vectors") List<float[]> vectors,
869869
@JsonProperty("model") String model,
870870
@JsonProperty("total_tokens") Integer totalTokens) {
871871
}

models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/api/MiniMaxRetryTests.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ public void miniMaxChatStreamNonTransientError() {
157157
@Test
158158
public void miniMaxEmbeddingTransientError() {
159159

160-
EmbeddingList expectedEmbeddings = new EmbeddingList(List.of(List.of(9.9, 8.8)), "model", 10);
160+
EmbeddingList expectedEmbeddings = new EmbeddingList(List.of(new float[] { 9.9f, 8.8f }), "model", 10);
161161

162162
when(miniMaxApi.embeddings(isA(EmbeddingRequest.class)))
163163
.thenThrow(new TransientAiException("Transient Error 1"))
@@ -168,7 +168,7 @@ public void miniMaxEmbeddingTransientError() {
168168
.call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null));
169169

170170
assertThat(result).isNotNull();
171-
assertThat(result.getResult().getOutput()).isEqualTo(List.of(9.9, 8.8));
171+
assertThat(result.getResult().getOutput()).isEqualTo(new float[] { 9.9f, 8.8f });
172172
assertThat(retryListener.onSuccessRetryCount).isEqualTo(2);
173173
assertThat(retryListener.onErrorRetryCount).isEqualTo(2);
174174
}

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingModel.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ public EmbeddingResponse call(EmbeddingRequest request) {
116116
}
117117

118118
@Override
119-
public List<Double> embed(Document document) {
119+
public float[] embed(Document document) {
120120
Assert.notNull(document, "Document must not be null");
121121
return this.embed(document.getFormattedContent(this.metadataMode));
122122
}

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ public record Usage(
196196
public record Embedding(
197197
// @formatter:off
198198
@JsonProperty("index") Integer index,
199-
@JsonProperty("embedding") List<Double> embedding,
199+
@JsonProperty("embedding") float[] embedding,
200200
@JsonProperty("object") String object) {
201201
// @formatter:on
202202

@@ -207,7 +207,7 @@ public record Embedding(
207207
* @param embedding The embedding vector, which is a list of floats. The length of
208208
* vector depends on the model.
209209
*/
210-
public Embedding(Integer index, List<Double> embedding) {
210+
public Embedding(Integer index, float[] embedding) {
211211
this(index, embedding, "embedding");
212212
}
213213
}

models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiRetryTests.java

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@
1515
*/
1616
package org.springframework.ai.mistralai;
1717

18+
import static org.assertj.core.api.Assertions.assertThat;
19+
import static org.junit.jupiter.api.Assertions.assertThrows;
20+
import static org.mockito.ArgumentMatchers.isA;
21+
import static org.mockito.Mockito.when;
22+
1823
import java.util.List;
1924
import java.util.Optional;
2025

@@ -23,8 +28,6 @@
2328
import org.junit.jupiter.api.extension.ExtendWith;
2429
import org.mockito.Mock;
2530
import org.mockito.junit.jupiter.MockitoExtension;
26-
import reactor.core.publisher.Flux;
27-
2831
import org.springframework.ai.chat.prompt.Prompt;
2932
import org.springframework.ai.document.MetadataMode;
3033
import org.springframework.ai.mistralai.api.MistralAiApi;
@@ -45,10 +48,7 @@
4548
import org.springframework.retry.RetryListener;
4649
import org.springframework.retry.support.RetryTemplate;
4750

48-
import static org.assertj.core.api.Assertions.assertThat;
49-
import static org.junit.jupiter.api.Assertions.assertThrows;
50-
import static org.mockito.ArgumentMatchers.isA;
51-
import static org.mockito.Mockito.when;
51+
import reactor.core.publisher.Flux;
5252

5353
/**
5454
* @author Christian Tzolov
@@ -166,7 +166,7 @@ public void mistralAiChatStreamNonTransientError() {
166166
public void mistralAiEmbeddingTransientError() {
167167

168168
EmbeddingList<Embedding> expectedEmbeddings = new EmbeddingList<>("list",
169-
List.of(new Embedding(0, List.of(9.9, 8.8))), "model", new MistralAiApi.Usage(10, 10, 10));
169+
List.of(new Embedding(0, new float[] { 9.9f, 8.8f })), "model", new MistralAiApi.Usage(10, 10, 10));
170170

171171
when(mistralAiApi.embeddings(isA(EmbeddingRequest.class)))
172172
.thenThrow(new TransientAiException("Transient Error 1"))
@@ -177,7 +177,7 @@ public void mistralAiEmbeddingTransientError() {
177177
.call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null));
178178

179179
assertThat(result).isNotNull();
180-
assertThat(result.getResult().getOutput()).isEqualTo(List.of(9.9, 8.8));
180+
assertThat(result.getResult().getOutput()).isEqualTo(new float[] { 9.9f, 8.8f });
181181
assertThat(retryListener.onSuccessRetryCount).isEqualTo(2);
182182
assertThat(retryListener.onErrorRetryCount).isEqualTo(2);
183183
}

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaEmbeddingModel.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ public OllamaEmbeddingModel(OllamaApi ollamaApi, OllamaOptions defaultOptions) {
7676
}
7777

7878
@Override
79-
public List<Double> embed(Document document) {
79+
public float[] embed(Document document) {
8080
return embed(document.getContent());
8181
}
8282

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -751,7 +751,7 @@ public EmbeddingRequest(String model, String prompt) {
751751
@Deprecated(since = "1.0.0-M2", forRemoval = true)
752752
@JsonInclude(Include.NON_NULL)
753753
public record EmbeddingResponse(
754-
@JsonProperty("embedding") List<Double> embedding) {
754+
@JsonProperty("embedding") List<Float> embedding) {
755755
}
756756

757757

@@ -764,7 +764,7 @@ public record EmbeddingResponse(
764764
@JsonInclude(Include.NON_NULL)
765765
public record EmbeddingsResponse(
766766
@JsonProperty("model") String model,
767-
@JsonProperty("embeddings") List<List<Double>> embeddings) {
767+
@JsonProperty("embeddings") List<float[]> embeddings) {
768768
}
769769

770770
/**

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelTests.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,9 @@ public void options() {
5555

5656
when(ollamaApi.embed(embeddingsRequestCaptor.capture()))
5757
.thenReturn(
58-
new EmbeddingsResponse("RESPONSE_MODEL_NAME", List.of(List.of(1d, 2d, 3d), List.of(4d, 5d, 6d))))
58+
new EmbeddingsResponse("RESPONSE_MODEL_NAME", List.of(new float[]{1f, 2f, 3f}, new float[]{4f, 5f, 6f})))
5959
.thenReturn(new EmbeddingsResponse("RESPONSE_MODEL_NAME2",
60-
List.of(List.of(7d, 8d, 9d), List.of(10d, 11d, 12d))));
60+
List.of(new float[]{7f, 8f, 9f}, new float[]{10f, 11f, 12f})));
6161

6262
// Tests default options
6363
var defaultOptions = OllamaOptions.builder().withModel("DEFAULT_MODEL").build();
@@ -69,10 +69,10 @@ public void options() {
6969

7070
assertThat(response.getResults()).hasSize(2);
7171
assertThat(response.getResults().get(0).getIndex()).isEqualTo(0);
72-
assertThat(response.getResults().get(0).getOutput()).isEqualTo(List.of(1d, 2d, 3d));
72+
assertThat(response.getResults().get(0).getOutput()).isEqualTo(new float[]{1f, 2f, 3f});
7373
assertThat(response.getResults().get(0).getMetadata()).isEqualTo(EmbeddingResultMetadata.EMPTY);
7474
assertThat(response.getResults().get(1).getIndex()).isEqualTo(1);
75-
assertThat(response.getResults().get(1).getOutput()).isEqualTo(List.of(4d, 5d, 6d));
75+
assertThat(response.getResults().get(1).getOutput()).isEqualTo(new float[]{4f, 5f, 6f});
7676
assertThat(response.getResults().get(1).getMetadata()).isEqualTo(EmbeddingResultMetadata.EMPTY);
7777
assertThat(response.getMetadata().getModel()).isEqualTo("RESPONSE_MODEL_NAME");
7878

@@ -94,10 +94,10 @@ public void options() {
9494

9595
assertThat(response.getResults()).hasSize(2);
9696
assertThat(response.getResults().get(0).getIndex()).isEqualTo(0);
97-
assertThat(response.getResults().get(0).getOutput()).isEqualTo(List.of(7d, 8d, 9d));
97+
assertThat(response.getResults().get(0).getOutput()).isEqualTo(new float[]{7f, 8f, 9f});
9898
assertThat(response.getResults().get(0).getMetadata()).isEqualTo(EmbeddingResultMetadata.EMPTY);
9999
assertThat(response.getResults().get(1).getIndex()).isEqualTo(1);
100-
assertThat(response.getResults().get(1).getOutput()).isEqualTo(List.of(10d, 11d, 12d));
100+
assertThat(response.getResults().get(1).getOutput()).isEqualTo(new float[]{10f, 11f, 12f});
101101
assertThat(response.getResults().get(1).getMetadata()).isEqualTo(EmbeddingResultMetadata.EMPTY);
102102
assertThat(response.getMetadata().getModel()).isEqualTo("RESPONSE_MODEL_NAME2");
103103

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingModel.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ public OpenAiEmbeddingModel(OpenAiApi openAiApi, MetadataMode metadataMode, Open
137137
}
138138

139139
@Override
140-
public List<Double> embed(Document document) {
140+
public float[] embed(Document document) {
141141
Assert.notNull(document, "Document must not be null");
142142
return this.embed(document.getFormattedContent(this.metadataMode));
143143
}

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1102,7 +1102,7 @@ public String getValue() {
11021102
@JsonInclude(Include.NON_NULL)
11031103
public record Embedding(// @formatter:off
11041104
@JsonProperty("index") Integer index,
1105-
@JsonProperty("embedding") List<Double> embedding,
1105+
@JsonProperty("embedding") float[] embedding,
11061106
@JsonProperty("object") String object) {// @formatter:on
11071107

11081108
/**
@@ -1112,7 +1112,7 @@ public record Embedding(// @formatter:off
11121112
* @param embedding The embedding vector, which is a list of floats. The length of
11131113
* vector depends on the model.
11141114
*/
1115-
public Embedding(Integer index, List<Double> embedding) {
1115+
public Embedding(Integer index, float[] embedding) {
11161116
this(index, embedding, "embedding");
11171117
}
11181118
}

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiRetryTests.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ public void openAiChatStreamNonTransientError() {
197197
public void openAiEmbeddingTransientError() {
198198

199199
EmbeddingList<Embedding> expectedEmbeddings = new EmbeddingList<>("list",
200-
List.of(new Embedding(0, List.of(9.9, 8.8))), "model", new OpenAiApi.Usage(10, 10, 10));
200+
List.of(new Embedding(0, new float[] { 9.9f, 8.8f })), "model", new OpenAiApi.Usage(10, 10, 10));
201201

202202
when(openAiApi.embeddings(isA(EmbeddingRequest.class))).thenThrow(new TransientAiException("Transient Error 1"))
203203
.thenThrow(new TransientAiException("Transient Error 2"))
@@ -207,7 +207,7 @@ public void openAiEmbeddingTransientError() {
207207
.call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null));
208208

209209
assertThat(result).isNotNull();
210-
assertThat(result.getResult().getOutput()).isEqualTo(List.of(9.9, 8.8));
210+
assertThat(result.getResult().getOutput()).isEqualTo(new float[] { 9.9f, 8.8f });
211211
assertThat(retryListener.onSuccessRetryCount).isEqualTo(2);
212212
assertThat(retryListener.onErrorRetryCount).isEqualTo(2);
213213
}

0 commit comments

Comments
 (0)