Skip to content

Commit ed52a3e

Browse files
committed
Add Bedrock Titan Chat Options + Docs
1 parent 4839a61 commit ed52a3e

File tree

12 files changed

+521
-189
lines changed

12 files changed

+521
-189
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ The Spring AI project provides a Spring-friendly API and abstractions for develo
66

77
Let's make your `@Beans` intelligent!
88

9+
For further information go to our [Spring AI documentation](https://docs.spring.io/spring-ai/reference/).
10+
911
## Project Update
1012

1113
:partying_face: The Spring AI project has graduated out of the repository!
Lines changed: 2 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,83 +1,3 @@
1-
# 1. Bedrock Titan Chat
1+
# Bedrock Titan Chat
22

3-
## 1.1 TitanChatBedrockApi
4-
5-
[TitanChatBedrockApi](./src/main/java/org/springframework/ai/bedrock/titan/api/TitanChatBedrockApi.java) provides is lightweight Java client on top of AWS Bedrock [Titan text models](https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-text.html).
6-
7-
Following class diagram illustrates the Llama2ChatBedrockApi interface and building blocks:
8-
9-
![TitanChatBedrockApi Class Diagram](./src/test/resources/doc/Bedrock%20Titan%20Chat%20API.jpg)
10-
11-
The TitanChatBedrockApi supports the `amazon.titan-text-lite-v1` and `amazon.titan-text-express-v1` models for bot synchronous (e.g. `chatCompletion()`) and streaming (e.g. `chatCompletionStream()`) responses.
12-
13-
Here is a simple snippet how to use the api programmatically:
14-
15-
```java
16-
TitanChatBedrockApi titanBedrockApi = new TitanChatBedrockApi(TitanChatCompletionModel.TITAN_TEXT_EXPRESS_V1.id(),
17-
Region.EU_CENTRAL_1.id());
18-
19-
TitanChatRequest titanChatRequest = TitanChatRequest.builder("Give me the names of 3 famous pirates?")
20-
.withTemperature(0.5f)
21-
.withTopP(0.9f)
22-
.withMaxTokenCount(100)
23-
.withStopSequences(List.of("|"))
24-
.build();
25-
26-
TitanChatResponse response = titanBedrockApi.chatCompletion(titanChatRequest);
27-
28-
assertThat(response.results()).hasSize(1);
29-
assertThat(response.results().get(0).outputText()).contains("Blackbeard");
30-
31-
Flux<TitanChatResponseChunk> response = titanBedrockApi.chatCompletionStream(titanChatRequest);
32-
33-
List<TitanChatResponseChunk> results = response.collectList().block();
34-
assertThat(results.stream().map(TitanChatResponseChunk::outputText).collect(Collectors.joining("\n")))
35-
.contains("Blackbeard");
36-
```
37-
38-
## 1.2 BedrockTitanChatClient
39-
40-
[BedrockTitanChatClient](./src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatClient.java) implements the Spring-Ai `ChatClient` and `StreamingChatClient` on top of the `TitanChatBedrockApi`.
41-
42-
You can use like this:
43-
44-
```java
45-
@Bean
46-
public TitanChatBedrockApi titanApi() {
47-
return new TitanChatBedrockApi(TitanChatModel.TITAN_TEXT_EXPRESS_V1.id(),
48-
EnvironmentVariableCredentialsProvider.create(), Region.US_EAST_1.id(), new ObjectMapper());
49-
}
50-
51-
@Bean
52-
public BedrockTitanChatClient titanChatClient(TitanChatBedrockApi titanApi) {
53-
return new BedrockTitanChatClient(titanApi);
54-
}
55-
56-
```
57-
58-
or you can leverage the `spring-ai-bedrock-ai-spring-boot-starter` Boot starter. For this add the following dependency:
59-
60-
```xml
61-
<dependency>
62-
<artifactId>spring-ai-bedrock-ai-spring-boot-starter</artifactId>
63-
<groupId>org.springframework.ai</groupId>
64-
<version>0.8.0-SNAPSHOT</version>
65-
</dependency>
66-
```
67-
68-
**NOTE:** You have to enable the Bedrock Titan chat client with `spring.ai.bedrock.titan.chat.enabled=true`.
69-
By default the client is disabled.
70-
71-
Use the `BedrockTitanChatProperties` to configure the Bedrock Titan Chat client:
72-
73-
| Property | Description | Default |
74-
| ------------- | ------------- | ------------- |
75-
| spring.ai.bedrock.aws.region | AWS region to use. | us-east-1 |
76-
| spring.ai.bedrock.aws.accessKey | AWS credentials access key. | |
77-
| spring.ai.bedrock.aws.secretKey | AWS credentials secret key. | |
78-
| spring.ai.bedrock.titan.chat.enable | Enable Bedrock Titan chat client. Disabled by default | false |
79-
| spring.ai.bedrock.titan.chat.model | The model id to use. See the `TitanChatModel` for the supported models. | amazon.titan-text-express-v1 |
80-
| spring.ai.bedrock.titan.chat.temperature | Controls the randomness of the output. Values can range over [0.0,1.0] | 0.7 |
81-
| spring.ai.bedrock.titan.chat.topP | The maximum cumulative probability of tokens to consider when sampling. | AWS Bedrock default |
82-
| spring.ai.bedrock.titan.chat.maxTokenCount | Specify the maximum number of tokens to use in the generated response. | AWS Bedrock default |
83-
| spring.ai.bedrock.titan.chat.stopSequences | Configure up to four sequences that the model recognizes. | AWS Bedrock default |
3+
Visit the Spring AI [Bedrock Titan Chat Documentation](https://docs.spring.io/spring-ai/reference/api/clients/bedrock/bedrock-titan.html).

models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatClient.java

Lines changed: 56 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2023-2023 the original author or authors.
2+
* Copyright 2023-2024 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -18,20 +18,23 @@
1818

1919
import java.util.List;
2020

21-
import org.springframework.ai.chat.ChatClient;
22-
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
2321
import reactor.core.publisher.Flux;
2422

2523
import org.springframework.ai.bedrock.MessageToPromptConverter;
2624
import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi;
2725
import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatRequest;
2826
import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatResponse;
2927
import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatResponseChunk;
28+
import org.springframework.ai.chat.ChatClient;
29+
import org.springframework.ai.chat.ChatOptions;
3030
import org.springframework.ai.chat.ChatResponse;
31-
import org.springframework.ai.chat.StreamingChatClient;
3231
import org.springframework.ai.chat.Generation;
32+
import org.springframework.ai.chat.StreamingChatClient;
33+
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
3334
import org.springframework.ai.chat.metadata.Usage;
3435
import org.springframework.ai.chat.prompt.Prompt;
36+
import org.springframework.ai.model.ModelOptionsUtils;
37+
import org.springframework.util.Assert;
3538

3639
/**
3740
* @author Christian Tzolov
@@ -41,41 +44,22 @@ public class BedrockTitanChatClient implements ChatClient, StreamingChatClient {
4144

4245
private final TitanChatBedrockApi chatApi;
4346

44-
private Float temperature;
45-
46-
private Float topP;
47-
48-
private Integer maxTokenCount;
49-
50-
private List<String> stopSequences;
47+
private final BedrockTitanChatOptions defaultOptions;
5148

5249
public BedrockTitanChatClient(TitanChatBedrockApi chatApi) {
53-
this.chatApi = chatApi;
50+
this(chatApi, BedrockTitanChatOptions.builder().withTemperature(0.8f).build());
5451
}
5552

56-
public BedrockTitanChatClient withTemperature(Float temperature) {
57-
this.temperature = temperature;
58-
return this;
59-
}
60-
61-
public BedrockTitanChatClient withTopP(Float topP) {
62-
this.topP = topP;
63-
return this;
64-
}
65-
66-
public BedrockTitanChatClient withMaxTokenCount(Integer maxTokens) {
67-
this.maxTokenCount = maxTokens;
68-
return this;
69-
}
70-
71-
public BedrockTitanChatClient withStopSequences(List<String> stopSequences) {
72-
this.stopSequences = stopSequences;
73-
return this;
53+
public BedrockTitanChatClient(TitanChatBedrockApi chatApi, BedrockTitanChatOptions defaultOptions) {
54+
Assert.notNull(chatApi, "ChatApi must not be null");
55+
Assert.notNull(defaultOptions, "DefaultOptions must not be null");
56+
this.chatApi = chatApi;
57+
this.defaultOptions = defaultOptions;
7458
}
7559

7660
@Override
7761
public ChatResponse call(Prompt prompt) {
78-
TitanChatResponse response = this.chatApi.chatCompletion(this.createRequest(prompt, false));
62+
TitanChatResponse response = this.chatApi.chatCompletion(this.createRequest(prompt));
7963
List<Generation> generations = response.results().stream().map(result -> {
8064
return new Generation(result.outputText());
8165
}).toList();
@@ -85,7 +69,7 @@ public ChatResponse call(Prompt prompt) {
8569

8670
@Override
8771
public Flux<ChatResponse> stream(Prompt prompt) {
88-
return this.chatApi.chatCompletionStream(this.createRequest(prompt, true)).map(chunk -> {
72+
return this.chatApi.chatCompletionStream(this.createRequest(prompt)).map(chunk -> {
8973

9074
Generation generation = new Generation(chunk.outputText());
9175

@@ -104,15 +88,48 @@ else if (chunk.inputTextTokenCount() != null && chunk.totalOutputTextTokenCount(
10488
});
10589
}
10690

107-
private TitanChatRequest createRequest(Prompt prompt, boolean stream) {
91+
/**
92+
* Test access.
93+
*/
94+
TitanChatRequest createRequest(Prompt prompt) {
10895
final String promptValue = MessageToPromptConverter.create().toPrompt(prompt.getInstructions());
10996

110-
return TitanChatRequest.builder(promptValue)
111-
.withTemperature(this.temperature)
112-
.withTopP(this.topP)
113-
.withMaxTokenCount(this.maxTokenCount)
114-
.withStopSequences(this.stopSequences)
115-
.build();
97+
var requestBuilder = TitanChatRequest.builder(promptValue);
98+
99+
if (this.defaultOptions != null) {
100+
requestBuilder = update(requestBuilder, this.defaultOptions);
101+
}
102+
103+
if (prompt.getOptions() != null) {
104+
if (prompt.getOptions() instanceof ChatOptions runtimeOptions) {
105+
BedrockTitanChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions,
106+
ChatOptions.class, BedrockTitanChatOptions.class);
107+
108+
requestBuilder = update(requestBuilder, updatedRuntimeOptions);
109+
}
110+
else {
111+
throw new IllegalArgumentException("Prompt options are not of type ChatOptions: "
112+
+ prompt.getOptions().getClass().getSimpleName());
113+
}
114+
}
115+
116+
return requestBuilder.build();
117+
}
118+
119+
private TitanChatRequest.Builder update(TitanChatRequest.Builder builder, BedrockTitanChatOptions options) {
120+
if (options.getTemperature() != null) {
121+
builder.withTemperature(options.getTemperature());
122+
}
123+
if (options.getTopP() != null) {
124+
builder.withTopP(options.getTopP());
125+
}
126+
if (options.getMaxTokenCount() != null) {
127+
builder.withMaxTokenCount(options.getMaxTokenCount());
128+
}
129+
if (options.getStopSequences() != null) {
130+
builder.withStopSequences(options.getStopSequences());
131+
}
132+
return builder;
116133
}
117134

118135
private Usage extractUsage(TitanChatResponseChunk response) {
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
/*
2+
* Copyright 2024-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.bedrock.titan;
18+
19+
import java.util.List;
20+
21+
import com.fasterxml.jackson.annotation.JsonInclude;
22+
import com.fasterxml.jackson.annotation.JsonInclude.Include;
23+
24+
import org.springframework.ai.chat.ChatOptions;
25+
26+
import com.fasterxml.jackson.annotation.JsonProperty;
27+
28+
/**
29+
* @author Christian Tzolov
30+
* @since 0.8.0
31+
*/
32+
@JsonInclude(Include.NON_NULL)
33+
public class BedrockTitanChatOptions implements ChatOptions {
34+
35+
// @formatter:off
36+
/**
37+
* The temperature value controls the randomness of the generated text.
38+
*/
39+
private @JsonProperty("temperature") Float temperature;
40+
41+
/**
42+
* The topP value controls the diversity of the generated text. Use a lower value to ignore less probable options.
43+
*/
44+
private @JsonProperty("topP") Float topP;
45+
46+
/**
47+
* Maximum number of tokens to generate.
48+
*/
49+
private @JsonProperty("maxTokenCount") Integer maxTokenCount;
50+
51+
/**
52+
* A list of tokens that the model should stop generating after.
53+
*/
54+
private @JsonProperty("stopSequences") List<String> stopSequences;
55+
// @formatter:on
56+
57+
public static Builder builder() {
58+
return new Builder();
59+
}
60+
61+
public static class Builder {
62+
63+
private BedrockTitanChatOptions options = new BedrockTitanChatOptions();
64+
65+
public Builder withTemperature(Float temperature) {
66+
this.options.temperature = temperature;
67+
return this;
68+
}
69+
70+
public Builder withTopP(Float topP) {
71+
this.options.topP = topP;
72+
return this;
73+
}
74+
75+
public Builder withMaxTokenCount(Integer maxTokenCount) {
76+
this.options.maxTokenCount = maxTokenCount;
77+
return this;
78+
}
79+
80+
public Builder withStopSequences(List<String> stopSequences) {
81+
this.options.stopSequences = stopSequences;
82+
return this;
83+
}
84+
85+
public BedrockTitanChatOptions build() {
86+
return this.options;
87+
}
88+
89+
}
90+
91+
public Float getTemperature() {
92+
return temperature;
93+
}
94+
95+
public void setTemperature(Float temperature) {
96+
this.temperature = temperature;
97+
}
98+
99+
public Float getTopP() {
100+
return topP;
101+
}
102+
103+
public void setTopP(Float topP) {
104+
this.topP = topP;
105+
}
106+
107+
public Integer getMaxTokenCount() {
108+
return maxTokenCount;
109+
}
110+
111+
public void setMaxTokenCount(Integer maxTokenCount) {
112+
this.maxTokenCount = maxTokenCount;
113+
}
114+
115+
public List<String> getStopSequences() {
116+
return stopSequences;
117+
}
118+
119+
public void setStopSequences(List<String> stopSequences) {
120+
this.stopSequences = stopSequences;
121+
}
122+
123+
@Override
124+
public Integer getTopK() {
125+
throw new UnsupportedOperationException("Bedrock Titian Chat does not support the 'TopK' option.");
126+
}
127+
128+
@Override
129+
public void setTopK(Integer topK) {
130+
throw new UnsupportedOperationException("Bedrock Titian Chat does not support the 'TopK' option.'");
131+
}
132+
133+
}

0 commit comments

Comments
 (0)