Skip to content

Commit a50969e

Browse files
author
wmz7year
committed
Bedrock Titan embedding client adds BedrockTitanEmbeddingOptions to support dynamic embedding request types.
1 parent 5f9ecdd commit a50969e

File tree

5 files changed

+113
-7
lines changed

5 files changed

+113
-7
lines changed

models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/aot/BedrockRuntimeHints.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.springframework.ai.bedrock.llama.BedrockLlamaChatOptions;
2929
import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi;
3030
import org.springframework.ai.bedrock.titan.BedrockTitanChatOptions;
31+
import org.springframework.ai.bedrock.titan.BedrockTitanEmbeddingOptions;
3132
import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi;
3233
import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi;
3334
import org.springframework.aot.hint.MemberCategory;
@@ -43,6 +44,7 @@
4344
* @author Josh Long
4445
* @author Christian Tzolov
4546
* @author Mark Pollack
47+
* @author Wei Jiang
4648
*/
4749
public class BedrockRuntimeHints implements RuntimeHintsRegistrar {
4850

@@ -72,6 +74,8 @@ public void registerHints(RuntimeHints hints, ClassLoader classLoader) {
7274
hints.reflection().registerType(tr, mcs);
7375
for (var tr : findJsonAnnotatedClassesInPackage(BedrockTitanChatOptions.class))
7476
hints.reflection().registerType(tr, mcs);
77+
for (var tr : findJsonAnnotatedClassesInPackage(BedrockTitanEmbeddingOptions.class))
78+
hints.reflection().registerType(tr, mcs);
7579
for (var tr : findJsonAnnotatedClassesInPackage(TitanEmbeddingBedrockApi.class))
7680
hints.reflection().registerType(tr, mcs);
7781

models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingClient.java

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.springframework.ai.document.Document;
2929
import org.springframework.ai.embedding.AbstractEmbeddingClient;
3030
import org.springframework.ai.embedding.Embedding;
31+
import org.springframework.ai.embedding.EmbeddingOptions;
3132
import org.springframework.ai.embedding.EmbeddingRequest;
3233
import org.springframework.ai.embedding.EmbeddingResponse;
3334
import org.springframework.util.Assert;
@@ -40,6 +41,7 @@
4041
* Note: Titan Embedding does not support batch embedding.
4142
*
4243
* @author Christian Tzolov
44+
* @author Wei Jiang
4345
* @since 0.8.0
4446
*/
4547
public class BedrockTitanEmbeddingClient extends AbstractEmbeddingClient {
@@ -87,9 +89,7 @@ public EmbeddingResponse call(EmbeddingRequest request) {
8789

8890
List<List<Double>> embeddingList = new ArrayList<>();
8991
for (String inputContent : request.getInstructions()) {
90-
var apiRequest = (this.inputType == InputType.IMAGE)
91-
? new TitanEmbeddingRequest.Builder().withInputImage(inputContent).build()
92-
: new TitanEmbeddingRequest.Builder().withInputText(inputContent).build();
92+
var apiRequest = createTitanEmbeddingRequest(inputContent, request.getOptions());
9393
TitanEmbeddingResponse response = this.embeddingApi.embedding(apiRequest);
9494
embeddingList.add(response.embedding());
9595
}
@@ -100,6 +100,18 @@ public EmbeddingResponse call(EmbeddingRequest request) {
100100
return new EmbeddingResponse(embeddings);
101101
}
102102

103+
private TitanEmbeddingRequest createTitanEmbeddingRequest(String inputContent, EmbeddingOptions requestOptions) {
104+
InputType inputType = this.inputType;
105+
106+
if (requestOptions != null
107+
&& requestOptions instanceof BedrockTitanEmbeddingOptions bedrockTitanEmbeddingOptions) {
108+
inputType = bedrockTitanEmbeddingOptions.getInputType();
109+
}
110+
111+
return (inputType == InputType.IMAGE) ? new TitanEmbeddingRequest.Builder().withInputImage(inputContent).build()
112+
: new TitanEmbeddingRequest.Builder().withInputText(inputContent).build();
113+
}
114+
103115
@Override
104116
public int dimensions() {
105117
if (this.inputType == InputType.IMAGE) {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
/*
2+
* Copyright 2023 - 2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.ai.bedrock.titan;
17+
18+
import com.fasterxml.jackson.annotation.JsonInclude;
19+
import com.fasterxml.jackson.annotation.JsonInclude.Include;
20+
21+
import org.springframework.ai.bedrock.titan.BedrockTitanEmbeddingClient.InputType;
22+
import org.springframework.ai.embedding.EmbeddingOptions;
23+
import org.springframework.util.Assert;
24+
25+
/**
26+
* @author Wei Jiang
27+
*/
28+
@JsonInclude(Include.NON_NULL)
29+
public class BedrockTitanEmbeddingOptions implements EmbeddingOptions {
30+
31+
/**
32+
* Titan Embedding API input types. Could be either text or image (encoded in base64).
33+
*/
34+
private InputType inputType;
35+
36+
public static Builder builder() {
37+
return new Builder();
38+
}
39+
40+
public static class Builder {
41+
42+
private BedrockTitanEmbeddingOptions options = new BedrockTitanEmbeddingOptions();
43+
44+
public Builder withInputType(InputType inputType) {
45+
Assert.notNull(inputType, "input type can not be null.");
46+
47+
this.options.setInputType(inputType);
48+
return this;
49+
}
50+
51+
public BedrockTitanEmbeddingOptions build() {
52+
return this.options;
53+
}
54+
55+
}
56+
57+
public InputType getInputType() {
58+
return this.inputType;
59+
}
60+
61+
public void setInputType(InputType inputType) {
62+
this.inputType = inputType;
63+
}
64+
65+
}

models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingClientIT.java

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,23 @@
2222

2323
import org.junit.jupiter.api.Test;
2424
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
25+
26+
import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider;
2527
import software.amazon.awssdk.regions.Region;
2628

29+
import org.springframework.ai.bedrock.titan.BedrockTitanEmbeddingClient.InputType;
2730
import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi;
2831
import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi.TitanEmbeddingModel;
32+
import org.springframework.ai.embedding.EmbeddingRequest;
2933
import org.springframework.ai.embedding.EmbeddingResponse;
3034
import org.springframework.beans.factory.annotation.Autowired;
3135
import org.springframework.boot.SpringBootConfiguration;
3236
import org.springframework.boot.test.context.SpringBootTest;
3337
import org.springframework.context.annotation.Bean;
3438
import org.springframework.core.io.DefaultResourceLoader;
3539

40+
import com.fasterxml.jackson.databind.ObjectMapper;
41+
3642
import static org.assertj.core.api.Assertions.assertThat;
3743

3844
@SpringBootTest
@@ -46,7 +52,8 @@ class BedrockTitanEmbeddingClientIT {
4652
@Test
4753
void singleEmbedding() {
4854
assertThat(embeddingClient).isNotNull();
49-
EmbeddingResponse embeddingResponse = embeddingClient.embedForResponse(List.of("Hello World"));
55+
EmbeddingResponse embeddingResponse = embeddingClient.call(new EmbeddingRequest(List.of("Hello World"),
56+
BedrockTitanEmbeddingOptions.builder().withInputType(InputType.TEXT).build()));
5057
assertThat(embeddingResponse.getResults()).hasSize(1);
5158
assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty();
5259
assertThat(embeddingClient.dimensions()).isEqualTo(1024);
@@ -59,7 +66,8 @@ void imageEmbedding() throws IOException {
5966
.getContentAsByteArray();
6067

6168
EmbeddingResponse embeddingResponse = embeddingClient
62-
.embedForResponse(List.of(Base64.getEncoder().encodeToString(image)));
69+
.call(new EmbeddingRequest(List.of(Base64.getEncoder().encodeToString(image)),
70+
BedrockTitanEmbeddingOptions.builder().withInputType(InputType.IMAGE).build()));
6371
assertThat(embeddingResponse.getResults()).hasSize(1);
6472
assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty();
6573
assertThat(embeddingClient.dimensions()).isEqualTo(1024);
@@ -70,7 +78,8 @@ public static class TestConfiguration {
7078

7179
@Bean
7280
public TitanEmbeddingBedrockApi titanEmbeddingApi() {
73-
return new TitanEmbeddingBedrockApi(TitanEmbeddingModel.TITAN_EMBED_IMAGE_V1.id(), Region.US_EAST_1.id(),
81+
return new TitanEmbeddingBedrockApi(TitanEmbeddingModel.TITAN_EMBED_IMAGE_V1.id(),
82+
EnvironmentVariableCredentialsProvider.create(), Region.US_EAST_1.id(), new ObjectMapper(),
7483
Duration.ofMinutes(2));
7584
}
7685

spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/bedrock-titan-embedding.adoc

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,22 @@ The prefix `spring.ai.bedrock.titan.embedding` (defined in `BedrockTitanEmbeddin
8181
Supported values are: `amazon.titan-embed-image-v1` and `amazon.titan-embed-text-v1`.
8282
Model ID values can also be found in the https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html[AWS Bedrock documentation for base model IDs].
8383

84+
== Runtime Options [[embedding-options]]
85+
86+
The https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingOptions.java[BedrockTitanEmbeddingOptions.java] provides model configurations, such as `input-type`.
87+
On start-up, the default options can be configured with the `BedrockTitanEmbeddingClient(api).withInputType(type)` method or the `spring.ai.bedrock.titan.embedding.input-type` properties.
88+
89+
At run-time you can override the default options by adding new, request specific, options to the `EmbeddingRequest` call.
90+
For example to override the default temperature for a specific request:
91+
92+
[source,java]
93+
----
94+
EmbeddingResponse embeddingResponse = embeddingClient.call(
95+
new EmbeddingRequest(List.of("Hello World", "World is big and salvation is near"),
96+
BedrockTitanEmbeddingOptions.builder()
97+
.withInputType(InputType.TEXT)
98+
.build()));
99+
----
84100

85101
== Sample Controller
86102

@@ -154,7 +170,7 @@ Next, create an https://github.com/spring-projects/spring-ai/blob/main/models/sp
154170
var titanEmbeddingApi = new TitanEmbeddingBedrockApi(
155171
TitanEmbeddingModel.TITAN_EMBED_IMAGE_V1.id(), Region.US_EAST_1.id());
156172
157-
var embeddingClient new BedrockTitanEmbeddingClient(titanEmbeddingApi);
173+
var embeddingClient = new BedrockTitanEmbeddingClient(titanEmbeddingApi);
158174
159175
EmbeddingResponse embeddingResponse = embeddingClient
160176
.embedForResponse(List.of("Hello World")); // NOTE titan does not support batch embedding.

0 commit comments

Comments
 (0)