Skip to content

Commit 9a52182

Browse files
Kehrlanntzolov
authored andcommitted
Propagate reactive Context to AsyncMcpToolCallback
- When calling tools while using ChatModel#stream, store the reactive context in a thread-local, so it can be used by downstream reactive tools. - In AsyncMcpToolCallback, restore the reactive context so it can be accessed by the tool. This will be useful for Spring Security OAuth2 support in reactive scenarios, because it relies on the context. Signed-off-by: Daniel Garnier-Moiroux <git@garnier.wf>
1 parent af07517 commit 9a52182

File tree

13 files changed

+136
-31
lines changed

13 files changed

+136
-31
lines changed

mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallback.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
import org.springframework.ai.chat.model.ToolContext;
2525
import org.springframework.ai.model.ModelOptionsUtils;
26+
import org.springframework.ai.model.tool.internal.ToolCallReactiveContextHolder;
2627
import org.springframework.ai.tool.ToolCallback;
2728
import org.springframework.ai.tool.definition.DefaultToolDefinition;
2829
import org.springframework.ai.tool.definition.ToolDefinition;
@@ -120,7 +121,7 @@ public String call(String functionInput) {
120121
new IllegalStateException("Error calling tool: " + response.content()));
121122
}
122123
return ModelOptionsUtils.toJsonString(response.content());
123-
}).block();
124+
}).contextWrite(ctx -> ctx.putAll(ToolCallReactiveContextHolder.getContext())).block();
124125
}
125126

126127
@Override

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
import org.springframework.ai.content.Media;
6565
import org.springframework.ai.model.ModelOptionsUtils;
6666
import org.springframework.ai.model.tool.DefaultToolExecutionEligibilityPredicate;
67+
import org.springframework.ai.model.tool.internal.ToolCallReactiveContextHolder;
6768
import org.springframework.ai.model.tool.ToolCallingChatOptions;
6869
import org.springframework.ai.model.tool.ToolCallingManager;
6970
import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate;
@@ -265,8 +266,15 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
265266
if (chatResponse.hasFinishReasons(Set.of("tool_use"))) {
266267
// FIXME: bounded elastic needs to be used since tool calling
267268
// is currently only synchronous
268-
return Flux.defer(() -> {
269-
var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, chatResponse);
269+
return Flux.deferContextual((ctx) -> {
270+
// TODO: factor out the tool execution logic with setting context into a uitlity.
271+
ToolExecutionResult toolExecutionResult;
272+
try {
273+
ToolCallReactiveContextHolder.setContext(ctx);
274+
toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, chatResponse);
275+
} finally {
276+
ToolCallReactiveContextHolder.clearContext();
277+
}
270278
if (toolExecutionResult.returnDirect()) {
271279
// Return tool execution result directly to the client.
272280
return Flux.just(ChatResponse.builder().from(chatResponse)
@@ -279,6 +287,7 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
279287
chatResponse);
280288
}
281289
}).subscribeOn(Schedulers.boundedElastic());
290+
282291
} else {
283292
return Mono.empty();
284293
}

models/spring-ai-anthropic/src/test/resources/application-logging-test.properties

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@
1616

1717
logging.level.org.springframework.ai.chat.client.advisor=DEBUG
1818

19-
logging.level.org.springframework.ai.anthropic.api.AnthropicApi=DEBUG
19+
logging.level.org.springframework.ai.anthropic.api.AnthropicApi=INFO

