Skip to content

Commit 577a605

Browse files
committed
Add Bedrock Llama2 Options Support
1 parent ee81558 commit 577a605

File tree

13 files changed

+448
-189
lines changed

13 files changed

+448
-189
lines changed

models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama2/BedrockLlama2ChatClient.java

Lines changed: 44 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,20 @@
2222
import reactor.core.publisher.Flux;
2323

2424
import org.springframework.ai.bedrock.MessageToPromptConverter;
25+
import org.springframework.ai.bedrock.anthropic.AnthropicChatOptions;
26+
import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi.AnthropicChatRequest;
2527
import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi;
2628
import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi.Llama2ChatRequest;
2729
import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi.Llama2ChatResponse;
2830
import org.springframework.ai.chat.ChatClient;
31+
import org.springframework.ai.chat.ChatOptions;
2932
import org.springframework.ai.chat.StreamingChatClient;
3033
import org.springframework.ai.chat.Generation;
3134
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
3235
import org.springframework.ai.chat.metadata.Usage;
3336
import org.springframework.ai.chat.prompt.Prompt;
37+
import org.springframework.ai.model.ModelOptionsUtils;
38+
import org.springframework.util.Assert;
3439

3540
/**
3641
* Java {@link ChatClient} and {@link StreamingChatClient} for the Bedrock Llama2 chat
@@ -43,40 +48,25 @@ public class BedrockLlama2ChatClient implements ChatClient, StreamingChatClient
4348

4449
private final Llama2ChatBedrockApi chatApi;
4550

46-
private Float temperature;
47-
48-
private Float topP;
49-
50-
private Integer maxGenLen;
51+
private final BedrockLlama2ChatOptions defaultOptions;
5152

5253
public BedrockLlama2ChatClient(Llama2ChatBedrockApi chatApi) {
53-
this.chatApi = chatApi;
54+
this(chatApi,
55+
BedrockLlama2ChatOptions.builder().withTemperature(0.8f).withTopP(0.9f).withMaxGenLen(100).build());
5456
}
5557

56-
public BedrockLlama2ChatClient withTemperature(Float temperature) {
57-
this.temperature = temperature;
58-
return this;
59-
}
60-
61-
public BedrockLlama2ChatClient withTopP(Float topP) {
62-
this.topP = topP;
63-
return this;
64-
}
58+
public BedrockLlama2ChatClient(Llama2ChatBedrockApi chatApi, BedrockLlama2ChatOptions options) {
59+
Assert.notNull(chatApi, "Llama2ChatBedrockApi must not be null");
60+
Assert.notNull(options, "BedrockLlama2ChatOptions must not be null");
6561

66-
public BedrockLlama2ChatClient withMaxGenLen(Integer maxGenLen) {
67-
this.maxGenLen = maxGenLen;
68-
return this;
62+
this.chatApi = chatApi;
63+
this.defaultOptions = options;
6964
}
7065

7166
@Override
7267
public ChatResponse call(Prompt prompt) {
73-
final String promptValue = MessageToPromptConverter.create().toPrompt(prompt.getInstructions());
7468

75-
var request = Llama2ChatRequest.builder(promptValue)
76-
.withTemperature(this.temperature)
77-
.withTopP(this.topP)
78-
.withMaxGenLen(this.maxGenLen)
79-
.build();
69+
var request = createRequest(prompt);
8070

8171
Llama2ChatResponse response = this.chatApi.chatCompletion(request);
8272

@@ -87,13 +77,7 @@ public ChatResponse call(Prompt prompt) {
8777
@Override
8878
public Flux<ChatResponse> stream(Prompt prompt) {
8979

90-
final String promptValue = MessageToPromptConverter.create().toPrompt(prompt.getInstructions());
91-
92-
var request = Llama2ChatRequest.builder(promptValue)
93-
.withTemperature(this.temperature)
94-
.withTopP(this.topP)
95-
.withMaxGenLen(this.maxGenLen)
96-
.build();
80+
var request = createRequest(prompt);
9781

9882
Flux<Llama2ChatResponse> fluxResponse = this.chatApi.chatCompletionStream(request);
9983

@@ -119,4 +103,33 @@ public Long getGenerationTokens() {
119103
};
120104
}
121105

106+
/**
107+
* Accessible for testing.
108+
*/
109+
Llama2ChatRequest createRequest(Prompt prompt) {
110+
111+
final String promptValue = MessageToPromptConverter.create().toPrompt(prompt.getInstructions());
112+
113+
Llama2ChatRequest request = Llama2ChatRequest.builder(promptValue).build();
114+
115+
if (this.defaultOptions != null) {
116+
request = ModelOptionsUtils.merge(request, this.defaultOptions, Llama2ChatRequest.class);
117+
}
118+
119+
if (prompt.getOptions() != null) {
120+
if (prompt.getOptions() instanceof ChatOptions runtimeOptions) {
121+
BedrockLlama2ChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions,
122+
ChatOptions.class, BedrockLlama2ChatOptions.class);
123+
124+
request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, Llama2ChatRequest.class);
125+
}
126+
else {
127+
throw new IllegalArgumentException("Prompt options are not of type ChatOptions: "
128+
+ prompt.getOptions().getClass().getSimpleName());
129+
}
130+
}
131+
132+
return request;
133+
}
134+
122135
}
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
/*
2+
* Copyright 2024-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.bedrock.llama2;
18+
19+
import com.fasterxml.jackson.annotation.JsonIgnore;
20+
import com.fasterxml.jackson.annotation.JsonInclude;
21+
import com.fasterxml.jackson.annotation.JsonInclude.Include;
22+
import com.fasterxml.jackson.annotation.JsonProperty;
23+
24+
import org.springframework.ai.chat.ChatOptions;
25+
26+
/**
27+
* @author Christian Tzolov
28+
*/
29+
@JsonInclude(Include.NON_NULL)
30+
public class BedrockLlama2ChatOptions implements ChatOptions {
31+
32+
/**
33+
* The temperature value controls the randomness of the generated text. Use a lower
34+
* value to decrease randomness in the response.
35+
*/
36+
private @JsonProperty("temperature") Float temperature;
37+
38+
/**
39+
* The topP value controls the diversity of the generated text. Use a lower value to
40+
* ignore less probable options. Set to 0 or 1.0 to disable.
41+
*/
42+
private @JsonProperty("top_p") Float topP;
43+
44+
/**
45+
* The maximum length of the generated text.
46+
*/
47+
private @JsonProperty("max_gen_len") Integer maxGenLen;
48+
49+
public static Builder builder() {
50+
return new Builder();
51+
}
52+
53+
public static class Builder {
54+
55+
private BedrockLlama2ChatOptions options = new BedrockLlama2ChatOptions();
56+
57+
public Builder withTemperature(Float temperature) {
58+
this.options.setTemperature(temperature);
59+
return this;
60+
}
61+
62+
public Builder withTopP(Float topP) {
63+
this.options.setTopP(topP);
64+
return this;
65+
}
66+
67+
public Builder withMaxGenLen(Integer maxGenLen) {
68+
this.options.setMaxGenLen(maxGenLen);
69+
return this;
70+
}
71+
72+
public BedrockLlama2ChatOptions build() {
73+
return this.options;
74+
}
75+
76+
}
77+
78+
public Float getTemperature() {
79+
return this.temperature;
80+
}
81+
82+
public void setTemperature(Float temperature) {
83+
this.temperature = temperature;
84+
}
85+
86+
public Float getTopP() {
87+
return this.topP;
88+
}
89+
90+
public void setTopP(Float topP) {
91+
this.topP = topP;
92+
}
93+
94+
public Integer getMaxGenLen() {
95+
return this.maxGenLen;
96+
}
97+
98+
public void setMaxGenLen(Integer maxGenLen) {
99+
this.maxGenLen = maxGenLen;
100+
}
101+
102+
@Override
103+
@JsonIgnore
104+
public Integer getTopK() {
105+
throw new UnsupportedOperationException("Unsupported option: 'TopK'");
106+
}
107+
108+
@Override
109+
@JsonIgnore
110+
public void setTopK(Integer topK) {
111+
throw new UnsupportedOperationException("Unsupported option: 'TopK'");
112+
}
113+
114+
}

