Skip to content

Commit 90cab21

Browse files
ThomasVitalemarkpollack
authored andcommitted
Inner advisor should handle output format instructions
Output format instructions should not be included until the very last advisor runs, otherwise there's a risk of templating failure if more than one advisor tries to render the prompt template. This change guarantees the output format instructions are always included right before calling the chat model, without the risk of previous advisors interfering with it. Signed-off-by: Thomas Vitale <ThomasVitale@users.noreply.github.com>
1 parent f4d4ddd commit 90cab21

File tree

5 files changed

+157
-9
lines changed

5 files changed

+157
-9
lines changed

spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/ChatClientAttributes.java

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,7 @@
2121
*
2222
* @author Thomas Vitale
2323
* @since 1.0.0
24-
* @deprecated only introduced to smooth the transition to the new APIs and ensure
25-
* backward compatibility
2624
*/
27-
@Deprecated
2825
public enum ChatClientAttributes {
2926

3027
//@formatter:off
@@ -33,7 +30,6 @@ public enum ChatClientAttributes {
3330
ADVISORS("spring.ai.chat.client.advisors"),
3431
@Deprecated // Only for backward compatibility until the next release.
3532
CHAT_MODEL("spring.ai.chat.client.model"),
36-
@Deprecated // Only for backward compatibility until the next release.
3733
OUTPUT_FORMAT("spring.ai.chat.client.output.format"),
3834
@Deprecated // Only for backward compatibility until the next release.
3935
USER_PARAMS("spring.ai.chat.client.user.params"),

spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -496,11 +496,13 @@ private ChatClientResponse doGetObservableChatClientResponse(ChatClientRequest c
496496

497497
private ChatClientResponse doGetObservableChatClientResponse(ChatClientRequest chatClientRequest,
498498
@Nullable String outputFormat) {
499-
ChatClientRequest formattedChatClientRequest = StringUtils.hasText(outputFormat)
500-
? augmentPromptWithFormatInstructions(chatClientRequest, outputFormat) : chatClientRequest;
499+
500+
if (outputFormat != null) {
501+
chatClientRequest.context().put(ChatClientAttributes.OUTPUT_FORMAT.getKey(), outputFormat);
502+
}
501503

502504
ChatClientObservationContext observationContext = ChatClientObservationContext.builder()
503-
.request(formattedChatClientRequest)
505+
.request(chatClientRequest)
504506
.advisors(advisorChain.getCallAdvisors())
505507
.stream(false)
506508
.withFormat(outputFormat)
@@ -510,7 +512,7 @@ private ChatClientResponse doGetObservableChatClientResponse(ChatClientRequest c
510512
DEFAULT_CHAT_CLIENT_OBSERVATION_CONVENTION, () -> observationContext, observationRegistry);
511513
var chatClientResponse = observation.observe(() -> {
512514
// Apply the advisor chain that terminates with the ChatModelCallAdvisor.
513-
return advisorChain.nextCall(formattedChatClientRequest);
515+
return advisorChain.nextCall(chatClientRequest);
514516
});
515517
return chatClientResponse != null ? chatClientResponse : ChatClientResponse.builder().build();
516518
}

spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/ChatModelCallAdvisor.java

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,17 @@
1616

1717
package org.springframework.ai.chat.client.advisor;
1818

19+
import org.springframework.ai.chat.client.ChatClientAttributes;
1920
import org.springframework.ai.chat.client.ChatClientRequest;
2021
import org.springframework.ai.chat.client.ChatClientResponse;
2122
import org.springframework.ai.chat.client.advisor.api.CallAdvisor;
2223
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
2324
import org.springframework.ai.chat.model.ChatModel;
2425
import org.springframework.ai.chat.model.ChatResponse;
26+
import org.springframework.ai.chat.prompt.Prompt;
2527
import org.springframework.core.Ordered;
2628
import org.springframework.util.Assert;
29+
import org.springframework.util.StringUtils;
2730

2831
import java.util.Map;
2932

@@ -46,9 +49,29 @@ private ChatModelCallAdvisor(ChatModel chatModel) {
4649
public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAroundAdvisorChain chain) {
4750
Assert.notNull(chatClientRequest, "the chatClientRequest cannot be null");
4851

49-
ChatResponse chatResponse = chatModel.call(chatClientRequest.prompt());
52+
ChatClientRequest formattedChatClientRequest = augmentWithFormatInstructions(chatClientRequest);
53+
54+
ChatResponse chatResponse = chatModel.call(formattedChatClientRequest.prompt());
5055
return ChatClientResponse.builder()
5156
.chatResponse(chatResponse)
57+
.context(Map.copyOf(formattedChatClientRequest.context()))
58+
.build();
59+
}
60+
61+
private static ChatClientRequest augmentWithFormatInstructions(ChatClientRequest chatClientRequest) {
62+
String outputFormat = (String) chatClientRequest.context().get(ChatClientAttributes.OUTPUT_FORMAT.getKey());
63+
64+
if (!StringUtils.hasText(outputFormat)) {
65+
return chatClientRequest;
66+
}
67+
68+
Prompt augmentedPrompt = chatClientRequest.prompt()
69+
.augmentUserMessage(userMessage -> userMessage.mutate()
70+
.text(userMessage.getText() + System.lineSeparator() + outputFormat)
71+
.build());
72+
73+
return ChatClientRequest.builder()
74+
.prompt(augmentedPrompt)
5275
.context(Map.copyOf(chatClientRequest.context()))
5376
.build();
5477
}

spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/client/advisor/QuestionAnswerAdvisorIT.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,28 @@ void qaCustomPromptTemplate() {
163163
evaluateRelevancy(question, chatResponse);
164164
}
165165

166+
@Test
167+
void qaOutputConverter() {
168+
String question = "Where does the adventure of Anacletus and Birba take place?";
169+
170+
QuestionAnswerAdvisor qaAdvisor = QuestionAnswerAdvisor.builder(this.pgVectorStore).build();
171+
172+
Answer answer = ChatClient.builder(this.openAiChatModel)
173+
.build()
174+
.prompt(question)
175+
.advisors(qaAdvisor)
176+
.call()
177+
.entity(Answer.class);
178+
179+
assertThat(answer).isNotNull();
180+
181+
System.out.println(answer);
182+
assertThat(answer.content()).containsIgnoringCase("Highlands");
183+
}
184+
185+
private record Answer(String content) {
186+
}
187+
166188
private void evaluateRelevancy(String question, ChatResponse chatResponse) {
167189
EvaluationRequest evaluationRequest = new EvaluationRequest(question,
168190
chatResponse.getMetadata().get(QuestionAnswerAdvisor.RETRIEVED_DOCUMENTS),
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
/*
2+
* Copyright 2023-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.integration.tests.client.advisor;
18+
19+
import org.junit.jupiter.api.AfterEach;
20+
import org.junit.jupiter.api.BeforeEach;
21+
import org.junit.jupiter.api.Test;
22+
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
23+
import org.springframework.ai.chat.client.ChatClient;
24+
import org.springframework.ai.chat.client.advisor.vectorstore.QuestionAnswerAdvisor;
25+
import org.springframework.ai.chat.model.ChatResponse;
26+
import org.springframework.ai.converter.BeanOutputConverter;
27+
import org.springframework.ai.document.Document;
28+
import org.springframework.ai.document.DocumentReader;
29+
import org.springframework.ai.integration.tests.TestApplication;
30+
import org.springframework.ai.openai.OpenAiChatModel;
31+
import org.springframework.ai.openai.OpenAiChatOptions;
32+
import org.springframework.ai.reader.markdown.MarkdownDocumentReader;
33+
import org.springframework.ai.reader.markdown.config.MarkdownDocumentReaderConfig;
34+
import org.springframework.ai.vectorstore.pgvector.PgVectorStore;
35+
import org.springframework.beans.factory.annotation.Autowired;
36+
import org.springframework.beans.factory.annotation.Value;
37+
import org.springframework.boot.test.context.SpringBootTest;
38+
import org.springframework.core.io.Resource;
39+
import reactor.core.publisher.Flux;
40+
41+
import java.util.List;
42+
import java.util.stream.Collectors;
43+
44+
import static org.assertj.core.api.Assertions.assertThat;
45+
46+
/**
47+
* Integration tests for {@link QuestionAnswerAdvisor} with streaming responses.
48+
*
49+
*/
50+
@SpringBootTest(classes = TestApplication.class)
51+
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".*")
52+
public class QuestionAnswerAdvisorStreamIT {
53+
54+
private List<Document> knowledgeBaseDocuments;
55+
56+
@Autowired
57+
OpenAiChatModel openAiChatModel;
58+
59+
@Autowired
60+
PgVectorStore pgVectorStore;
61+
62+
@Value("${classpath:documents/knowledge-base.md}")
63+
Resource knowledgeBaseResource;
64+
65+
@BeforeEach
66+
void setUp() {
67+
DocumentReader markdownReader = new MarkdownDocumentReader(this.knowledgeBaseResource,
68+
MarkdownDocumentReaderConfig.defaultConfig());
69+
this.knowledgeBaseDocuments = markdownReader.read();
70+
this.pgVectorStore.add(this.knowledgeBaseDocuments);
71+
}
72+
73+
@AfterEach
74+
void tearDown() {
75+
this.pgVectorStore.delete(this.knowledgeBaseDocuments.stream().map(Document::getId).toList());
76+
}
77+
78+
@Test
79+
void qaStreamBasic() {
80+
String question = "Where does the adventure of Anacletus and Birba take place?";
81+
82+
QuestionAnswerAdvisor qaAdvisor = QuestionAnswerAdvisor.builder(this.pgVectorStore).build();
83+
84+
// Test streaming with the QuestionAnswerAdvisor
85+
// This verifies the fix works in the streaming context too
86+
Flux<String> responseFlux = ChatClient.builder(this.openAiChatModel)
87+
.build()
88+
.prompt(question)
89+
.advisors(qaAdvisor)
90+
.options(OpenAiChatOptions.builder().streamUsage(true).build())
91+
.stream()
92+
.content();
93+
94+
// Collect the streamed responses
95+
String response = responseFlux.collectList().block().stream().collect(Collectors.joining());
96+
97+
// Verify the response contains the expected content
98+
assertThat(response).isNotEmpty();
99+
assertThat(response).containsIgnoringCase("Highlands");
100+
}
101+
102+
private record Answer(String content) {
103+
}
104+
105+
}

0 commit comments

Comments
 (0)