Skip to content

Commit 7fca784

Browse files
committed
Add Bedrock Anthropic Chat Options support
- Improve Anthropic tests - Add anthrpic docs - Restructure the docs for Azure OpenAI, OpenAI, Ollama, Bedrock Cohere and Bedrock Lllam2
1 parent 4ba9a3c commit 7fca784

File tree

30 files changed

+1053
-566
lines changed

30 files changed

+1053
-566
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Azure OpenAI
2+
3+
Visit the Spring AI [Azure OpenAI Chat Documentation](https://docs.spring.io/spring-ai/reference/api/clients/azure-openai-chat.html).

models/spring-ai-azure-openai/README_AZURE_OPENAI.md

Lines changed: 0 additions & 9 deletions
This file was deleted.

models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatClient.java

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,14 +78,22 @@ public class AzureOpenAiChatClient implements ChatClient, StreamingChatClient {
7878
private final OpenAIClient openAIClient;
7979

8080
public AzureOpenAiChatClient(OpenAIClient microsoftOpenAiClient) {
81+
this(microsoftOpenAiClient,
82+
AzureOpenAiChatOptions.builder().withModel(DEFAULT_MODEL).withTemperature(DEFAULT_TEMPERATURE).build());
83+
}
84+
85+
public AzureOpenAiChatClient(OpenAIClient microsoftOpenAiClient, AzureOpenAiChatOptions options) {
8186
Assert.notNull(microsoftOpenAiClient, "com.azure.ai.openai.OpenAIClient must not be null");
87+
Assert.notNull(options, "AzureOpenAiChatOptions must not be null");
8288
this.openAIClient = microsoftOpenAiClient;
83-
this.defaultOptions = AzureOpenAiChatOptions.builder()
84-
.withModel(DEFAULT_MODEL)
85-
.withTemperature(DEFAULT_TEMPERATURE)
86-
.build();
89+
this.defaultOptions = options;
8790
}
8891

92+
/**
93+
* @deprecated since 0.8.0, use
94+
* {@link #AzureOpenAiChatClient(OpenAIClient, AzureOpenAiChatOptions)} instead.
95+
*/
96+
@Deprecated(forRemoval = true, since = "0.8.0")
8997
public AzureOpenAiChatClient withDefaultOptions(AzureOpenAiChatOptions defaultOptions) {
9098
Assert.notNull(defaultOptions, "DefaultOptions must not be null");
9199
this.defaultOptions = defaultOptions;

models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureChatCompletionsOptionsTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ public class AzureChatCompletionsOptionsTests {
3333
public void createRequestWithChatOptions() {
3434

3535
OpenAIClient mockClient = Mockito.mock(OpenAIClient.class);
36-
var client = new AzureOpenAiChatClient(mockClient).withDefaultOptions(
36+
var client = new AzureOpenAiChatClient(mockClient,
3737
AzureOpenAiChatOptions.builder().withModel("DEFAULT_MODEL").withTemperature(66.6f).build());
3838

3939
var requestOptions = client.toAzureChatCompletionsOptions(new Prompt("Test message content"));

models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatClientIT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ public OpenAIClient openAIClient() {
180180

181181
@Bean
182182
public AzureOpenAiChatClient azureOpenAiChatClient(OpenAIClient openAIClient) {
183-
return new AzureOpenAiChatClient(openAIClient).withDefaultOptions(
183+
return new AzureOpenAiChatClient(openAIClient,
184184
AzureOpenAiChatOptions.builder().withModel("gpt-35-turbo").withMaxTokens(200).build());
185185

186186
}

models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/api/AnthropicChatBedrockApi.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,11 @@ public enum AnthropicChatModel {
194194
/**
195195
* anthropic.claude-v2
196196
*/
197-
CLAUDE_V2("anthropic.claude-v2");
197+
CLAUDE_V2("anthropic.claude-v2"),
198+
/**
199+
* anthropic.claude-v2:1
200+
*/
201+
CLAUDE_V21("anthropic.claude-v2:1");
198202

199203
private final String id;
200204

Lines changed: 42 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2023-2023 the original author or authors.
2+
* Copyright 2023-2024 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -18,23 +18,23 @@
1818

1919
import java.util.List;
2020

21-
import org.springframework.ai.chat.ChatResponse;
22-
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
2321
import reactor.core.publisher.Flux;
2422

2523
import org.springframework.ai.bedrock.BedrockUsage;
2624
import org.springframework.ai.bedrock.MessageToPromptConverter;
2725
import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi;
2826
import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatRequest;
2927
import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatResponse;
30-
import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatRequest.LogitBias;
31-
import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatRequest.ReturnLikelihoods;
32-
import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatRequest.Truncate;
3328
import org.springframework.ai.chat.ChatClient;
34-
import org.springframework.ai.chat.StreamingChatClient;
29+
import org.springframework.ai.chat.ChatOptions;
30+
import org.springframework.ai.chat.ChatResponse;
3531
import org.springframework.ai.chat.Generation;
32+
import org.springframework.ai.chat.StreamingChatClient;
33+
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
3634
import org.springframework.ai.chat.metadata.Usage;
3735
import org.springframework.ai.chat.prompt.Prompt;
36+
import org.springframework.ai.model.ModelOptionsUtils;
37+
import org.springframework.util.Assert;
3838

3939
/**
4040
* @author Christian Tzolov
@@ -44,71 +44,18 @@ public class BedrockCohereChatClient implements ChatClient, StreamingChatClient
4444

4545
private final CohereChatBedrockApi chatApi;
4646

47-
private Float temperature;
48-
49-
private Float topP;
50-
51-
private Integer topK;
52-
53-
private Integer maxTokens;
54-
55-
private List<String> stopSequences;
56-
57-
private ReturnLikelihoods returnLikelihoods;
58-
59-
private Integer numGenerations;
60-
61-
private LogitBias logitBias;
62-
63-
private Truncate truncate;
47+
private final BedrockCohereChatOptions defaultOptions;
6448

6549
public BedrockCohereChatClient(CohereChatBedrockApi chatApi) {
66-
this.chatApi = chatApi;
67-
}
68-
69-
public BedrockCohereChatClient withTemperature(Float temperature) {
70-
this.temperature = temperature;
71-
return this;
50+
this(chatApi, BedrockCohereChatOptions.builder().build());
7251
}
7352

74-
public BedrockCohereChatClient withTopP(Float topP) {
75-
this.topP = topP;
76-
return this;
77-
}
53+
public BedrockCohereChatClient(CohereChatBedrockApi chatApi, BedrockCohereChatOptions options) {
54+
Assert.notNull(chatApi, "CohereChatBedrockApi must not be null");
55+
Assert.notNull(options, "BedrockCohereChatOptions must not be null");
7856

79-
public BedrockCohereChatClient withTopK(Integer topK) {
80-
this.topK = topK;
81-
return this;
82-
}
83-
84-
public BedrockCohereChatClient withMaxTokens(Integer maxTokens) {
85-
this.maxTokens = maxTokens;
86-
return this;
87-
}
88-
89-
public BedrockCohereChatClient withStopSequences(List<String> stopSequences) {
90-
this.stopSequences = stopSequences;
91-
return this;
92-
}
93-
94-
public BedrockCohereChatClient withReturnLikelihoods(ReturnLikelihoods returnLikelihoods) {
95-
this.returnLikelihoods = returnLikelihoods;
96-
return this;
97-
}
98-
99-
public BedrockCohereChatClient withNumGenerations(Integer numGenerations) {
100-
this.numGenerations = numGenerations;
101-
return this;
102-
}
103-
104-
public BedrockCohereChatClient withLogitBias(LogitBias logitBias) {
105-
this.logitBias = logitBias;
106-
return this;
107-
}
108-
109-
public BedrockCohereChatClient withTruncate(Truncate truncate) {
110-
this.truncate = truncate;
111-
return this;
57+
this.chatApi = chatApi;
58+
this.defaultOptions = options;
11259
}
11360

11461
@Override
@@ -134,21 +81,38 @@ public Flux<ChatResponse> stream(Prompt prompt) {
13481
});
13582
}
13683

137-
private CohereChatRequest createRequest(Prompt prompt, boolean stream) {
84+
/**
85+
* Test access.
86+
*/
87+
CohereChatRequest createRequest(Prompt prompt, boolean stream) {
13888
final String promptValue = MessageToPromptConverter.create().toPrompt(prompt.getInstructions());
13989

140-
return CohereChatRequest.builder(promptValue)
141-
.withTemperature(this.temperature)
142-
.withTopP(this.topP)
143-
.withTopK(this.topK)
144-
.withMaxTokens(this.maxTokens)
145-
.withStopSequences(this.stopSequences)
146-
.withReturnLikelihoods(this.returnLikelihoods)
90+
var request = CohereChatRequest.builder(promptValue)
91+
.withTemperature(this.defaultOptions.getTemperature())
92+
.withTopP(this.defaultOptions.getTopP())
93+
.withTopK(this.defaultOptions.getTopK())
94+
.withMaxTokens(this.defaultOptions.getMaxTokens())
95+
.withStopSequences(this.defaultOptions.getStopSequences())
96+
.withReturnLikelihoods(this.defaultOptions.getReturnLikelihoods())
14797
.withStream(stream)
148-
.withNumGenerations(this.numGenerations)
149-
.withLogitBias(this.logitBias)
150-
.withTruncate(this.truncate)
98+
.withNumGenerations(this.defaultOptions.getNumGenerations())
99+
.withLogitBias(this.defaultOptions.getLogitBias())
100+
.withTruncate(this.defaultOptions.getTruncate())
151101
.build();
102+
103+
if (prompt.getOptions() != null) {
104+
if (prompt.getOptions() instanceof ChatOptions runtimeOptions) {
105+
BedrockCohereChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions,
106+
ChatOptions.class, BedrockCohereChatOptions.class);
107+
request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, CohereChatRequest.class);
108+
}
109+
else {
110+
throw new IllegalArgumentException("Prompt options are not of type ChatOptions: "
111+
+ prompt.getOptions().getClass().getSimpleName());
112+
}
113+
}
114+
115+
return request;
152116
}
153117

154118
}

0 commit comments

Comments
 (0)