Skip to content

Commit d759fb2

Browse files
ThomasVitalemarkpollack
authored andcommitted
Modular RAG: Orchestration and Post-Retrieval
Pre-Retrieval: * Consolidated naming and documentation Retrieval: * Consolidated naming and documentation * Introduced DocumentJoiner sub-module and CompositionDocumentJoiner operator Post-Retrieval: * Introduced main interfaces for sub-modules. Implementation waiting for missing features in Document APIs Orchestration: * Introduced QueryRouter sub-module and AllDocumentRetrieversQueryRouter operator Generation: * Consolidated naming and documentation Advisor: * Introduced BaseAdvisor to reduce boilerplate when implementing Advisors * Extended RetrievalAugmentationAdvisor to include the new sub-modules Relates to #gh-1603
1 parent c783c6b commit d759fb2

File tree

44 files changed

+1181
-243
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+1181
-243
lines changed

spring-ai-core/pom.xml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,11 @@
9696
<artifactId>micrometer-core</artifactId>
9797
</dependency>
9898

99+
<dependency>
100+
<groupId>io.micrometer</groupId>
101+
<artifactId>context-propagation</artifactId>
102+
</dependency>
103+
99104
<dependency>
100105
<groupId>io.micrometer</groupId>
101106
<artifactId>micrometer-tracing-bridge-otel</artifactId>
@@ -195,4 +200,4 @@
195200
</profiles>
196201

197202

198-
</project>
203+
</project>

spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisor.java

Lines changed: 127 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -16,71 +16,82 @@
1616

1717
package org.springframework.ai.chat.client.advisor;
1818

19-
import java.util.ArrayList;
2019
import java.util.Arrays;
2120
import java.util.HashMap;
2221
import java.util.List;
2322
import java.util.Map;
24-
import java.util.function.Predicate;
23+
import java.util.concurrent.CompletableFuture;
24+
import java.util.stream.Collectors;
2525

26-
import reactor.core.publisher.Flux;
27-
import reactor.core.publisher.Mono;
28-
import reactor.core.scheduler.Schedulers;
26+
import reactor.core.scheduler.Scheduler;
2927

3028
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
3129
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
32-
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor;
33-
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
34-
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor;
35-
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain;
30+
import org.springframework.ai.chat.client.advisor.api.BaseAdvisor;
3631
import org.springframework.ai.chat.model.ChatResponse;
3732
import org.springframework.ai.chat.prompt.PromptTemplate;
3833
import org.springframework.ai.document.Document;
3934
import org.springframework.ai.rag.Query;
40-
import org.springframework.ai.rag.analysis.query.transformation.QueryTransformer;
41-
import org.springframework.ai.rag.augmentation.ContextualQueryAugmentor;
42-
import org.springframework.ai.rag.augmentation.QueryAugmentor;
35+
import org.springframework.ai.rag.generation.augmentation.ContextualQueryAugmenter;
36+
import org.springframework.ai.rag.generation.augmentation.QueryAugmenter;
37+
import org.springframework.ai.rag.orchestration.routing.AllRetrieversQueryRouter;
38+
import org.springframework.ai.rag.orchestration.routing.QueryRouter;
39+
import org.springframework.ai.rag.preretrieval.query.expansion.QueryExpander;
40+
import org.springframework.ai.rag.preretrieval.query.transformation.QueryTransformer;
41+
import org.springframework.ai.rag.retrieval.join.ConcatenationDocumentJoiner;
42+
import org.springframework.ai.rag.retrieval.join.DocumentJoiner;
4343
import org.springframework.ai.rag.retrieval.search.DocumentRetriever;
44+
import org.springframework.core.task.TaskExecutor;
45+
import org.springframework.core.task.support.ContextPropagatingTaskDecorator;
4446
import org.springframework.lang.Nullable;
47+
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
4548
import org.springframework.util.Assert;
46-
import org.springframework.util.StringUtils;
4749

