From ba17f202de15f86541b362cf551259d0e746dc47 Mon Sep 17 00:00:00 2001 From: Thomas Vitale Date: Fri, 14 Mar 2025 07:42:23 +0100 Subject: [PATCH] Support userText rendering strategy in RetrievalAugmentationAdvisor The input userText to the advisor might contain the templating special characters for other purposes than templating (e.g. when including source code). A new UserTextProcessor functional interface allows to customize how the userText is processed when taking as input to the RetrievalAugmentationAdvisor. By default, the PromptTemplateUserTextProcessor implementation is used. If you want to disable the rendering step altogether, you can use the NoOpUserTextProcessor implementation. Signed-off-by: Thomas Vitale --- .../client/advisor/NoOpUserTextProcessor.java | 18 +++++ .../PromptTemplateUserTextProcessor.java | 42 +++++++++++ .../advisor/RetrievalAugmentationAdvisor.java | 29 ++++++-- .../client/advisor/UserTextProcessor.java | 38 ++++++++++ .../advisor/NoOpUserTextProcessorTests.java | 22 ++++++ .../PromptTemplateUserTextProcessorTests.java | 73 +++++++++++++++++++ .../api/retrieval-augmented-generation.adoc | 18 +++++ .../RetrievalAugmentationAdvisorIT.java | 31 +++++++- 8 files changed, 260 insertions(+), 11 deletions(-) create mode 100644 spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/NoOpUserTextProcessor.java create mode 100644 spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/PromptTemplateUserTextProcessor.java create mode 100644 spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/UserTextProcessor.java create mode 100644 spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/NoOpUserTextProcessorTests.java create mode 100644 spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/PromptTemplateUserTextProcessorTests.java diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/NoOpUserTextProcessor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/NoOpUserTextProcessor.java new file mode 100644 index 00000000000..07168c35b62 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/NoOpUserTextProcessor.java @@ -0,0 +1,18 @@ +package org.springframework.ai.chat.client.advisor; + +import java.util.Map; + +/** + * A {@link UserTextProcessor} that returns the user text as is. + * + * @author Thomas Vitale + * @since 1.0.0 + */ +public class NoOpUserTextProcessor implements UserTextProcessor { + + @Override + public String process(String userText, Map userParams) { + return userText; + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/PromptTemplateUserTextProcessor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/PromptTemplateUserTextProcessor.java new file mode 100644 index 00000000000..9764f044cf1 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/PromptTemplateUserTextProcessor.java @@ -0,0 +1,42 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.client.advisor; + +import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.util.Assert; + +import java.util.Map; + +/** + * Processes the advised user text with the given user parameters using a + * {@link PromptTemplate}. + * + * @author Thomas Vitale + * @since 1.0.0 + */ +public class PromptTemplateUserTextProcessor implements UserTextProcessor { + + @Override + public String process(String userText, Map userParams) { + Assert.hasText(userText, "userText cannot be null or empty"); + Assert.notNull(userParams, "userParams cannot be null"); + Assert.noNullElements(userParams.keySet(), "userParams keys cannot be null"); + + return new PromptTemplate(userText, userParams).render(); + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisor.java index f2dfbde24aa..eab6d8507f5 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisor.java @@ -29,7 +29,6 @@ import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; import org.springframework.ai.chat.client.advisor.api.BaseAdvisor; import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.ai.document.Document; import org.springframework.ai.rag.Query; import org.springframework.ai.rag.generation.augmentation.ContextualQueryAugmenter; @@ -61,6 +60,8 @@ public final class RetrievalAugmentationAdvisor implements BaseAdvisor { public static final String DOCUMENT_CONTEXT = "rag_document_context"; + private final UserTextProcessor userTextProcessor; + private final List queryTransformers; @Nullable @@ -78,10 +79,12 @@ public final class RetrievalAugmentationAdvisor implements BaseAdvisor { private final int order; - public RetrievalAugmentationAdvisor(@Nullable List queryTransformers, - @Nullable QueryExpander queryExpander, DocumentRetriever documentRetriever, - @Nullable DocumentJoiner documentJoiner, @Nullable QueryAugmenter queryAugmenter, - @Nullable TaskExecutor taskExecutor, @Nullable Scheduler scheduler, @Nullable Integer order) { + public RetrievalAugmentationAdvisor(@Nullable UserTextProcessor userTextProcessor, + @Nullable List queryTransformers, @Nullable QueryExpander queryExpander, + DocumentRetriever documentRetriever, @Nullable DocumentJoiner documentJoiner, + @Nullable QueryAugmenter queryAugmenter, @Nullable TaskExecutor taskExecutor, @Nullable Scheduler scheduler, + @Nullable Integer order) { + this.userTextProcessor = userTextProcessor != null ? userTextProcessor : new PromptTemplateUserTextProcessor(); Assert.notNull(documentRetriever, "documentRetriever cannot be null"); Assert.noNullElements(queryTransformers, "queryTransformers cannot contain null elements"); this.queryTransformers = queryTransformers != null ? queryTransformers : List.of(); @@ -102,9 +105,11 @@ public static Builder builder() { public AdvisedRequest before(AdvisedRequest request) { Map context = new HashMap<>(request.adviseContext()); + String processedUserText = this.userTextProcessor.apply(request.userText(), request.userParams()); + // 0. Create a query from the user text, parameters, and conversation history. Query originalQuery = Query.builder() - .text(new PromptTemplate(request.userText(), request.userParams()).render()) + .text(processedUserText) .history(request.messages()) .context(context) .build(); @@ -183,6 +188,8 @@ private static TaskExecutor buildDefaultTaskExecutor() { public static final class Builder { + private UserTextProcessor userTextProcessor; + private List queryTransformers; private QueryExpander queryExpander; @@ -202,6 +209,11 @@ public static final class Builder { private Builder() { } + public Builder userTextProcessor(UserTextProcessor userTextProcessor) { + this.userTextProcessor = userTextProcessor; + return this; + } + public Builder queryTransformers(List queryTransformers) { this.queryTransformers = queryTransformers; return this; @@ -248,8 +260,9 @@ public Builder order(Integer order) { } public RetrievalAugmentationAdvisor build() { - return new RetrievalAugmentationAdvisor(this.queryTransformers, this.queryExpander, this.documentRetriever, - this.documentJoiner, this.queryAugmenter, this.taskExecutor, this.scheduler, this.order); + return new RetrievalAugmentationAdvisor(this.userTextProcessor, this.queryTransformers, this.queryExpander, + this.documentRetriever, this.documentJoiner, this.queryAugmenter, this.taskExecutor, this.scheduler, + this.order); } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/UserTextProcessor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/UserTextProcessor.java new file mode 100644 index 00000000000..92d5fba303a --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/UserTextProcessor.java @@ -0,0 +1,38 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.client.advisor; + +import java.util.Map; +import java.util.function.BiFunction; + +/** + * Processes the advised user text with the given user parameters. + * + * @author Thomas Vitale + * @since 1.0.0 + */ +@FunctionalInterface +public interface UserTextProcessor extends BiFunction, String> { + + String process(String userText, Map userParams); + + @Override + default String apply(String userText, Map userParams) { + return process(userText, userParams); + } + +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/NoOpUserTextProcessorTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/NoOpUserTextProcessorTests.java new file mode 100644 index 00000000000..4eb6100f2c9 --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/NoOpUserTextProcessorTests.java @@ -0,0 +1,22 @@ +package org.springframework.ai.chat.client.advisor; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +/** + * Unit tests for {@link NoOpUserTextProcessor}. + * + * @author Thomas Vitale + */ +class NoOpUserTextProcessorTests { + + @Test + void process() { + NoOpUserTextProcessor processor = new NoOpUserTextProcessor(); + String userText = "Hello, {World}!"; + String processedText = processor.process(userText, null); + assertEquals(userText, processedText); + } + +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/PromptTemplateUserTextProcessorTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/PromptTemplateUserTextProcessorTests.java new file mode 100644 index 00000000000..161b14be649 --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/PromptTemplateUserTextProcessorTests.java @@ -0,0 +1,73 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.client.advisor; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.NullAndEmptySource; + +import java.util.HashMap; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; + +/** + * Unit tests for {@link PromptTemplateUserTextProcessor}. + * + * @author Thomas Vitale + */ +class PromptTemplateUserTextProcessorTests { + + @ParameterizedTest + @NullAndEmptySource + void processWithNullOrEmptyUserText(String userText) { + PromptTemplateUserTextProcessor processor = new PromptTemplateUserTextProcessor(); + Map userParams = Map.of("name", "William"); + assertThatIllegalArgumentException().isThrownBy(() -> processor.process(userText, userParams)) + .withMessage("userText cannot be null or empty"); + } + + @Test + void processWithNullUserParams() { + PromptTemplateUserTextProcessor processor = new PromptTemplateUserTextProcessor(); + String userText = "Hello, {name}!"; + Map userParams = null; + assertThatIllegalArgumentException().isThrownBy(() -> processor.process(userText, userParams)) + .withMessage("userParams cannot be null"); + } + + @Test + void processWithNullUserParamsKeys() { + PromptTemplateUserTextProcessor processor = new PromptTemplateUserTextProcessor(); + String userText = "Hello, {name}!"; + Map userParams = new HashMap<>(); + userParams.put(null, "William"); + assertThatIllegalArgumentException().isThrownBy(() -> processor.process(userText, userParams)) + .withMessage("userParams keys cannot be null"); + } + + @Test + void process() { + PromptTemplateUserTextProcessor processor = new PromptTemplateUserTextProcessor(); + String userText = "Hello, {name}!"; + Map userParams = Map.of("name", "William"); + String processedText = processor.process(userText, userParams); + assertThat(processedText).isEqualTo("Hello, William!"); + } + +} diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/retrieval-augmented-generation.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/retrieval-augmented-generation.adoc index 44d10874283..d874f157117 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/retrieval-augmented-generation.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/retrieval-augmented-generation.adoc @@ -138,6 +138,24 @@ String answer = chatClient.prompt() See xref:api/retrieval-augmented-generation.adoc#_vectorstoredocumentretriever for more information. +By default, the `RetrievalAugmentationAdvisor` process the input user text with a `PromptTemplate`, ensuring that any template placeholder is correctly rendered before using the text for the retrieval process. +If you want to customize the processing logic, you can provide a custom `UserTextProcessor` to the advisor, either as a lambda or a class. +For example, in case you want to skip the rendering step, you can provide a `NoOpUserTextProcessor`. That is useful if you're planning to use the templating special characters in the user text for other purposes. + +[source,java] +---- +Advisor retrievalAugmentationAdvisor = RetrievalAugmentationAdvisor.builder() + .documentRetriever(VectorStoreDocumentRetriever.builder().vectorStore().build()) + .userTextProcessor(new NoOpUserTextProcessor()) + .build(); + +String answer = chatClient.prompt() + .advisors(retrievalAugmentationAdvisor) + .user(question) + .call() + .content(); +---- + ===== Advanced RAG [source,java] diff --git a/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/client/advisor/RetrievalAugmentationAdvisorIT.java b/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/client/advisor/RetrievalAugmentationAdvisorIT.java index e7170c68b75..db090141710 100644 --- a/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/client/advisor/RetrievalAugmentationAdvisorIT.java +++ b/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/client/advisor/RetrievalAugmentationAdvisorIT.java @@ -16,16 +16,14 @@ package org.springframework.ai.integration.tests.client.advisor; -import java.util.List; - import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; - import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.client.advisor.AbstractChatMemoryAdvisor; import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor; +import org.springframework.ai.chat.client.advisor.NoOpUserTextProcessor; import org.springframework.ai.chat.client.advisor.RetrievalAugmentationAdvisor; import org.springframework.ai.chat.memory.InMemoryChatMemory; import org.springframework.ai.chat.model.ChatResponse; @@ -49,6 +47,8 @@ import org.springframework.boot.test.context.SpringBootTest; import org.springframework.core.io.Resource; +import java.util.List; + import static org.assertj.core.api.Assertions.assertThat; /** @@ -131,6 +131,31 @@ void ragWithRequestFilter() { .isNull(); } + @Test + void ragWithCustomUserTextProcessor() { + String question = "Where does the adventure of {Anacletus} and {Birba} take place?"; + + RetrievalAugmentationAdvisor ragAdvisor = RetrievalAugmentationAdvisor.builder() + .documentRetriever(VectorStoreDocumentRetriever.builder().vectorStore(this.pgVectorStore).build()) + .userTextProcessor(new NoOpUserTextProcessor()) + .build(); + + ChatResponse chatResponse = ChatClient.builder(this.openAiChatModel) + .build() + .prompt(question) + .advisors(ragAdvisor) + .call() + .chatResponse(); + + assertThat(chatResponse).isNotNull(); + + String response = chatResponse.getResult().getOutput().getText(); + System.out.println(response); + assertThat(response).containsIgnoringCase("Highlands"); + + evaluateRelevancy(question, chatResponse); + } + @Test void ragWithCompression() { MessageChatMemoryAdvisor memoryAdvisor = MessageChatMemoryAdvisor.builder(new InMemoryChatMemory()).build();