models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama2/BedrockLlama2ChatClientIT.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,8 @@ public Llama2ChatBedrockApi llama2Api() {
166166

167167
@Bean
168168
public BedrockLlama2ChatClient llama2ChatClient(Llama2ChatBedrockApi llama2Api) {
169-
return new BedrockLlama2ChatClient(llama2Api);
169+
return new BedrockLlama2ChatClient(llama2Api,
170+
BedrockLlama2ChatOptions.builder().withTemperature(0.5f).withMaxGenLen(100).withTopP(0.9f).build());
170171
}
171172

172173
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
/*
2+
* Copyright 2024-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.bedrock.llama2;
18+
19+
import com.fasterxml.jackson.databind.ObjectMapper;
20+
import org.junit.jupiter.api.Test;
21+
import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider;
22+
import software.amazon.awssdk.regions.Region;
23+
24+
import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi;
25+
import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi.Llama2ChatModel;
26+
import org.springframework.ai.chat.prompt.Prompt;
27+
28+
import static org.assertj.core.api.Assertions.assertThat;
29+
30+
/**
31+
* @author Christian Tzolov
32+
*/
33+
public class BedrockLlama2CreateRequestTests {
34+
35+
private Llama2ChatBedrockApi api = new Llama2ChatBedrockApi(Llama2ChatModel.LLAMA2_70B_CHAT_V1.id(),
36+
EnvironmentVariableCredentialsProvider.create(), Region.US_EAST_1.id(), new ObjectMapper());
37+
38+
@Test
39+
public void createRequestWithChatOptions() {
40+
41+
var client = new BedrockLlama2ChatClient(api,
42+
BedrockLlama2ChatOptions.builder().withTemperature(66.6f).withMaxGenLen(666).withTopP(0.66f).build());
43+
44+
var request = client.createRequest(new Prompt("Test message content"));
45+
46+
assertThat(request.prompt()).isNotEmpty();
47+
assertThat(request.temperature()).isEqualTo(66.6f);
48+
assertThat(request.topP()).isEqualTo(0.66f);
49+
assertThat(request.maxGenLen()).isEqualTo(666);
50+
51+
request = client.createRequest(new Prompt("Test message content",
52+
BedrockLlama2ChatOptions.builder().withTemperature(99.9f).withMaxGenLen(999).withTopP(0.99f).build()));
53+
54+
assertThat(request.prompt()).isNotEmpty();
55+
assertThat(request.temperature()).isEqualTo(99.9f);
56+
assertThat(request.topP()).isEqualTo(0.99f);
57+
assertThat(request.maxGenLen()).isEqualTo(999);
58+
}
59+
60+
}
Loading

spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
*** xref:api/clients/azure-openai-chat.adoc[]
88
*** xref:api/clients/bedrock.adoc[]
99
**** xref:api/clients/bedrock/bedrock-anthropic.adoc[]
10+
**** xref:api/clients/bedrock/bedrock-llama2.adoc[]
1011
*** xref:api/clients/huggingface.adoc[]
1112
*** xref:api/clients/ollama-chat.adoc[]
1213
** xref:api/prompt.adoc[]

0 commit comments

Comments
 (0)