4850
/**
4951
* Advisor that implements common Retrieval Augmented Generation (RAG) flows using the
5052
* building blocks defined in the {@link org.springframework.ai.rag} package and following
5153
* the Modular RAG Architecture.
52-
* <p>
53-
* It's the successor of the {@link QuestionAnswerAdvisor}.
5454
*
5555
* @author Christian Tzolov
5656
* @author Thomas Vitale
5757
* @since 1.0.0
5858
* @see <a href="http://export.arxiv.org/abs/2407.21059">arXiv:2407.21059</a>
5959
* @see <a href="https://export.arxiv.org/abs/2312.10997">arXiv:2312.10997</a>
6060
*/
61-
public final class RetrievalAugmentationAdvisor implements CallAroundAdvisor, StreamAroundAdvisor {
61+
public final class RetrievalAugmentationAdvisor implements BaseAdvisor {
6262

6363
public static final String DOCUMENT_CONTEXT = "rag_document_context";
6464

6565
private final List<QueryTransformer> queryTransformers;
6666

67-
private final DocumentRetriever documentRetriever;
67+
@Nullable
68+
private final QueryExpander queryExpander;
6869

69-
private final QueryAugmentor queryAugmentor;
70+
private final QueryRouter queryRouter;
7071

71-
private final boolean protectFromBlocking;
72+
private final DocumentJoiner documentJoiner;
73+
74+
private final QueryAugmenter queryAugmenter;
75+
76+
private final TaskExecutor taskExecutor;
77+
78+
private final Scheduler scheduler;
7279

7380
private final int order;
7481

75-
public RetrievalAugmentationAdvisor(List<QueryTransformer> queryTransformers, DocumentRetriever documentRetriever,
76-
@Nullable QueryAugmentor queryAugmentor, @Nullable Boolean protectFromBlocking, @Nullable Integer order) {
77-
Assert.notNull(queryTransformers, "queryTransformers cannot be null");
82+
public RetrievalAugmentationAdvisor(@Nullable List<QueryTransformer> queryTransformers,
83+
@Nullable QueryExpander queryExpander, QueryRouter queryRouter, @Nullable DocumentJoiner documentJoiner,
84+
@Nullable QueryAugmenter queryAugmenter, @Nullable TaskExecutor taskExecutor, @Nullable Scheduler scheduler,
85+
@Nullable Integer order) {
86+
Assert.notNull(queryRouter, "queryRouter cannot be null");
7887
Assert.noNullElements(queryTransformers, "queryTransformers cannot contain null elements");
79-
Assert.notNull(documentRetriever, "documentRetriever cannot be null");
80-
this.queryTransformers = queryTransformers;
81-
this.documentRetriever = documentRetriever;
82-
this.queryAugmentor = queryAugmentor != null ? queryAugmentor : ContextualQueryAugmentor.builder().build();
83-
this.protectFromBlocking = protectFromBlocking != null ? protectFromBlocking : true;
88+
this.queryTransformers = queryTransformers != null ? queryTransformers : List.of();
89+
this.queryExpander = queryExpander;
90+
this.queryRouter = queryRouter;
91+
this.documentJoiner = documentJoiner != null ? documentJoiner : new ConcatenationDocumentJoiner();
92+
this.queryAugmenter = queryAugmenter != null ? queryAugmenter : ContextualQueryAugmenter.builder().build();
93+
this.taskExecutor = taskExecutor != null ? taskExecutor : buildDefaultTaskExecutor();
94+
this.scheduler = scheduler != null ? scheduler : BaseAdvisor.DEFAULT_SCHEDULER;
8495
this.order = order != null ? order : 0;
8596
}
8697

@@ -89,41 +100,7 @@ public static Builder builder() {
89100
}
90101

91102
@Override
92-
public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) {
93-
Assert.notNull(advisedRequest, "advisedRequest cannot be null");
94-
Assert.notNull(chain, "chain cannot be null");
95-
96-
AdvisedRequest processedAdvisedRequest = before(advisedRequest);
97-
AdvisedResponse advisedResponse = chain.nextAroundCall(processedAdvisedRequest);
98-
return after(advisedResponse);
99-
}
100-
101-
@Override
102-
public Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) {
103-
Assert.notNull(advisedRequest, "advisedRequest cannot be null");
104-
Assert.notNull(chain, "chain cannot be null");
105-
106-
// This can be executed by both blocking and non-blocking Threads
107-
// E.g. a command line or Tomcat blocking Thread implementation
108-
// or by a WebFlux dispatch in a non-blocking manner.
109-
Flux<AdvisedResponse> advisedResponses = (this.protectFromBlocking) ?
110-
// @formatter:off
111-
Mono.just(advisedRequest)
112-
.publishOn(Schedulers.boundedElastic())
113-
.map(this::before)
114-
.flatMapMany(chain::nextAroundStream)
115-
: chain.nextAroundStream(before(advisedRequest));
116-
// @formatter:on
117-
118-
return advisedResponses.map(ar -> {
119-
if (onFinishReason().test(ar)) {
120-
ar = after(ar);
121-
}
122-
return ar;
123-
});
124-
}
125-
126-
private AdvisedRequest before(AdvisedRequest request) {
103+
public AdvisedRequest before(AdvisedRequest request) {
127104
Map<String, Object> context = new HashMap<>(request.adviseContext());
128105

129106
// 0. Create a query from the user text and parameters.
@@ -135,17 +112,47 @@ private AdvisedRequest before(AdvisedRequest request) {
135112
transformedQuery = queryTransformer.apply(transformedQuery);
136113
}
137114

138-
// 2. Retrieve similar documents for the original query.
139-
List<Document> documents = this.documentRetriever.retrieve(transformedQuery);
115+
// 2. Expand query into one or multiple queries.
116+
List<Query> expandedQueries = this.queryExpander != null ? this.queryExpander.expand(transformedQuery)
117+
: List.of(transformedQuery);
118+
119+
// 3. Get similar documents for each query.
120+
Map<Query, List<List<Document>>> documentsForQuery = expandedQueries.stream()
121+
.map(query -> CompletableFuture.supplyAsync(() -> getDocumentsForQuery(query), this.taskExecutor))
122+
.toList()
123+
.stream()
124+
.map(CompletableFuture::join)
125+
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
126+
127+
// 4. Combine documents retrieved based on multiple queries and from multiple data
128+
// sources.
129+
List<Document> documents = this.documentJoiner.join(documentsForQuery);
140130
context.put(DOCUMENT_CONTEXT, documents);
141131

142-
// 3. Augment user query with the document contextual data.
143-
Query augmentedQuery = this.queryAugmentor.augment(transformedQuery, documents);
132+
// 5. Augment user query with the document contextual data.
133+
Query augmentedQuery = this.queryAugmenter.augment(originalQuery, documents);
144134

135+
// 6. Update advised request with augmented prompt.
145136
return AdvisedRequest.from(request).withUserText(augmentedQuery.text()).withAdviseContext(context).build();
146137
}
147138

148-
private AdvisedResponse after(AdvisedResponse advisedResponse) {
139+
/**
140+
* Processes a single query by routing it to document retrievers and collecting
141+
* documents.
142+
*/
143+
private Map.Entry<Query, List<List<Document>>> getDocumentsForQuery(Query query) {
144+
List<DocumentRetriever> retrievers = this.queryRouter.route(query);
145+
List<List<Document>> documents = retrievers.stream()
146+
.map(retriever -> CompletableFuture.supplyAsync(() -> retriever.retrieve(query), this.taskExecutor))
147+
.toList()
148+
.stream()
149+
.map(CompletableFuture::join)
150+
.toList();
151+
return Map.entry(query, documents);
152+
}
153+
154+
@Override
155+
public AdvisedResponse after(AdvisedResponse advisedResponse) {
149156
ChatResponse.Builder chatResponseBuilder;
150157
if (advisedResponse.response() == null) {
151158
chatResponseBuilder = ChatResponse.builder();
@@ -157,66 +164,91 @@ private AdvisedResponse after(AdvisedResponse advisedResponse) {
157164
return new AdvisedResponse(chatResponseBuilder.build(), advisedResponse.adviseContext());
158165
}
159166

160-
private Predicate<AdvisedResponse> onFinishReason() {
161-
return advisedResponse -> {
162-
ChatResponse chatResponse = advisedResponse.response();
163-
return chatResponse != null && chatResponse.getResults() != null
164-
&& chatResponse.getResults()
165-
.stream()
166-
.anyMatch(result -> result != null && result.getMetadata() != null
167-
&& StringUtils.hasText(result.getMetadata().getFinishReason()));
168-
};
169-
}
170-
171167
@Override
172-
public String getName() {
173-
return this.getClass().getSimpleName();
168+
public Scheduler getScheduler() {
169+
return this.scheduler;
174170
}
175171

176172
@Override
177173
public int getOrder() {
178174
return this.order;
179175
}
180176

177+
private static TaskExecutor buildDefaultTaskExecutor() {
178+
ThreadPoolTaskExecutor taskExecutor = new ThreadPoolTaskExecutor();
179+
taskExecutor.setThreadNamePrefix("ai-advisor-");
180+
taskExecutor.setCorePoolSize(4);
181+
taskExecutor.setMaxPoolSize(16);
182+
taskExecutor.setTaskDecorator(new ContextPropagatingTaskDecorator());
183+
taskExecutor.initialize();
184+
return taskExecutor;
185+
}
186+
181187
public static final class Builder {
182188

183-
private final List<QueryTransformer> queryTransformers = new ArrayList<>();
189+
private List<QueryTransformer> queryTransformers;
184190

185-
private DocumentRetriever documentRetriever;
191+
private QueryExpander queryExpander;
186192

187-
private QueryAugmentor queryAugmentor;
193+
private QueryRouter queryRouter;
188194

189-
private Boolean protectFromBlocking;
195+
private DocumentJoiner documentJoiner;
196+
197+
private QueryAugmenter queryAugmenter;
198+
199+
private TaskExecutor taskExecutor;
200+
201+
private Scheduler scheduler;
190202

191203
private Integer order;
192204

193205
private Builder() {
194206
}
195207

196208
public Builder queryTransformers(List<QueryTransformer> queryTransformers) {
197-
Assert.notNull(queryTransformers, "queryTransformers cannot be null");
198-
this.queryTransformers.addAll(queryTransformers);
209+
this.queryTransformers = queryTransformers;
199210
return this;
200211
}
201212

202213
public Builder queryTransformers(QueryTransformer... queryTransformers) {
203-
Assert.notNull(queryTransformers, "queryTransformers cannot be null");
204-
this.queryTransformers.addAll(Arrays.asList(queryTransformers));
214+
this.queryTransformers = Arrays.asList(queryTransformers);
215+
return this;
216+
}
217+
218+
public Builder queryExpander(QueryExpander queryExpander) {
219+
this.queryExpander = queryExpander;
220+
return this;
221+
}
222+
223+
public Builder queryRouter(QueryRouter queryRouter) {
224+
Assert.isNull(this.queryRouter, "Cannot set both documentRetriever and queryRouter");
225+
this.queryRouter = queryRouter;
205226
return this;
206227
}
207228

208229
public Builder documentRetriever(DocumentRetriever documentRetriever) {
209-
this.documentRetriever = documentRetriever;
230+
Assert.isNull(this.queryRouter, "Cannot set both documentRetriever and queryRouter");
231+
this.queryRouter = AllRetrieversQueryRouter.builder().documentRetrievers(documentRetriever).build();
232+
return this;
233+
}
234+
235+
public Builder documentJoiner(DocumentJoiner documentJoiner) {
236+
this.documentJoiner = documentJoiner;
237+
return this;
238+
}
239+
240+
public Builder queryAugmenter(QueryAugmenter queryAugmenter) {
241+
this.queryAugmenter = queryAugmenter;
210242
return this;
211243
}
212244

213-
public Builder queryAugmentor(QueryAugmentor queryAugmentor) {
214-
this.queryAugmentor = queryAugmentor;
245+
public Builder taskExecutor(TaskExecutor taskExecutor) {
246+
this.taskExecutor = taskExecutor;
215247
return this;
216248
}
217249

218-
public Builder protectFromBlocking(Boolean protectFromBlocking) {
219-
this.protectFromBlocking = protectFromBlocking;
250+
public Builder scheduler(Scheduler scheduler) {
251+
this.scheduler = scheduler;
220252
return this;
221253
}
222254

@@ -226,8 +258,8 @@ public Builder order(Integer order) {
226258
}
227259

228260
public RetrievalAugmentationAdvisor build() {
229-
return new RetrievalAugmentationAdvisor(this.queryTransformers, this.documentRetriever, this.queryAugmentor,
230-
this.protectFromBlocking, this.order);
261+
return new RetrievalAugmentationAdvisor(this.queryTransformers, this.queryExpander, this.queryRouter,
262+
this.documentJoiner, this.queryAugmenter, this.taskExecutor, this.scheduler, this.order);
231263
}
232264

233265
}

0 commit comments

Comments
 (0)