16
16
17
17
package org .springframework .ai .chat .client .advisor ;
18
18
19
- import java .util .ArrayList ;
20
19
import java .util .Arrays ;
21
20
import java .util .HashMap ;
22
21
import java .util .List ;
23
22
import java .util .Map ;
24
- import java .util .function .Predicate ;
23
+ import java .util .concurrent .CompletableFuture ;
24
+ import java .util .stream .Collectors ;
25
25
26
- import reactor .core .publisher .Flux ;
27
- import reactor .core .publisher .Mono ;
28
- import reactor .core .scheduler .Schedulers ;
26
+ import reactor .core .scheduler .Scheduler ;
29
27
30
28
import org .springframework .ai .chat .client .advisor .api .AdvisedRequest ;
31
29
import org .springframework .ai .chat .client .advisor .api .AdvisedResponse ;
32
- import org .springframework .ai .chat .client .advisor .api .CallAroundAdvisor ;
33
- import org .springframework .ai .chat .client .advisor .api .CallAroundAdvisorChain ;
34
- import org .springframework .ai .chat .client .advisor .api .StreamAroundAdvisor ;
35
- import org .springframework .ai .chat .client .advisor .api .StreamAroundAdvisorChain ;
30
+ import org .springframework .ai .chat .client .advisor .api .BaseAdvisor ;
36
31
import org .springframework .ai .chat .model .ChatResponse ;
37
32
import org .springframework .ai .chat .prompt .PromptTemplate ;
38
33
import org .springframework .ai .document .Document ;
39
34
import org .springframework .ai .rag .Query ;
40
- import org .springframework .ai .rag .analysis .query .transformation .QueryTransformer ;
41
- import org .springframework .ai .rag .augmentation .ContextualQueryAugmentor ;
42
- import org .springframework .ai .rag .augmentation .QueryAugmentor ;
35
+ import org .springframework .ai .rag .generation .augmentation .ContextualQueryAugmenter ;
36
+ import org .springframework .ai .rag .generation .augmentation .QueryAugmenter ;
37
+ import org .springframework .ai .rag .orchestration .routing .AllRetrieversQueryRouter ;
38
+ import org .springframework .ai .rag .orchestration .routing .QueryRouter ;
39
+ import org .springframework .ai .rag .preretrieval .query .expansion .QueryExpander ;
40
+ import org .springframework .ai .rag .preretrieval .query .transformation .QueryTransformer ;
41
+ import org .springframework .ai .rag .retrieval .join .ConcatenationDocumentJoiner ;
42
+ import org .springframework .ai .rag .retrieval .join .DocumentJoiner ;
43
43
import org .springframework .ai .rag .retrieval .search .DocumentRetriever ;
44
+ import org .springframework .core .task .TaskExecutor ;
45
+ import org .springframework .core .task .support .ContextPropagatingTaskDecorator ;
44
46
import org .springframework .lang .Nullable ;
47
+ import org .springframework .scheduling .concurrent .ThreadPoolTaskExecutor ;
45
48
import org .springframework .util .Assert ;
46
- import org .springframework .util .StringUtils ;
47
49
48
50
/**
49
51
* Advisor that implements common Retrieval Augmented Generation (RAG) flows using the
50
52
* building blocks defined in the {@link org.springframework.ai.rag} package and following
51
53
* the Modular RAG Architecture.
52
- * <p>
53
- * It's the successor of the {@link QuestionAnswerAdvisor}.
54
54
*
55
55
* @author Christian Tzolov
56
56
* @author Thomas Vitale
57
57
* @since 1.0.0
58
58
* @see <a href="http://export.arxiv.org/abs/2407.21059">arXiv:2407.21059</a>
59
59
* @see <a href="https://export.arxiv.org/abs/2312.10997">arXiv:2312.10997</a>
60
60
*/
61
- public final class RetrievalAugmentationAdvisor implements CallAroundAdvisor , StreamAroundAdvisor {
61
+ public final class RetrievalAugmentationAdvisor implements BaseAdvisor {
62
62
63
63
public static final String DOCUMENT_CONTEXT = "rag_document_context" ;
64
64
65
65
private final List <QueryTransformer > queryTransformers ;
66
66
67
- private final DocumentRetriever documentRetriever ;
67
+ @ Nullable
68
+ private final QueryExpander queryExpander ;
68
69
69
- private final QueryAugmentor queryAugmentor ;
70
+ private final QueryRouter queryRouter ;
70
71
71
- private final boolean protectFromBlocking ;
72
+ private final DocumentJoiner documentJoiner ;
73
+
74
+ private final QueryAugmenter queryAugmenter ;
75
+
76
+ private final TaskExecutor taskExecutor ;
77
+
78
+ private final Scheduler scheduler ;
72
79
73
80
private final int order ;
74
81
75
- public RetrievalAugmentationAdvisor (List <QueryTransformer > queryTransformers , DocumentRetriever documentRetriever ,
76
- @ Nullable QueryAugmentor queryAugmentor , @ Nullable Boolean protectFromBlocking , @ Nullable Integer order ) {
77
- Assert .notNull (queryTransformers , "queryTransformers cannot be null" );
82
+ public RetrievalAugmentationAdvisor (@ Nullable List <QueryTransformer > queryTransformers ,
83
+ @ Nullable QueryExpander queryExpander , QueryRouter queryRouter , @ Nullable DocumentJoiner documentJoiner ,
84
+ @ Nullable QueryAugmenter queryAugmenter , @ Nullable TaskExecutor taskExecutor , @ Nullable Scheduler scheduler ,
85
+ @ Nullable Integer order ) {
86
+ Assert .notNull (queryRouter , "queryRouter cannot be null" );
78
87
Assert .noNullElements (queryTransformers , "queryTransformers cannot contain null elements" );
79
- Assert .notNull (documentRetriever , "documentRetriever cannot be null" );
80
- this .queryTransformers = queryTransformers ;
81
- this .documentRetriever = documentRetriever ;
82
- this .queryAugmentor = queryAugmentor != null ? queryAugmentor : ContextualQueryAugmentor .builder ().build ();
83
- this .protectFromBlocking = protectFromBlocking != null ? protectFromBlocking : true ;
88
+ this .queryTransformers = queryTransformers != null ? queryTransformers : List .of ();
89
+ this .queryExpander = queryExpander ;
90
+ this .queryRouter = queryRouter ;
91
+ this .documentJoiner = documentJoiner != null ? documentJoiner : new ConcatenationDocumentJoiner ();
92
+ this .queryAugmenter = queryAugmenter != null ? queryAugmenter : ContextualQueryAugmenter .builder ().build ();
93
+ this .taskExecutor = taskExecutor != null ? taskExecutor : buildDefaultTaskExecutor ();
94
+ this .scheduler = scheduler != null ? scheduler : BaseAdvisor .DEFAULT_SCHEDULER ;
84
95
this .order = order != null ? order : 0 ;
85
96
}
86
97
@@ -89,41 +100,7 @@ public static Builder builder() {
89
100
}
90
101
91
102
@ Override
92
- public AdvisedResponse aroundCall (AdvisedRequest advisedRequest , CallAroundAdvisorChain chain ) {
93
- Assert .notNull (advisedRequest , "advisedRequest cannot be null" );
94
- Assert .notNull (chain , "chain cannot be null" );
95
-
96
- AdvisedRequest processedAdvisedRequest = before (advisedRequest );
97
- AdvisedResponse advisedResponse = chain .nextAroundCall (processedAdvisedRequest );
98
- return after (advisedResponse );
99
- }
100
-
101
- @ Override
102
- public Flux <AdvisedResponse > aroundStream (AdvisedRequest advisedRequest , StreamAroundAdvisorChain chain ) {
103
- Assert .notNull (advisedRequest , "advisedRequest cannot be null" );
104
- Assert .notNull (chain , "chain cannot be null" );
105
-
106
- // This can be executed by both blocking and non-blocking Threads
107
- // E.g. a command line or Tomcat blocking Thread implementation
108
- // or by a WebFlux dispatch in a non-blocking manner.
109
- Flux <AdvisedResponse > advisedResponses = (this .protectFromBlocking ) ?
110
- // @formatter:off
111
- Mono .just (advisedRequest )
112
- .publishOn (Schedulers .boundedElastic ())
113
- .map (this ::before )
114
- .flatMapMany (chain ::nextAroundStream )
115
- : chain .nextAroundStream (before (advisedRequest ));
116
- // @formatter:on
117
-
118
- return advisedResponses .map (ar -> {
119
- if (onFinishReason ().test (ar )) {
120
- ar = after (ar );
121
- }
122
- return ar ;
123
- });
124
- }
125
-
126
- private AdvisedRequest before (AdvisedRequest request ) {
103
+ public AdvisedRequest before (AdvisedRequest request ) {
127
104
Map <String , Object > context = new HashMap <>(request .adviseContext ());
128
105
129
106
// 0. Create a query from the user text and parameters.
@@ -135,17 +112,47 @@ private AdvisedRequest before(AdvisedRequest request) {
135
112
transformedQuery = queryTransformer .apply (transformedQuery );
136
113
}
137
114
138
- // 2. Retrieve similar documents for the original query.
139
- List <Document > documents = this .documentRetriever .retrieve (transformedQuery );
115
+ // 2. Expand query into one or multiple queries.
116
+ List <Query > expandedQueries = this .queryExpander != null ? this .queryExpander .expand (transformedQuery )
117
+ : List .of (transformedQuery );
118
+
119
+ // 3. Get similar documents for each query.
120
+ Map <Query , List <List <Document >>> documentsForQuery = expandedQueries .stream ()
121
+ .map (query -> CompletableFuture .supplyAsync (() -> getDocumentsForQuery (query ), this .taskExecutor ))
122
+ .toList ()
123
+ .stream ()
124
+ .map (CompletableFuture ::join )
125
+ .collect (Collectors .toMap (Map .Entry ::getKey , Map .Entry ::getValue ));
126
+
127
+ // 4. Combine documents retrieved based on multiple queries and from multiple data
128
+ // sources.
129
+ List <Document > documents = this .documentJoiner .join (documentsForQuery );
140
130
context .put (DOCUMENT_CONTEXT , documents );
141
131
142
- // 3 . Augment user query with the document contextual data.
143
- Query augmentedQuery = this .queryAugmentor .augment (transformedQuery , documents );
132
+ // 5 . Augment user query with the document contextual data.
133
+ Query augmentedQuery = this .queryAugmenter .augment (originalQuery , documents );
144
134
135
+ // 6. Update advised request with augmented prompt.
145
136
return AdvisedRequest .from (request ).withUserText (augmentedQuery .text ()).withAdviseContext (context ).build ();
146
137
}
147
138
148
- private AdvisedResponse after (AdvisedResponse advisedResponse ) {
139
+ /**
140
+ * Processes a single query by routing it to document retrievers and collecting
141
+ * documents.
142
+ */
143
+ private Map .Entry <Query , List <List <Document >>> getDocumentsForQuery (Query query ) {
144
+ List <DocumentRetriever > retrievers = this .queryRouter .route (query );
145
+ List <List <Document >> documents = retrievers .stream ()
146
+ .map (retriever -> CompletableFuture .supplyAsync (() -> retriever .retrieve (query ), this .taskExecutor ))
147
+ .toList ()
148
+ .stream ()
149
+ .map (CompletableFuture ::join )
150
+ .toList ();
151
+ return Map .entry (query , documents );
152
+ }
153
+
154
+ @ Override
155
+ public AdvisedResponse after (AdvisedResponse advisedResponse ) {
149
156
ChatResponse .Builder chatResponseBuilder ;
150
157
if (advisedResponse .response () == null ) {
151
158
chatResponseBuilder = ChatResponse .builder ();
@@ -157,66 +164,91 @@ private AdvisedResponse after(AdvisedResponse advisedResponse) {
157
164
return new AdvisedResponse (chatResponseBuilder .build (), advisedResponse .adviseContext ());
158
165
}
159
166
160
- private Predicate <AdvisedResponse > onFinishReason () {
161
- return advisedResponse -> {
162
- ChatResponse chatResponse = advisedResponse .response ();
163
- return chatResponse != null && chatResponse .getResults () != null
164
- && chatResponse .getResults ()
165
- .stream ()
166
- .anyMatch (result -> result != null && result .getMetadata () != null
167
- && StringUtils .hasText (result .getMetadata ().getFinishReason ()));
168
- };
169
- }
170
-
171
167
@ Override
172
- public String getName () {
173
- return this .getClass (). getSimpleName () ;
168
+ public Scheduler getScheduler () {
169
+ return this .scheduler ;
174
170
}
175
171
176
172
@ Override
177
173
public int getOrder () {
178
174
return this .order ;
179
175
}
180
176
177
+ private static TaskExecutor buildDefaultTaskExecutor () {
178
+ ThreadPoolTaskExecutor taskExecutor = new ThreadPoolTaskExecutor ();
179
+ taskExecutor .setThreadNamePrefix ("ai-advisor-" );
180
+ taskExecutor .setCorePoolSize (4 );
181
+ taskExecutor .setMaxPoolSize (16 );
182
+ taskExecutor .setTaskDecorator (new ContextPropagatingTaskDecorator ());
183
+ taskExecutor .initialize ();
184
+ return taskExecutor ;
185
+ }
186
+
181
187
public static final class Builder {
182
188
183
- private final List <QueryTransformer > queryTransformers = new ArrayList <>() ;
189
+ private List <QueryTransformer > queryTransformers ;
184
190
185
- private DocumentRetriever documentRetriever ;
191
+ private QueryExpander queryExpander ;
186
192
187
- private QueryAugmentor queryAugmentor ;
193
+ private QueryRouter queryRouter ;
188
194
189
- private Boolean protectFromBlocking ;
195
+ private DocumentJoiner documentJoiner ;
196
+
197
+ private QueryAugmenter queryAugmenter ;
198
+
199
+ private TaskExecutor taskExecutor ;
200
+
201
+ private Scheduler scheduler ;
190
202
191
203
private Integer order ;
192
204
193
205
private Builder () {
194
206
}
195
207
196
208
public Builder queryTransformers (List <QueryTransformer > queryTransformers ) {
197
- Assert .notNull (queryTransformers , "queryTransformers cannot be null" );
198
- this .queryTransformers .addAll (queryTransformers );
209
+ this .queryTransformers = queryTransformers ;
199
210
return this ;
200
211
}
201
212
202
213
public Builder queryTransformers (QueryTransformer ... queryTransformers ) {
203
- Assert .notNull (queryTransformers , "queryTransformers cannot be null" );
204
- this .queryTransformers .addAll (Arrays .asList (queryTransformers ));
214
+ this .queryTransformers = Arrays .asList (queryTransformers );
215
+ return this ;
216
+ }
217
+
218
+ public Builder queryExpander (QueryExpander queryExpander ) {
219
+ this .queryExpander = queryExpander ;
220
+ return this ;
221
+ }
222
+
223
+ public Builder queryRouter (QueryRouter queryRouter ) {
224
+ Assert .isNull (this .queryRouter , "Cannot set both documentRetriever and queryRouter" );
225
+ this .queryRouter = queryRouter ;
205
226
return this ;
206
227
}
207
228
208
229
public Builder documentRetriever (DocumentRetriever documentRetriever ) {
209
- this .documentRetriever = documentRetriever ;
230
+ Assert .isNull (this .queryRouter , "Cannot set both documentRetriever and queryRouter" );
231
+ this .queryRouter = AllRetrieversQueryRouter .builder ().documentRetrievers (documentRetriever ).build ();
232
+ return this ;
233
+ }
234
+
235
+ public Builder documentJoiner (DocumentJoiner documentJoiner ) {
236
+ this .documentJoiner = documentJoiner ;
237
+ return this ;
238
+ }
239
+
240
+ public Builder queryAugmenter (QueryAugmenter queryAugmenter ) {
241
+ this .queryAugmenter = queryAugmenter ;
210
242
return this ;
211
243
}
212
244
213
- public Builder queryAugmentor ( QueryAugmentor queryAugmentor ) {
214
- this .queryAugmentor = queryAugmentor ;
245
+ public Builder taskExecutor ( TaskExecutor taskExecutor ) {
246
+ this .taskExecutor = taskExecutor ;
215
247
return this ;
216
248
}
217
249
218
- public Builder protectFromBlocking ( Boolean protectFromBlocking ) {
219
- this .protectFromBlocking = protectFromBlocking ;
250
+ public Builder scheduler ( Scheduler scheduler ) {
251
+ this .scheduler = scheduler ;
220
252
return this ;
221
253
}
222
254
@@ -226,8 +258,8 @@ public Builder order(Integer order) {
226
258
}
227
259
228
260
public RetrievalAugmentationAdvisor build () {
229
- return new RetrievalAugmentationAdvisor (this .queryTransformers , this .documentRetriever , this .queryAugmentor ,
230
- this .protectFromBlocking , this .order );
261
+ return new RetrievalAugmentationAdvisor (this .queryTransformers , this .queryExpander , this .queryRouter ,
262
+ this .documentJoiner , this . queryAugmenter , this . taskExecutor , this . scheduler , this .order );
231
263
}
232
264
233
265
}
0 commit comments