Skip to content

Commit 1abfd9a

Browse files
sobychackomarkpollack
authored andcommitted
Add truncation support for Cohere embeddings
Fixes: #1753 #1753 - Add character-based truncation (max 2048 chars) for Cohere embedding requests - Support both START and END truncation strategies - Add unit tests verifying truncation behavior for both strategies Truncation is applied before sending requests to Bedrock API to avoid ValidationException when text exceeds maximum length. The END strategy (default) keeps the first 2048 characters while START keeps the last 2048 characters.
1 parent 551206f commit 1abfd9a

File tree

3 files changed

+133
-4
lines changed

3 files changed

+133
-4
lines changed

models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingModel.java

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import java.util.List;
2020
import java.util.concurrent.atomic.AtomicInteger;
21+
import java.util.stream.Collectors;
2122

2223
import org.springframework.ai.bedrock.cohere.api.CohereEmbeddingBedrockApi;
2324
import org.springframework.ai.bedrock.cohere.api.CohereEmbeddingBedrockApi.CohereEmbeddingRequest;
@@ -37,10 +38,13 @@
3738
* this API. If this change in the future we will add it as metadata.
3839
*
3940
* @author Christian Tzolov
41+
* @author Soby Chacko
4042
* @since 0.8.0
4143
*/
4244
public class BedrockCohereEmbeddingModel extends AbstractEmbeddingModel {
4345

46+
private static final int COHERE_MAX_CHARACTERS = 2048;
47+
4448
private final CohereEmbeddingBedrockApi embeddingApi;
4549

4650
private final BedrockCohereEmbeddingOptions defaultOptions;
@@ -74,11 +78,34 @@ public float[] embed(Document document) {
7478

7579
@Override
7680
public EmbeddingResponse call(EmbeddingRequest request) {
77-
Assert.notEmpty(request.getInstructions(), "At least one text is required!");
81+
82+
List<String> instructions = request.getInstructions();
83+
Assert.notEmpty(instructions, "At least one text is required!");
7884

7985
final BedrockCohereEmbeddingOptions optionsToUse = this.mergeOptions(request.getOptions());
8086

81-
var apiRequest = new CohereEmbeddingRequest(request.getInstructions(), optionsToUse.getInputType(),
87+
List<String> truncatedInstructions = instructions.stream().map(text -> {
88+
if (text == null || text.isEmpty()) {
89+
return text;
90+
}
91+
92+
if (text.length() <= COHERE_MAX_CHARACTERS) {
93+
return text;
94+
}
95+
96+
// Handle truncation based on option
97+
return switch (optionsToUse.getTruncate()) {
98+
case END -> text.substring(0, COHERE_MAX_CHARACTERS); // Keep first 2048 chars
99+
case START -> text.substring(text.length() - COHERE_MAX_CHARACTERS); // Keep
100+
// last
101+
// 2048
102+
// chars
103+
default -> text.substring(0, COHERE_MAX_CHARACTERS); // Default to END
104+
// behavior
105+
};
106+
}).collect(Collectors.toList());
107+
108+
var apiRequest = new CohereEmbeddingRequest(truncatedInstructions, optionsToUse.getInputType(),
82109
optionsToUse.getTruncate());
83110
CohereEmbeddingResponse apiResponse = this.embeddingApi.embedding(apiRequest);
84111
var indexCounter = new AtomicInteger(0);

models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingModelIT.java

Lines changed: 102 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import com.fasterxml.jackson.databind.ObjectMapper;
2323
import org.junit.jupiter.api.Test;
2424
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
25+
import org.mockito.ArgumentCaptor;
2526
import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider;
2627
import software.amazon.awssdk.regions.Region;
2728

@@ -31,11 +32,14 @@
3132
import org.springframework.ai.embedding.EmbeddingRequest;
3233
import org.springframework.ai.embedding.EmbeddingResponse;
3334
import org.springframework.beans.factory.annotation.Autowired;
35+
import org.springframework.beans.factory.annotation.Qualifier;
3436
import org.springframework.boot.SpringBootConfiguration;
3537
import org.springframework.boot.test.context.SpringBootTest;
38+
import org.springframework.boot.test.mock.mockito.SpyBean;
3639
import org.springframework.context.annotation.Bean;
3740

3841
import static org.assertj.core.api.Assertions.assertThat;
42+
import static org.mockito.Mockito.verify;
3943

4044
@SpringBootTest
4145
@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*")
@@ -45,6 +49,13 @@ class BedrockCohereEmbeddingModelIT {
4549
@Autowired
4650
private BedrockCohereEmbeddingModel embeddingModel;
4751

52+
@SpyBean
53+
private CohereEmbeddingBedrockApi embeddingApi;
54+
55+
@Autowired
56+
@Qualifier("embeddingModelStartTruncate")
57+
private BedrockCohereEmbeddingModel embeddingModelStartTruncate;
58+
4859
@Test
4960
void singleEmbedding() {
5061
assertThat(this.embeddingModel).isNotNull();
@@ -54,6 +65,77 @@ void singleEmbedding() {
5465
assertThat(this.embeddingModel.dimensions()).isEqualTo(1024);
5566
}
5667

68+
@Test
69+
void truncatesLongText() {
70+
String longText = "Hello World".repeat(300);
71+
assertThat(longText.length()).isGreaterThan(2048);
72+
73+
EmbeddingResponse embeddingResponse = this.embeddingModel.embedForResponse(List.of(longText));
74+
75+
assertThat(embeddingResponse.getResults()).hasSize(1);
76+
assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty();
77+
assertThat(this.embeddingModel.dimensions()).isEqualTo(1024);
78+
}
79+
80+
@Test
81+
void truncatesMultipleLongTexts() {
82+
String longText1 = "Hello World".repeat(300);
83+
String longText2 = "Another Text".repeat(300);
84+
85+
EmbeddingResponse embeddingResponse = this.embeddingModel.embedForResponse(List.of(longText1, longText2));
86+
87+
assertThat(embeddingResponse.getResults()).hasSize(2);
88+
assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty();
89+
assertThat(embeddingResponse.getResults().get(1).getOutput()).isNotEmpty();
90+
assertThat(this.embeddingModel.dimensions()).isEqualTo(1024);
91+
}
92+
93+
@Test
94+
void verifyExactTruncationLength() {
95+
String longText = "x".repeat(3000);
96+
97+
ArgumentCaptor<CohereEmbeddingBedrockApi.CohereEmbeddingRequest> requestCaptor = ArgumentCaptor
98+
.forClass(CohereEmbeddingBedrockApi.CohereEmbeddingRequest.class);
99+
100+
EmbeddingResponse embeddingResponse = embeddingModel.embedForResponse(List.of(longText));
101+
102+
verify(embeddingApi).embedding(requestCaptor.capture());
103+
CohereEmbeddingBedrockApi.CohereEmbeddingRequest capturedRequest = requestCaptor.getValue();
104+
105+
assertThat(capturedRequest.texts()).hasSize(1);
106+
assertThat(capturedRequest.texts().get(0).length()).isLessThanOrEqualTo(2048);
107+
108+
assertThat(embeddingResponse.getResults()).hasSize(1);
109+
assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty();
110+
}
111+
112+
@Test
113+
void truncatesLongTextFromStart() {
114+
String startMarker = "START_MARKER_";
115+
String endMarker = "_END_MARKER";
116+
String middlePadding = "x".repeat(2500); // Long enough to force truncation
117+
String longText = startMarker + middlePadding + endMarker;
118+
119+
assertThat(longText.length()).isGreaterThan(2048);
120+
121+
ArgumentCaptor<CohereEmbeddingBedrockApi.CohereEmbeddingRequest> requestCaptor = ArgumentCaptor
122+
.forClass(CohereEmbeddingBedrockApi.CohereEmbeddingRequest.class);
123+
124+
EmbeddingResponse embeddingResponse = this.embeddingModelStartTruncate.embedForResponse(List.of(longText));
125+
126+
// Verify truncation behavior
127+
verify(embeddingApi).embedding(requestCaptor.capture());
128+
String truncatedText = requestCaptor.getValue().texts().get(0);
129+
assertThat(truncatedText.length()).isLessThanOrEqualTo(2048);
130+
assertThat(truncatedText).doesNotContain(startMarker);
131+
assertThat(truncatedText).endsWith(endMarker);
132+
133+
// Verify embedding response
134+
assertThat(embeddingResponse.getResults()).hasSize(1);
135+
assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty();
136+
assertThat(this.embeddingModelStartTruncate.dimensions()).isEqualTo(1024);
137+
}
138+
57139
@Test
58140
void batchEmbedding() {
59141
assertThat(this.embeddingModel).isNotNull();
@@ -93,9 +175,27 @@ public CohereEmbeddingBedrockApi cohereEmbeddingApi() {
93175
Duration.ofMinutes(2));
94176
}
95177

96-
@Bean
178+
@Bean("embeddingModel")
97179
public BedrockCohereEmbeddingModel cohereAiEmbedding(CohereEmbeddingBedrockApi cohereEmbeddingApi) {
98-
return new BedrockCohereEmbeddingModel(cohereEmbeddingApi);
180+
// custom model that uses the END truncation strategy, instead of the default
181+
// NONE.
182+
return new BedrockCohereEmbeddingModel(cohereEmbeddingApi,
183+
BedrockCohereEmbeddingOptions.builder()
184+
.withInputType(CohereEmbeddingBedrockApi.CohereEmbeddingRequest.InputType.SEARCH_DOCUMENT)
185+
.withTruncate(CohereEmbeddingBedrockApi.CohereEmbeddingRequest.Truncate.END)
186+
.build());
187+
}
188+
189+
@Bean("embeddingModelStartTruncate")
190+
public BedrockCohereEmbeddingModel cohereAiEmbeddingStartTruncate(
191+
CohereEmbeddingBedrockApi cohereEmbeddingApi) {
192+
// custom model that uses the START truncation strategy, instead of the
193+
// default NONE.
194+
return new BedrockCohereEmbeddingModel(cohereEmbeddingApi,
195+
BedrockCohereEmbeddingOptions.builder()
196+
.withInputType(CohereEmbeddingBedrockApi.CohereEmbeddingRequest.InputType.SEARCH_DOCUMENT)
197+
.withTruncate(CohereEmbeddingBedrockApi.CohereEmbeddingRequest.Truncate.START)
198+
.build());
99199
}
100200

101201
}

spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/bedrock-cohere-embedding.adoc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ The prefix `spring.ai.bedrock.cohere.embedding` (defined in `BedrockCohereEmbedd
7373
| spring.ai.bedrock.cohere.embedding.options.truncate | Specifies how the API handles inputs longer than the maximum token length. If you specify LEFT or RIGHT, the model discards the input until the remaining input is exactly the maximum input token length for the model. | NONE
7474
|====
7575

76+
NOTE: When accessing Cohere via Amazon Bedrock, the functionality of truncating is not available. This is an issue with Amazon Bedrock. The Spring AI class `BedrockCohereEmbeddingModel` will truncate to 2048 character length, which is the maximum supported by the model.
77+
7678
Look at the https://github.com/spring-projects/spring-ai/blob/056b95a00efa5b014a1f488329fbd07a46c02378/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereEmbeddingBedrockApi.java#L150[CohereEmbeddingModel] for other model IDs.
7779
Supported values are: `cohere.embed-multilingual-v3` and `cohere.embed-english-v3`.
7880
Model ID values can also be found in the https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html[AWS Bedrock documentation for base model IDs].

0 commit comments

Comments
 (0)