Skip to content

Commit 3497a1e

Browse files
wmz7yeartzolov
authored andcommitted
Fix Bedrock Cohere embedding truncate type types
- fix compilation errors and javadoc
1 parent f955fd7 commit 3497a1e

File tree

3 files changed

+36
-10
lines changed

3 files changed

+36
-10
lines changed

models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereEmbeddingBedrockApi.java

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@
3232

3333
/**
3434
* Cohere Embedding API.
35-
* https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere.html#model-parameters-embed
35+
* <a href="https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere.html#model-parameters-embed">AWS Bedrock Cohere Embedding API</a>
36+
* Based on the <a href="https://docs.cohere.com/reference/embed">Cohere Embedding API</a>
3637
*
3738
* @author Christian Tzolov
3839
* @author Wei Jiang
@@ -151,22 +152,21 @@ public enum InputType {
151152
}
152153

153154
/**
154-
* Specifies how the API handles inputs longer than the maximum token length. If you specify LEFT or RIGHT, the
155-
* model discards the input until the remaining input is exactly the maximum input token length for the model.
155+
* Specifies how the API handles inputs longer than the maximum token length. Passing START will discard the start of the input. END will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model.
156156
*/
157157
public enum Truncate {
158158
/**
159-
* (Default) Returns an error when the input exceeds the maximum input token length.
159+
* Returns an error when the input exceeds the maximum input token length.
160160
*/
161161
NONE,
162162
/**
163-
* Discard the start of the input.
163+
* Discards the start of the input.
164164
*/
165-
LEFT,
165+
START,
166166
/**
167-
* Discards the end of the input.
167+
* (default) Discards the end of the input.
168168
*/
169-
RIGHT
169+
END
170170
}
171171
}
172172

models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/api/CohereEmbeddingBedrockApiIT.java

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
/**
3434
* @author Christian Tzolov
35+
* @author Wei Jiang
3536
*/
3637
@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*")
3738
@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*")
@@ -56,4 +57,29 @@ public void embedText() {
5657
assertThat(response.embeddings().get(0)).hasSize(1024);
5758
}
5859

60+
@Test
61+
public void embedTextWithTruncate() {
62+
63+
CohereEmbeddingRequest request = new CohereEmbeddingRequest(
64+
List.of("I like to eat apples", "I like to eat oranges"),
65+
CohereEmbeddingRequest.InputType.SEARCH_DOCUMENT, CohereEmbeddingRequest.Truncate.START);
66+
67+
CohereEmbeddingResponse response = api.embedding(request);
68+
69+
assertThat(response).isNotNull();
70+
assertThat(response.texts()).isEqualTo(request.texts());
71+
assertThat(response.embeddings()).hasSize(2);
72+
assertThat(response.embeddings().get(0)).hasSize(1024);
73+
74+
request = new CohereEmbeddingRequest(List.of("I like to eat apples", "I like to eat oranges"),
75+
CohereEmbeddingRequest.InputType.SEARCH_DOCUMENT, CohereEmbeddingRequest.Truncate.END);
76+
77+
response = api.embedding(request);
78+
79+
assertThat(response).isNotNull();
80+
assertThat(response.texts()).isEqualTo(request.texts());
81+
assertThat(response.embeddings()).hasSize(2);
82+
assertThat(response.embeddings().get(0)).hasSize(1024);
83+
}
84+
5985
}

spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereEmbeddingAutoConfigurationIT.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ public void propertiesTest() {
9090
"spring.ai.bedrock.aws.region=" + Region.EU_CENTRAL_1.id(),
9191
"spring.ai.bedrock.cohere.embedding.model=MODEL_XYZ",
9292
"spring.ai.bedrock.cohere.embedding.options.inputType=CLASSIFICATION",
93-
"spring.ai.bedrock.cohere.embedding.options.truncate=RIGHT")
93+
"spring.ai.bedrock.cohere.embedding.options.truncate=START")
9494
.withConfiguration(AutoConfigurations.of(BedrockCohereEmbeddingAutoConfiguration.class))
9595
.run(context -> {
9696
var properties = context.getBean(BedrockCohereEmbeddingProperties.class);
@@ -101,7 +101,7 @@ public void propertiesTest() {
101101
assertThat(properties.getModel()).isEqualTo("MODEL_XYZ");
102102

103103
assertThat(properties.getOptions().getInputType()).isEqualTo(InputType.CLASSIFICATION);
104-
assertThat(properties.getOptions().getTruncate()).isEqualTo(CohereEmbeddingRequest.Truncate.RIGHT);
104+
assertThat(properties.getOptions().getTruncate()).isEqualTo(CohereEmbeddingRequest.Truncate.START);
105105

106106
assertThat(awsProperties.getAccessKey()).isEqualTo("ACCESS_KEY");
107107
assertThat(awsProperties.getSecretKey()).isEqualTo("SECRET_KEY");

0 commit comments

Comments
 (0)