Skip to content

Commit 9e865f0

Browse files
wmz7yeartzolov
authored andcommitted
Add Bedrock Meta LLama3 AI model support.
- re-enable llama structured output tests
1 parent b0add71 commit 9e865f0

File tree

21 files changed

+244
-230
lines changed

21 files changed

+244
-230
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ You can find more details in the [Reference Documentation](https://docs.spring.i
8484
Spring AI supports many AI models. For an overview see here. Specific models currently supported are
8585
* OpenAI
8686
* Azure OpenAI
87-
* Amazon Bedrock (Anthropic, Llama2, Cohere, Titan, Jurassic2)
87+
* Amazon Bedrock (Anthropic, Llama, Cohere, Titan, Jurassic2)
8888
* HuggingFace
8989
* Google VertexAI (PaLM2, Gemini)
9090
* Mistral AI

models/spring-ai-bedrock/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
- [Anthropic2 Chat Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/chat/bedrock/bedrock-anthropic.html)
55
- [Cohere Chat Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/chat/bedrock/bedrock-cohere.html)
66
- [Cohere Embedding Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/embeddings/bedrock-cohere-embedding.html)
7-
- [Llama2 Chat Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/chat/bedrock/bedrock-llama2.html)
7+
- [Llama Chat Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/chat/bedrock/bedrock-llama.html)
88
- [Titan Chat Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/chat/bedrock/bedrock-titan.html)
99
- [Titan Embedding Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/embeddings/bedrock-titan-embedding.html)
1010
- [Jurassic2 Chat Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/chat/bedrock/bedrock-jurassic2.html)

models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/aot/BedrockRuntimeHints.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525
import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi;
2626
import org.springframework.ai.bedrock.cohere.api.CohereEmbeddingBedrockApi;
2727
import org.springframework.ai.bedrock.jurassic2.api.Ai21Jurassic2ChatBedrockApi;
28-
import org.springframework.ai.bedrock.llama2.BedrockLlama2ChatOptions;
29-
import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi;
28+
import org.springframework.ai.bedrock.llama.BedrockLlamaChatOptions;
29+
import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi;
3030
import org.springframework.ai.bedrock.titan.BedrockTitanChatOptions;
3131
import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi;
3232
import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi;
@@ -63,9 +63,9 @@ public void registerHints(RuntimeHints hints, ClassLoader classLoader) {
6363
for (var tr : findJsonAnnotatedClassesInPackage(BedrockCohereEmbeddingOptions.class))
6464
hints.reflection().registerType(tr, mcs);
6565

66-
for (var tr : findJsonAnnotatedClassesInPackage(Llama2ChatBedrockApi.class))
66+
for (var tr : findJsonAnnotatedClassesInPackage(LlamaChatBedrockApi.class))
6767
hints.reflection().registerType(tr, mcs);
68-
for (var tr : findJsonAnnotatedClassesInPackage(BedrockLlama2ChatOptions.class))
68+
for (var tr : findJsonAnnotatedClassesInPackage(BedrockLlamaChatOptions.class))
6969
hints.reflection().registerType(tr, mcs);
7070

7171
for (var tr : findJsonAnnotatedClassesInPackage(TitanChatBedrockApi.class))
Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,16 @@
1313
* See the License for the specific language governing permissions and
1414
* limitations under the License.
1515
*/
16-
package org.springframework.ai.bedrock.llama2;
16+
package org.springframework.ai.bedrock.llama;
1717

1818
import java.util.List;
1919

2020
import reactor.core.publisher.Flux;
2121

2222
import org.springframework.ai.bedrock.MessageToPromptConverter;
23-
import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi;
24-
import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi.Llama2ChatRequest;
25-
import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi.Llama2ChatResponse;
23+
import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi;
24+
import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatRequest;
25+
import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatResponse;
2626
import org.springframework.ai.chat.ChatClient;
2727
import org.springframework.ai.chat.prompt.ChatOptions;
2828
import org.springframework.ai.chat.ChatResponse;
@@ -35,26 +35,27 @@
3535
import org.springframework.util.Assert;
3636

3737
/**
38-
* Java {@link ChatClient} and {@link StreamingChatClient} for the Bedrock Llama2 chat
38+
* Java {@link ChatClient} and {@link StreamingChatClient} for the Bedrock Llama chat
3939
* generative.
4040
*
4141
* @author Christian Tzolov
42+
* @author Wei Jiang
4243
* @since 0.8.0
4344
*/
44-
public class BedrockLlama2ChatClient implements ChatClient, StreamingChatClient {
45+
public class BedrockLlamaChatClient implements ChatClient, StreamingChatClient {
4546

46-
private final Llama2ChatBedrockApi chatApi;
47+
private final LlamaChatBedrockApi chatApi;
4748

48-
private final BedrockLlama2ChatOptions defaultOptions;
49+
private final BedrockLlamaChatOptions defaultOptions;
4950

50-
public BedrockLlama2ChatClient(Llama2ChatBedrockApi chatApi) {
51+
public BedrockLlamaChatClient(LlamaChatBedrockApi chatApi) {
5152
this(chatApi,
52-
BedrockLlama2ChatOptions.builder().withTemperature(0.8f).withTopP(0.9f).withMaxGenLen(100).build());
53+
BedrockLlamaChatOptions.builder().withTemperature(0.8f).withTopP(0.9f).withMaxGenLen(100).build());
5354
}
5455

55-
public BedrockLlama2ChatClient(Llama2ChatBedrockApi chatApi, BedrockLlama2ChatOptions options) {
56-
Assert.notNull(chatApi, "Llama2ChatBedrockApi must not be null");
57-
Assert.notNull(options, "BedrockLlama2ChatOptions must not be null");
56+
public BedrockLlamaChatClient(LlamaChatBedrockApi chatApi, BedrockLlamaChatOptions options) {
57+
Assert.notNull(chatApi, "LlamaChatBedrockApi must not be null");
58+
Assert.notNull(options, "BedrockLlamaChatOptions must not be null");
5859

5960
this.chatApi = chatApi;
6061
this.defaultOptions = options;
@@ -65,7 +66,7 @@ public ChatResponse call(Prompt prompt) {
6566

6667
var request = createRequest(prompt);
6768

68-
Llama2ChatResponse response = this.chatApi.chatCompletion(request);
69+
LlamaChatResponse response = this.chatApi.chatCompletion(request);
6970

7071
return new ChatResponse(List.of(new Generation(response.generation()).withGenerationMetadata(
7172
ChatGenerationMetadata.from(response.stopReason().name(), extractUsage(response)))));
@@ -76,7 +77,7 @@ public Flux<ChatResponse> stream(Prompt prompt) {
7677

7778
var request = createRequest(prompt);
7879

79-
Flux<Llama2ChatResponse> fluxResponse = this.chatApi.chatCompletionStream(request);
80+
Flux<LlamaChatResponse> fluxResponse = this.chatApi.chatCompletionStream(request);
8081

8182
return fluxResponse.map(response -> {
8283
String stopReason = response.stopReason() != null ? response.stopReason().name() : null;
@@ -85,7 +86,7 @@ public Flux<ChatResponse> stream(Prompt prompt) {
8586
});
8687
}
8788

88-
private Usage extractUsage(Llama2ChatResponse response) {
89+
private Usage extractUsage(LlamaChatResponse response) {
8990
return new Usage() {
9091

9192
@Override
@@ -103,22 +104,22 @@ public Long getGenerationTokens() {
103104
/**
104105
* Accessible for testing.
105106
*/
106-
Llama2ChatRequest createRequest(Prompt prompt) {
107+
LlamaChatRequest createRequest(Prompt prompt) {
107108

108109
final String promptValue = MessageToPromptConverter.create().toPrompt(prompt.getInstructions());
109110

110-
Llama2ChatRequest request = Llama2ChatRequest.builder(promptValue).build();
111+
LlamaChatRequest request = LlamaChatRequest.builder(promptValue).build();
111112

112113
if (this.defaultOptions != null) {
113-
request = ModelOptionsUtils.merge(request, this.defaultOptions, Llama2ChatRequest.class);
114+
request = ModelOptionsUtils.merge(request, this.defaultOptions, LlamaChatRequest.class);
114115
}
115116

116117
if (prompt.getOptions() != null) {
117118
if (prompt.getOptions() instanceof ChatOptions runtimeOptions) {
118-
BedrockLlama2ChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions,
119-
ChatOptions.class, BedrockLlama2ChatOptions.class);
119+
BedrockLlamaChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions,
120+
ChatOptions.class, BedrockLlamaChatOptions.class);
120121

121-
request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, Llama2ChatRequest.class);
122+
request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, LlamaChatRequest.class);
122123
}
123124
else {
124125
throw new IllegalArgumentException("Prompt options are not of type ChatOptions: "
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
* See the License for the specific language governing permissions and
1414
* limitations under the License.
1515
*/
16-
package org.springframework.ai.bedrock.llama2;
16+
package org.springframework.ai.bedrock.llama;
1717

1818
import com.fasterxml.jackson.annotation.JsonIgnore;
1919
import com.fasterxml.jackson.annotation.JsonInclude;
@@ -26,7 +26,7 @@
2626
* @author Christian Tzolov
2727
*/
2828
@JsonInclude(Include.NON_NULL)
29-
public class BedrockLlama2ChatOptions implements ChatOptions {
29+
public class BedrockLlamaChatOptions implements ChatOptions {
3030

3131
/**
3232
* The temperature value controls the randomness of the generated text. Use a lower
@@ -51,7 +51,7 @@ public static Builder builder() {
5151

5252
public static class Builder {
5353

54-
private BedrockLlama2ChatOptions options = new BedrockLlama2ChatOptions();
54+
private BedrockLlamaChatOptions options = new BedrockLlamaChatOptions();
5555

5656
public Builder withTemperature(Float temperature) {
5757
this.options.setTemperature(temperature);
@@ -68,7 +68,7 @@ public Builder withMaxGenLen(Integer maxGenLen) {
6868
return this;
6969
}
7070

71-
public BedrockLlama2ChatOptions build() {
71+
public BedrockLlamaChatOptions build() {
7272
return this.options;
7373
}
7474

0 commit comments

Comments
 (0)