models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@
9595
import org.springframework.ai.model.tool.ToolCallingManager;
9696
import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate;
9797
import org.springframework.ai.model.tool.ToolExecutionResult;
98+
import org.springframework.ai.model.tool.internal.ToolCallReactiveContextHolder;
9899
import org.springframework.ai.observation.conventions.AiProvider;
99100
import org.springframework.ai.support.UsageCalculator;
100101
import org.springframework.ai.tool.definition.ToolDefinition;
@@ -380,8 +381,15 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
380381
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), chatResponse)) {
381382
// FIXME: bounded elastic needs to be used since tool calling
382383
// is currently only synchronous
383-
return Flux.defer(() -> {
384-
var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, chatResponse);
384+
return Flux.deferContextual((ctx) -> {
385+
ToolExecutionResult toolExecutionResult;
386+
try {
387+
ToolCallReactiveContextHolder.setContext(ctx);
388+
toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, chatResponse);
389+
}
390+
finally {
391+
ToolCallReactiveContextHolder.clearContext();
392+
}
385393
if (toolExecutionResult.returnDirect()) {
386394
// Return tool execution result directly to the client.
387395
return Flux.just(ChatResponse.builder()

models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@
101101
import org.springframework.ai.model.tool.ToolCallingManager;
102102
import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate;
103103
import org.springframework.ai.model.tool.ToolExecutionResult;
104+
import org.springframework.ai.model.tool.internal.ToolCallReactiveContextHolder;
104105
import org.springframework.ai.observation.conventions.AiProvider;
105106
import org.springframework.ai.tool.definition.ToolDefinition;
106107
import org.springframework.util.Assert;
@@ -681,8 +682,15 @@ private Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse perviousCh
681682

682683
// FIXME: bounded elastic needs to be used since tool calling
683684
// is currently only synchronous
684-
return Flux.defer(() -> {
685-
var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, chatResponse);
685+
return Flux.deferContextual((ctx) -> {
686+
ToolExecutionResult toolExecutionResult;
687+
try {
688+
ToolCallReactiveContextHolder.setContext(ctx);
689+
toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, chatResponse);
690+
}
691+
finally {
692+
ToolCallReactiveContextHolder.clearContext();
693+
}
686694

687695
if (toolExecutionResult.returnDirect()) {
688696
// Return tool execution result directly to the client.

models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatModel.java

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
import org.springframework.ai.model.tool.ToolCallingManager;
6363
import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate;
6464
import org.springframework.ai.model.tool.ToolExecutionResult;
65+
import org.springframework.ai.model.tool.internal.ToolCallReactiveContextHolder;
6566
import org.springframework.ai.retry.RetryUtils;
6667
import org.springframework.ai.support.UsageCalculator;
6768
import org.springframework.ai.tool.definition.ToolDefinition;
@@ -286,10 +287,16 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
286287
// @formatter:off
287288
Flux<ChatResponse> flux = chatResponse.flatMap(response -> {
288289
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) {
289-
return Flux.defer(() -> {
290-
// FIXME: bounded elastic needs to be used since tool calling
291-
// is currently only synchronous
292-
var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
290+
// FIXME: bounded elastic needs to be used since tool calling
291+
// is currently only synchronous
292+
return Flux.deferContextual((ctx) -> {
293+
ToolExecutionResult toolExecutionResult;
294+
try {
295+
ToolCallReactiveContextHolder.setContext(ctx);
296+
toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
297+
} finally {
298+
ToolCallReactiveContextHolder.clearContext();
299+
}
293300
if (toolExecutionResult.returnDirect()) {
294301
// Return tool execution result directly to the client.
295302
return Flux.just(ChatResponse.builder().from(response)

models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
import org.springframework.ai.model.tool.ToolCallingManager;
6666
import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate;
6767
import org.springframework.ai.model.tool.ToolExecutionResult;
68+
import org.springframework.ai.model.tool.internal.ToolCallReactiveContextHolder;
6869
import org.springframework.ai.retry.RetryUtils;
6970
import org.springframework.ai.tool.definition.ToolDefinition;
7071
import org.springframework.http.ResponseEntity;
@@ -370,10 +371,16 @@ public Flux<ChatResponse> stream(Prompt prompt) {
370371

371372
Flux<ChatResponse> flux = chatResponse.flatMap(response -> {
372373
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(requestPrompt.getOptions(), response)) {
373-
return Flux.defer(() -> {
374-
// FIXME: bounded elastic needs to be used since tool calling
375-
// is currently only synchronous
376-
var toolExecutionResult = this.toolCallingManager.executeToolCalls(requestPrompt, response);
374+
// FIXME: bounded elastic needs to be used since tool calling
375+
// is currently only synchronous
376+
return Flux.deferContextual((ctx) -> {
377+
ToolExecutionResult toolExecutionResult;
378+
try {
379+
ToolCallReactiveContextHolder.setContext(ctx);
380+
toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
381+
} finally {
382+
ToolCallReactiveContextHolder.clearContext();
383+
}
377384
if (toolExecutionResult.returnDirect()) {
378385
// Return tool execution result directly to the client.
379386
return Flux.just(ChatResponse.builder().from(response)

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
import org.springframework.ai.model.tool.ToolCallingManager;
6565
import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate;
6666
import org.springframework.ai.model.tool.ToolExecutionResult;
67+
import org.springframework.ai.model.tool.internal.ToolCallReactiveContextHolder;
6768
import org.springframework.ai.retry.RetryUtils;
6869
import org.springframework.ai.support.UsageCalculator;
6970
import org.springframework.ai.tool.definition.ToolDefinition;
@@ -316,8 +317,14 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
316317
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) {
317318
// FIXME: bounded elastic needs to be used since tool calling
318319
// is currently only synchronous
319-
return Flux.defer(() -> {
320-
var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
320+
return Flux.deferContextual((ctx) -> {
321+
ToolExecutionResult toolExecutionResult;
322+
try {
323+
ToolCallReactiveContextHolder.setContext(ctx);
324+
toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
325+
} finally {
326+
ToolCallReactiveContextHolder.clearContext();
327+
}
321328
if (toolExecutionResult.returnDirect()) {
322329
// Return tool execution result directly to the client.
323330
return Flux.just(ChatResponse.builder().from(response)

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
import org.springframework.ai.model.tool.ToolCallingManager;
5555
import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate;
5656
import org.springframework.ai.model.tool.ToolExecutionResult;
57+
import org.springframework.ai.model.tool.internal.ToolCallReactiveContextHolder;
5758
import org.springframework.ai.ollama.api.OllamaApi;
5859
import org.springframework.ai.ollama.api.OllamaApi.ChatRequest;
5960
import org.springframework.ai.ollama.api.OllamaApi.Message.Role;
@@ -351,8 +352,14 @@ private Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCh
351352
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) {
352353
// FIXME: bounded elastic needs to be used since tool calling
353354
// is currently only synchronous
354-
return Flux.defer(() -> {
355-
var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
355+
return Flux.deferContextual((ctx) -> {
356+
ToolExecutionResult toolExecutionResult;
357+
try {
358+
ToolCallReactiveContextHolder.setContext(ctx);
359+
toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
360+
} finally {
361+
ToolCallReactiveContextHolder.clearContext();
362+
}
356363
if (toolExecutionResult.returnDirect()) {
357364
// Return tool execution result directly to the client.
358365
return Flux.just(ChatResponse.builder().from(response)

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
import org.springframework.ai.model.tool.ToolCallingManager;
6262
import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate;
6363
import org.springframework.ai.model.tool.ToolExecutionResult;
64+
import org.springframework.ai.model.tool.internal.ToolCallReactiveContextHolder;
6465
import org.springframework.ai.openai.api.OpenAiApi;
6566
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion;
6667
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion.Choice;
@@ -363,10 +364,16 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
363364
// @formatter:off
364365
Flux<ChatResponse> flux = chatResponse.flatMap(response -> {
365366
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) {
366-
return Flux.defer(() -> {
367-
// FIXME: bounded elastic needs to be used since tool calling
368-
// is currently only synchronous
369-
var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
367+
// FIXME: bounded elastic needs to be used since tool calling
368+
// is currently only synchronous
369+
return Flux.deferContextual((ctx) -> {
370+
ToolExecutionResult toolExecutionResult;
371+
try {
372+
ToolCallReactiveContextHolder.setContext(ctx);
373+
toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
374+
} finally {
375+
ToolCallReactiveContextHolder.clearContext();
376+
}
370377
if (toolExecutionResult.returnDirect()) {
371378
// Return tool execution result directly to the client.
372379
return Flux.just(ChatResponse.builder().from(response)

0 commit comments

Comments
 (0)