17
17
18
18
import java .util .ArrayList ;
19
19
import java .util .Base64 ;
20
+ import java .util .HashSet ;
20
21
import java .util .List ;
21
22
import java .util .Map ;
23
+ import java .util .Set ;
22
24
import java .util .concurrent .atomic .AtomicReference ;
23
25
import java .util .stream .Collectors ;
24
26
28
30
29
31
import org .springframework .ai .anthropic .api .AnthropicApi ;
30
32
import org .springframework .ai .anthropic .api .AnthropicApi .ChatCompletion ;
31
- import org .springframework .ai .anthropic .api .AnthropicApi .RequestMessage ;
32
- import org .springframework .ai .anthropic .api .AnthropicApi .MediaContent ;
33
33
import org .springframework .ai .anthropic .api .AnthropicApi .ChatCompletionRequest ;
34
+ import org .springframework .ai .anthropic .api .AnthropicApi .MediaContent ;
35
+ import org .springframework .ai .anthropic .api .AnthropicApi .MediaContent .Type ;
36
+ import org .springframework .ai .anthropic .api .AnthropicApi .RequestMessage ;
34
37
import org .springframework .ai .anthropic .api .AnthropicApi .Role ;
35
38
import org .springframework .ai .anthropic .api .AnthropicApi .StreamResponse ;
36
39
import org .springframework .ai .anthropic .api .AnthropicApi .Usage ;
37
- import org .springframework .ai .anthropic .api .AnthropicApi .MediaContent .Type ;
38
40
import org .springframework .ai .anthropic .metadata .AnthropicChatResponseMetadata ;
39
41
import org .springframework .ai .chat .ChatClient ;
40
42
import org .springframework .ai .chat .ChatResponse ;
45
47
import org .springframework .ai .chat .prompt .ChatOptions ;
46
48
import org .springframework .ai .chat .prompt .Prompt ;
47
49
import org .springframework .ai .model .ModelOptionsUtils ;
50
+ import org .springframework .ai .model .function .AbstractFunctionCallSupport ;
51
+ import org .springframework .ai .model .function .FunctionCallbackContext ;
48
52
import org .springframework .ai .retry .RetryUtils ;
49
53
import org .springframework .http .ResponseEntity ;
50
54
import org .springframework .retry .support .RetryTemplate ;
57
61
* @author Christian Tzolov
58
62
* @since 1.0.0
59
63
*/
60
- public class AnthropicChatClient implements ChatClient , StreamingChatClient {
64
+ public class AnthropicChatClient extends
65
+ AbstractFunctionCallSupport <AnthropicApi .RequestMessage , AnthropicApi .ChatCompletionRequest , ResponseEntity <AnthropicApi .ChatCompletion >>
66
+ implements ChatClient , StreamingChatClient {
61
67
62
68
private static final Logger logger = LoggerFactory .getLogger (AnthropicChatClient .class );
63
69
@@ -112,6 +118,22 @@ public AnthropicChatClient(AnthropicApi anthropicApi, AnthropicChatOptions defau
112
118
*/
113
119
public AnthropicChatClient (AnthropicApi anthropicApi , AnthropicChatOptions defaultOptions ,
114
120
RetryTemplate retryTemplate ) {
121
+ this (anthropicApi , defaultOptions , retryTemplate , null );
122
+ }
123
+
124
+ /**
125
+ * Construct a new {@link AnthropicChatClient} instance.
126
+ * @param anthropicApi the lower-level API for the Anthropic service.
127
+ * @param defaultOptions the default options used for the chat completion requests.
128
+ * @param retryTemplate the retry template used to retry the Anthropic API calls.
129
+ * @param functionCallbackContext the function callback context used to store the
130
+ * state of the function calls.
131
+ */
132
+ public AnthropicChatClient (AnthropicApi anthropicApi , AnthropicChatOptions defaultOptions ,
133
+ RetryTemplate retryTemplate , FunctionCallbackContext functionCallbackContext ) {
134
+
135
+ super (functionCallbackContext );
136
+
115
137
Assert .notNull (anthropicApi , "AnthropicApi must not be null" );
116
138
Assert .notNull (defaultOptions , "DefaultOptions must not be null" );
117
139
Assert .notNull (retryTemplate , "RetryTemplate must not be null" );
@@ -127,7 +149,7 @@ public ChatResponse call(Prompt prompt) {
127
149
ChatCompletionRequest request = createRequest (prompt , false );
128
150
129
151
return this .retryTemplate .execute (ctx -> {
130
- ResponseEntity <ChatCompletion > completionEntity = this .anthropicApi . chatCompletionEntity (request );
152
+ ResponseEntity <ChatCompletion > completionEntity = this .callWithFunctionSupport (request );
131
153
return toChatResponse (completionEntity .getBody ());
132
154
});
133
155
}
@@ -229,6 +251,8 @@ else if (mediaData instanceof String text) {
229
251
230
252
ChatCompletionRequest createRequest (Prompt prompt , boolean stream ) {
231
253
254
+ Set <String > functionsForThisRequest = new HashSet <>();
255
+
232
256
List <RequestMessage > userMessages = prompt .getInstructions ()
233
257
.stream ()
234
258
.filter (m -> m .getMessageType () != MessageType .SYSTEM )
@@ -260,6 +284,10 @@ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
260
284
AnthropicChatOptions updatedRuntimeOptions = ModelOptionsUtils .copyToTarget (runtimeOptions ,
261
285
ChatOptions .class , AnthropicChatOptions .class );
262
286
287
+ Set <String > promptEnabledFunctions = this .handleFunctionCallbackConfigurations (updatedRuntimeOptions ,
288
+ IS_RUNTIME_CALL );
289
+ functionsForThisRequest .addAll (promptEnabledFunctions );
290
+
263
291
request = ModelOptionsUtils .merge (updatedRuntimeOptions , request , ChatCompletionRequest .class );
264
292
}
265
293
else {
@@ -269,12 +297,32 @@ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
269
297
}
270
298
271
299
if (this .defaultOptions != null ) {
300
+ Set <String > defaultEnabledFunctions = this .handleFunctionCallbackConfigurations (this .defaultOptions ,
301
+ !IS_RUNTIME_CALL );
302
+ functionsForThisRequest .addAll (defaultEnabledFunctions );
303
+
272
304
request = ModelOptionsUtils .merge (request , this .defaultOptions , ChatCompletionRequest .class );
273
305
}
274
306
307
+ if (!CollectionUtils .isEmpty (functionsForThisRequest )) {
308
+
309
+ List <AnthropicApi .Tool > tools = getFunctionTools (functionsForThisRequest );
310
+
311
+ request = ChatCompletionRequest .from (request ).withTools (tools ).build ();
312
+ }
313
+
275
314
return request ;
276
315
}
277
316
317
+ private List <AnthropicApi .Tool > getFunctionTools (Set <String > functionNames ) {
318
+ return this .resolveFunctionCallbacks (functionNames ).stream ().map (functionCallback -> {
319
+ var description = functionCallback .getDescription ();
320
+ var name = functionCallback .getName ();
321
+ String inputSchema = functionCallback .getInputTypeSchema ();
322
+ return new AnthropicApi .Tool (name , description , ModelOptionsUtils .jsonToMap (inputSchema ));
323
+ }).toList ();
324
+ }
325
+
278
326
private static class ChatCompletionBuilder {
279
327
280
328
private String type ;
@@ -343,4 +391,63 @@ public ChatCompletion build() {
343
391
344
392
}
345
393
394
+ @ Override
395
+ protected ChatCompletionRequest doCreateToolResponseRequest (ChatCompletionRequest previousRequest ,
396
+ RequestMessage responseMessage , List <RequestMessage > conversationHistory ) {
397
+
398
+ List <MediaContent > toolToUseList = responseMessage .content ()
399
+ .stream ()
400
+ .filter (c -> c .type () == MediaContent .Type .TOOL_USE )
401
+ .toList ();
402
+
403
+ List <MediaContent > toolResults = new ArrayList <>();
404
+
405
+ for (MediaContent toolToUse : toolToUseList ) {
406
+
407
+ var functionCallId = toolToUse .id ();
408
+ var functionName = toolToUse .name ();
409
+ var functionArguments = toolToUse .input ();
410
+
411
+ if (!this .functionCallbackRegister .containsKey (functionName )) {
412
+ throw new IllegalStateException ("No function callback found for function name: " + functionName );
413
+ }
414
+
415
+ String functionResponse = this .functionCallbackRegister .get (functionName )
416
+ .call (ModelOptionsUtils .toJsonString (functionArguments ));
417
+
418
+ toolResults .add (new MediaContent (Type .TOOL_RESULT , functionCallId , functionResponse ));
419
+ }
420
+
421
+ // Add the function response to the conversation.
422
+ conversationHistory .add (new RequestMessage (toolResults , Role .USER ));
423
+
424
+ // Recursively call chatCompletionWithTools until the model doesn't call a
425
+ // functions anymore.
426
+ return ChatCompletionRequest .from (previousRequest ).withMessages (conversationHistory ).build ();
427
+ }
428
+
429
+ @ Override
430
+ protected List <RequestMessage > doGetUserMessages (ChatCompletionRequest request ) {
431
+ return request .messages ();
432
+ }
433
+
434
+ @ Override
435
+ protected RequestMessage doGetToolResponseMessage (ResponseEntity <ChatCompletion > response ) {
436
+ return new RequestMessage (response .getBody ().content (), Role .ASSISTANT );
437
+ }
438
+
439
+ @ Override
440
+ protected ResponseEntity <ChatCompletion > doChatCompletion (ChatCompletionRequest request ) {
441
+ return this .anthropicApi .chatCompletionEntity (request );
442
+ }
443
+
444
+ @ SuppressWarnings ("null" )
445
+ @ Override
446
+ protected boolean isToolFunctionCall (ResponseEntity <ChatCompletion > response ) {
447
+ if (response == null || response .getBody () == null || CollectionUtils .isEmpty (response .getBody ().content ())) {
448
+ return false ;
449
+ }
450
+ return response .getBody ().content ().stream ().anyMatch (content -> content .type () == MediaContent .Type .TOOL_USE );
451
+ }
452
+
346
453
}
0 commit comments