Skip to content

Commit f249e64

Browse files
tzolovmarkpollack
authored andcommitted
Add Function Calling Support for Anthropic Features
- expanded the AnthropicApi to include Tool, facilitating request and response abstractions. - extended AnthropicChatClient to inherit AbstractFunctionCallSupport, with implementation of all necessary methods and function registration protocols. - implemented FunctionCallingOptions interface in AnthropicChatOptions. - added tools integration tests for AnthropoicApi and AnthropicChatClient. - extended the auto-configuration with functional calling functionality. - added ITs for tools auto-config. - updated documentation on anthropic function calling and relevant pages for comprehensive coverage.
1 parent e268975 commit f249e64

File tree

15 files changed

+1112
-82
lines changed

15 files changed

+1112
-82
lines changed

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatClient.java

Lines changed: 112 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717

1818
import java.util.ArrayList;
1919
import java.util.Base64;
20+
import java.util.HashSet;
2021
import java.util.List;
2122
import java.util.Map;
23+
import java.util.Set;
2224
import java.util.concurrent.atomic.AtomicReference;
2325
import java.util.stream.Collectors;
2426

@@ -28,13 +30,13 @@
2830

2931
import org.springframework.ai.anthropic.api.AnthropicApi;
3032
import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletion;
31-
import org.springframework.ai.anthropic.api.AnthropicApi.RequestMessage;
32-
import org.springframework.ai.anthropic.api.AnthropicApi.MediaContent;
3333
import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest;
34+
import org.springframework.ai.anthropic.api.AnthropicApi.MediaContent;
35+
import org.springframework.ai.anthropic.api.AnthropicApi.MediaContent.Type;
36+
import org.springframework.ai.anthropic.api.AnthropicApi.RequestMessage;
3437
import org.springframework.ai.anthropic.api.AnthropicApi.Role;
3538
import org.springframework.ai.anthropic.api.AnthropicApi.StreamResponse;
3639
import org.springframework.ai.anthropic.api.AnthropicApi.Usage;
37-
import org.springframework.ai.anthropic.api.AnthropicApi.MediaContent.Type;
3840
import org.springframework.ai.anthropic.metadata.AnthropicChatResponseMetadata;
3941
import org.springframework.ai.chat.ChatClient;
4042
import org.springframework.ai.chat.ChatResponse;
@@ -45,6 +47,8 @@
4547
import org.springframework.ai.chat.prompt.ChatOptions;
4648
import org.springframework.ai.chat.prompt.Prompt;
4749
import org.springframework.ai.model.ModelOptionsUtils;
50+
import org.springframework.ai.model.function.AbstractFunctionCallSupport;
51+
import org.springframework.ai.model.function.FunctionCallbackContext;
4852
import org.springframework.ai.retry.RetryUtils;
4953
import org.springframework.http.ResponseEntity;
5054
import org.springframework.retry.support.RetryTemplate;
@@ -57,7 +61,9 @@
5761
* @author Christian Tzolov
5862
* @since 1.0.0
5963
*/
60-
public class AnthropicChatClient implements ChatClient, StreamingChatClient {
64+
public class AnthropicChatClient extends
65+
AbstractFunctionCallSupport<AnthropicApi.RequestMessage, AnthropicApi.ChatCompletionRequest, ResponseEntity<AnthropicApi.ChatCompletion>>
66+
implements ChatClient, StreamingChatClient {
6167

6268
private static final Logger logger = LoggerFactory.getLogger(AnthropicChatClient.class);
6369

@@ -112,6 +118,22 @@ public AnthropicChatClient(AnthropicApi anthropicApi, AnthropicChatOptions defau
112118
*/
113119
public AnthropicChatClient(AnthropicApi anthropicApi, AnthropicChatOptions defaultOptions,
114120
RetryTemplate retryTemplate) {
121+
this(anthropicApi, defaultOptions, retryTemplate, null);
122+
}
123+
124+
/**
125+
* Construct a new {@link AnthropicChatClient} instance.
126+
* @param anthropicApi the lower-level API for the Anthropic service.
127+
* @param defaultOptions the default options used for the chat completion requests.
128+
* @param retryTemplate the retry template used to retry the Anthropic API calls.
129+
* @param functionCallbackContext the function callback context used to store the
130+
* state of the function calls.
131+
*/
132+
public AnthropicChatClient(AnthropicApi anthropicApi, AnthropicChatOptions defaultOptions,
133+
RetryTemplate retryTemplate, FunctionCallbackContext functionCallbackContext) {
134+
135+
super(functionCallbackContext);
136+
115137
Assert.notNull(anthropicApi, "AnthropicApi must not be null");
116138
Assert.notNull(defaultOptions, "DefaultOptions must not be null");
117139
Assert.notNull(retryTemplate, "RetryTemplate must not be null");
@@ -127,7 +149,7 @@ public ChatResponse call(Prompt prompt) {
127149
ChatCompletionRequest request = createRequest(prompt, false);
128150

129151
return this.retryTemplate.execute(ctx -> {
130-
ResponseEntity<ChatCompletion> completionEntity = this.anthropicApi.chatCompletionEntity(request);
152+
ResponseEntity<ChatCompletion> completionEntity = this.callWithFunctionSupport(request);
131153
return toChatResponse(completionEntity.getBody());
132154
});
133155
}
@@ -229,6 +251,8 @@ else if (mediaData instanceof String text) {
229251

230252
ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
231253

254+
Set<String> functionsForThisRequest = new HashSet<>();
255+
232256
List<RequestMessage> userMessages = prompt.getInstructions()
233257
.stream()
234258
.filter(m -> m.getMessageType() != MessageType.SYSTEM)
@@ -260,6 +284,10 @@ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
260284
AnthropicChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions,
261285
ChatOptions.class, AnthropicChatOptions.class);
262286

287+
Set<String> promptEnabledFunctions = this.handleFunctionCallbackConfigurations(updatedRuntimeOptions,
288+
IS_RUNTIME_CALL);
289+
functionsForThisRequest.addAll(promptEnabledFunctions);
290+
263291
request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, ChatCompletionRequest.class);
264292
}
265293
else {
@@ -269,12 +297,32 @@ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
269297
}
270298

271299
if (this.defaultOptions != null) {
300+
Set<String> defaultEnabledFunctions = this.handleFunctionCallbackConfigurations(this.defaultOptions,
301+
!IS_RUNTIME_CALL);
302+
functionsForThisRequest.addAll(defaultEnabledFunctions);
303+
272304
request = ModelOptionsUtils.merge(request, this.defaultOptions, ChatCompletionRequest.class);
273305
}
274306

307+
if (!CollectionUtils.isEmpty(functionsForThisRequest)) {
308+
309+
List<AnthropicApi.Tool> tools = getFunctionTools(functionsForThisRequest);
310+
311+
request = ChatCompletionRequest.from(request).withTools(tools).build();
312+
}
313+
275314
return request;
276315
}
277316

317+
private List<AnthropicApi.Tool> getFunctionTools(Set<String> functionNames) {
318+
return this.resolveFunctionCallbacks(functionNames).stream().map(functionCallback -> {
319+
var description = functionCallback.getDescription();
320+
var name = functionCallback.getName();
321+
String inputSchema = functionCallback.getInputTypeSchema();
322+
return new AnthropicApi.Tool(name, description, ModelOptionsUtils.jsonToMap(inputSchema));
323+
}).toList();
324+
}
325+
278326
private static class ChatCompletionBuilder {
279327

280328
private String type;
@@ -343,4 +391,63 @@ public ChatCompletion build() {
343391

344392
}
345393

394+
@Override
395+
protected ChatCompletionRequest doCreateToolResponseRequest(ChatCompletionRequest previousRequest,
396+
RequestMessage responseMessage, List<RequestMessage> conversationHistory) {
397+
398+
List<MediaContent> toolToUseList = responseMessage.content()
399+
.stream()
400+
.filter(c -> c.type() == MediaContent.Type.TOOL_USE)
401+
.toList();
402+
403+
List<MediaContent> toolResults = new ArrayList<>();
404+
405+
for (MediaContent toolToUse : toolToUseList) {
406+
407+
var functionCallId = toolToUse.id();
408+
var functionName = toolToUse.name();
409+
var functionArguments = toolToUse.input();
410+
411+
if (!this.functionCallbackRegister.containsKey(functionName)) {
412+
throw new IllegalStateException("No function callback found for function name: " + functionName);
413+
}
414+
415+
String functionResponse = this.functionCallbackRegister.get(functionName)
416+
.call(ModelOptionsUtils.toJsonString(functionArguments));
417+
418+
toolResults.add(new MediaContent(Type.TOOL_RESULT, functionCallId, functionResponse));
419+
}
420+
421+
// Add the function response to the conversation.
422+
conversationHistory.add(new RequestMessage(toolResults, Role.USER));
423+
424+
// Recursively call chatCompletionWithTools until the model doesn't call a
425+
// functions anymore.
426+
return ChatCompletionRequest.from(previousRequest).withMessages(conversationHistory).build();
427+
}
428+
429+
@Override
430+
protected List<RequestMessage> doGetUserMessages(ChatCompletionRequest request) {
431+
return request.messages();
432+
}
433+
434+
@Override
435+
protected RequestMessage doGetToolResponseMessage(ResponseEntity<ChatCompletion> response) {
436+
return new RequestMessage(response.getBody().content(), Role.ASSISTANT);
437+
}
438+
439+
@Override
440+
protected ResponseEntity<ChatCompletion> doChatCompletion(ChatCompletionRequest request) {
441+
return this.anthropicApi.chatCompletionEntity(request);
442+
}
443+
444+
@SuppressWarnings("null")
445+
@Override
446+
protected boolean isToolFunctionCall(ResponseEntity<ChatCompletion> response) {
447+
if (response == null || response.getBody() == null || CollectionUtils.isEmpty(response.getBody().content())) {
448+
return false;
449+
}
450+
return response.getBody().content().stream().anyMatch(content -> content.type() == MediaContent.Type.TOOL_USE);
451+
}
452+
346453
}

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,22 @@
1515
*/
1616
package org.springframework.ai.anthropic;
1717

18+
import java.util.ArrayList;
19+
import java.util.HashSet;
1820
import java.util.List;
21+
import java.util.Set;
1922

23+
import com.fasterxml.jackson.annotation.JsonIgnore;
2024
import com.fasterxml.jackson.annotation.JsonInclude;
2125
import com.fasterxml.jackson.annotation.JsonInclude.Include;
2226
import com.fasterxml.jackson.annotation.JsonProperty;
2327

2428
import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest;
2529
import org.springframework.ai.chat.prompt.ChatOptions;
30+
import org.springframework.ai.model.function.FunctionCallback;
31+
import org.springframework.ai.model.function.FunctionCallingOptions;
32+
import org.springframework.boot.context.properties.NestedConfigurationProperty;
33+
import org.springframework.util.Assert;
2634

2735
/**
2836
* The options to be used when sending a chat request to the Anthropic API.
@@ -31,7 +39,7 @@
3139
* @since 1.0.0
3240
*/
3341
@JsonInclude(Include.NON_NULL)
34-
public class AnthropicChatOptions implements ChatOptions {
42+
public class AnthropicChatOptions implements ChatOptions, FunctionCallingOptions {
3543

3644
// @formatter:off
3745
private @JsonProperty("model") String model;
@@ -41,6 +49,32 @@ public class AnthropicChatOptions implements ChatOptions {
4149
private @JsonProperty("temperature") Float temperature;
4250
private @JsonProperty("top_p") Float topP;
4351
private @JsonProperty("top_k") Integer topK;
52+
53+
/**
54+
* Tool Function Callbacks to register with the ChatClient. For Prompt
55+
* Options the functionCallbacks are automatically enabled for the duration of the
56+
* prompt execution. For Default Options the functionCallbacks are registered but
57+
* disabled by default. Use the enableFunctions to set the functions from the registry
58+
* to be used by the ChatClient chat completion requests.
59+
*/
60+
@NestedConfigurationProperty
61+
@JsonIgnore
62+
private List<FunctionCallback> functionCallbacks = new ArrayList<>();
63+
64+
/**
65+
* List of functions, identified by their names, to configure for function calling in
66+
* the chat completion requests. Functions with those names must exist in the
67+
* functionCallbacks registry. The {@link #functionCallbacks} from the PromptOptions
68+
* are automatically enabled for the duration of the prompt execution.
69+
*
70+
* Note that function enabled with the default options are enabled for all chat
71+
* completion requests. This could impact the token count and the billing. If the
72+
* functions is set in a prompt options, then the enabled functions are only active
73+
* for the duration of this prompt execution.
74+
*/
75+
@NestedConfigurationProperty
76+
@JsonIgnore
77+
private Set<String> functions = new HashSet<>();
4478
// @formatter:on
4579

4680
public static Builder builder() {
@@ -86,6 +120,23 @@ public Builder withTopK(Integer topK) {
86120
return this;
87121
}
88122

123+
public Builder withFunctionCallbacks(List<FunctionCallback> functionCallbacks) {
124+
this.options.functionCallbacks = functionCallbacks;
125+
return this;
126+
}
127+
128+
public Builder withFunctions(Set<String> functionNames) {
129+
Assert.notNull(functionNames, "Function names must not be null");
130+
this.options.functions = functionNames;
131+
return this;
132+
}
133+
134+
public Builder withFunction(String functionName) {
135+
Assert.hasText(functionName, "Function name must not be empty");
136+
this.options.functions.add(functionName);
137+
return this;
138+
}
139+
89140
public AnthropicChatOptions build() {
90141
return this.options;
91142
}
@@ -150,4 +201,26 @@ public void setTopK(Integer topK) {
150201
this.topK = topK;
151202
}
152203

204+
@Override
205+
public List<FunctionCallback> getFunctionCallbacks() {
206+
return this.functionCallbacks;
207+
}
208+
209+
@Override
210+
public void setFunctionCallbacks(List<FunctionCallback> functionCallbacks) {
211+
Assert.notNull(functionCallbacks, "FunctionCallbacks must not be null");
212+
this.functionCallbacks = functionCallbacks;
213+
}
214+
215+
@Override
216+
public Set<String> getFunctions() {
217+
return this.functions;
218+
}
219+
220+
@Override
221+
public void setFunctions(Set<String> functions) {
222+
Assert.notNull(functions, "Function must not be null");
223+
this.functions = functions;
224+
}
225+
153226
}

0 commit comments

Comments
 (0)