Skip to content

Commit d858327

Browse files
jitokimtzolov
authored andcommitted
fix callback merging in AzureOpenAiChatModel constructor
and minor fixes Signed-off-by: jitokim <pigberger70@gmail.com>
1 parent f5b00a0 commit d858327

File tree

6 files changed

+96
-15
lines changed

6 files changed

+96
-15
lines changed

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ public class AnthropicApi {
8080

8181
private final StreamHelper streamHelper = new StreamHelper();
8282

83-
private WebClient webClient;
83+
private final WebClient webClient;
8484

8585
/**
8686
* Create a new client api with DEFAULT_BASE_URL
@@ -261,7 +261,7 @@ public enum ChatModel implements ChatModelDescription {
261261
/**
262262
* The CLAUDE_INSTANT_1_2
263263
*/
264-
CLAUDE_INSTANT_1_2("claude-instant-1.2");
264+
@Deprecated CLAUDE_INSTANT_1_2("claude-instant-1.2");
265265
// @formatter:on
266266

267267
private final String value;
@@ -366,7 +366,7 @@ public enum EventType {
366366
/**
367367
* Artificially created event to aggregate tool use events.
368368
*/
369-
TOOL_USE_AGGREATE
369+
TOOL_USE_AGGREGATE
370370

371371
}
372372

@@ -889,7 +889,7 @@ public static class ToolUseAggregationEvent implements StreamEvent {
889889

890890
@Override
891891
public EventType type() {
892-
return EventType.TOOL_USE_AGGREATE;
892+
return EventType.TOOL_USE_AGGREGATE;
893893
}
894894

895895
/**

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/StreamHelper.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
*
4949
* @author Mariusz Bernacki
5050
* @author Christian Tzolov
51+
* @author Jihoon Kim
5152
* @since 1.0.0
5253
*/
5354
public class StreamHelper {
@@ -85,10 +86,10 @@ public StreamEvent mergeToolUseEvents(StreamEvent previousEvent, StreamEvent eve
8586
}
8687
}
8788
else if (event.type() == EventType.CONTENT_BLOCK_DELTA) {
88-
ContentBlockDeltaEvent contentBolckDelta = (ContentBlockDeltaEvent) event;
89-
if (ContentBlock.Type.INPUT_JSON_DELTA.getValue().equals(contentBolckDelta.delta().type())) {
89+
ContentBlockDeltaEvent contentBlockDelta = (ContentBlockDeltaEvent) event;
90+
if (ContentBlock.Type.INPUT_JSON_DELTA.getValue().equals(contentBlockDelta.delta().type())) {
9091
return eventAggregator
91-
.appendPartialJson(((ContentBlockDeltaJson) contentBolckDelta.delta()).partialJson());
92+
.appendPartialJson(((ContentBlockDeltaJson) contentBlockDelta.delta()).partialJson());
9293
}
9394
}
9495
else if (event.type() == EventType.CONTENT_BLOCK_STOP) {
@@ -119,7 +120,7 @@ public ChatCompletionResponse eventToChatCompletionResponse(StreamEvent event,
119120
.withUsage(messageStartEvent.message().usage())
120121
.withContent(new ArrayList<>());
121122
}
122-
else if (event.type().equals(EventType.TOOL_USE_AGGREATE)) {
123+
else if (event.type().equals(EventType.TOOL_USE_AGGREGATE)) {
123124
ToolUseAggregationEvent eventToolUseBuilder = (ToolUseAggregationEvent) event;
124125

125126
if (!CollectionUtils.isEmpty(eventToolUseBuilder.getToolContentBlocks())) {

models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@
104104
* @author luocongqiu
105105
* @author timostark
106106
* @author Soby Chacko
107+
* @author Jihoon Kim
107108
* @see ChatModel
108109
* @see com.azure.ai.openai.OpenAIClient
109110
* @since 1.0.0
@@ -160,7 +161,7 @@ public AzureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder, AzureOpenAi
160161

161162
public AzureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder, AzureOpenAiChatOptions options,
162163
FunctionCallbackContext functionCallbackContext, List<FunctionCallback> toolFunctionCallbacks) {
163-
this(openAIClientBuilder, options, functionCallbackContext, List.of(), ObservationRegistry.NOOP);
164+
this(openAIClientBuilder, options, functionCallbackContext, toolFunctionCallbacks, ObservationRegistry.NOOP);
164165
}
165166

166167
public AzureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder, AzureOpenAiChatOptions options,
@@ -235,10 +236,6 @@ public Flux<ChatResponse> stream(Prompt prompt) {
235236
Flux<ChatCompletions> chatCompletionsStream = this.openAIAsyncClient
236237
.getChatCompletionsStream(options.getModel(), options);
237238

238-
// For chunked responses, only the first chunk contains the choice role.
239-
// The rest of the chunks with same ID share the same role.
240-
ConcurrentHashMap<String, String> roleMap = new ConcurrentHashMap<>();
241-
242239
ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
243240
.prompt(prompt)
244241
.provider(AiProvider.AZURE_OPENAI.value())
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
package org.springframework.ai.azure.openai;
2+
3+
import com.azure.ai.openai.OpenAIClientBuilder;
4+
5+
import java.util.List;
6+
7+
import org.junit.jupiter.api.Test;
8+
import org.junit.jupiter.api.extension.ExtendWith;
9+
import org.mockito.Mock;
10+
import org.mockito.junit.jupiter.MockitoExtension;
11+
12+
import org.springframework.ai.model.function.FunctionCallback;
13+
import org.springframework.ai.model.function.FunctionCallbackContext;
14+
15+
/**
16+
* @author Jihoon Kim
17+
*/
18+
@ExtendWith(MockitoExtension.class)
19+
public class AzureOpenAiChatModelTests {
20+
21+
@Mock
22+
OpenAIClientBuilder mockClient;
23+
24+
@Mock
25+
FunctionCallbackContext functionCallbackContext;
26+
27+
@Test
28+
public void createAzureOpenAiChatModelTest() {
29+
String callbackFromChatOptions = "callbackFromChatOptions";
30+
String callbackFromConstructorParam = "callbackFromConstructorParam";
31+
32+
AzureOpenAiChatOptions chatOptions = AzureOpenAiChatOptions.builder()
33+
.withFunctionCallbacks(List.of(new TestFunctionCallback(callbackFromChatOptions)))
34+
.build();
35+
36+
List<FunctionCallback> functionCallbacks = List.of(new TestFunctionCallback(callbackFromConstructorParam));
37+
38+
AzureOpenAiChatModel openAiChatModel = new AzureOpenAiChatModel(mockClient, chatOptions,
39+
functionCallbackContext, functionCallbacks);
40+
41+
assert 2 == openAiChatModel.getFunctionCallbackRegister().size();
42+
43+
assert callbackFromChatOptions == openAiChatModel.getFunctionCallbackRegister()
44+
.get(callbackFromChatOptions)
45+
.getName();
46+
47+
assert callbackFromConstructorParam == openAiChatModel.getFunctionCallbackRegister()
48+
.get(callbackFromConstructorParam)
49+
.getName();
50+
}
51+
52+
private class TestFunctionCallback implements FunctionCallback {
53+
54+
private final String name;
55+
56+
public TestFunctionCallback(String name) {
57+
this.name = name;
58+
}
59+
60+
@Override
61+
public String getName() {
62+
return name;
63+
}
64+
65+
@Override
66+
public String getDescription() {
67+
return null;
68+
}
69+
70+
@Override
71+
public String getInputTypeSchema() {
72+
return null;
73+
}
74+
75+
@Override
76+
public String call(String functionInput) {
77+
return null;
78+
}
79+
80+
}
81+
82+
}

spring-ai-core/src/main/java/org/springframework/ai/chat/model/AbstractToolCallSupport.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
* @author Christian Tzolov
4242
* @author Grogdunn
4343
* @author Thomas Vitale
44+
* @author Jihoon Kim
4445
* @since 1.0.0
4546
*/
4647
public abstract class AbstractToolCallSupport {
@@ -85,7 +86,7 @@ private static List<FunctionCallback> merge(FunctionCallingOptions functionOptio
8586

8687
if (!CollectionUtils.isEmpty(functionOptions.getFunctionCallbacks())) {
8788
toolFunctionCallbacksCopy.addAll(functionOptions.getFunctionCallbacks());
88-
// Make sure that that function callbacks are are registered directly to the
89+
// Make sure that that function callbacks are registered directly to the
8990
// functionCallbackRegister and not passed in the default options.
9091
functionOptions.setFunctionCallbacks(List.of());
9192
}

spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/anthropic-chat-functions.adoc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ public record Request(String location, Unit unit) {}
112112

113113
It is a best practice to annotate the request object with information such that the generated JSON schema of that function is as descriptive as possible to help the AI model pick the correct function to invoke.
114114

115-
The link:https://github.com/spring-projects/spring-ai/blob/main/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/tool/FunctionCallWithFunctionBeanIT.java.java[FunctionCallWithFunctionBeanIT.java] demonstrates this approach.
115+
The link:https://github.com/spring-projects/spring-ai/blob/main/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/tool/FunctionCallWithFunctionBeanIT.java[FunctionCallWithFunctionBeanIT.java] demonstrates this approach.
116116

117117

118118
==== FunctionCallback

0 commit comments

Comments
 (0)