|
21 | 21 | import java.util.List;
|
22 | 22 | import java.util.Map;
|
23 | 23 |
|
| 24 | +import reactor.core.publisher.Flux; |
| 25 | +import reactor.core.publisher.Mono; |
24 | 26 | import reactor.core.scheduler.Scheduler;
|
25 | 27 |
|
| 28 | +import org.springframework.ai.chat.client.ChatClientMessageAggregator; |
26 | 29 | import org.springframework.ai.chat.client.ChatClientRequest;
|
27 | 30 | import org.springframework.ai.chat.client.ChatClientResponse;
|
28 | 31 | import org.springframework.ai.chat.client.advisor.api.Advisor;
|
29 | 32 | import org.springframework.ai.chat.client.advisor.api.AdvisorChain;
|
30 | 33 | import org.springframework.ai.chat.client.advisor.api.BaseAdvisor;
|
31 | 34 | import org.springframework.ai.chat.client.advisor.api.BaseChatMemoryAdvisor;
|
| 35 | +import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain; |
32 | 36 | import org.springframework.ai.chat.memory.ChatMemory;
|
33 | 37 | import org.springframework.ai.chat.messages.AssistantMessage;
|
34 | 38 | import org.springframework.ai.chat.messages.Message;
|
@@ -167,6 +171,20 @@ public ChatClientResponse after(ChatClientResponse chatClientResponse, AdvisorCh
|
167 | 171 | return chatClientResponse;
|
168 | 172 | }
|
169 | 173 |
|
| 174 | + @Override |
| 175 | + public Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest, |
| 176 | + StreamAdvisorChain streamAdvisorChain) { |
| 177 | + // Get the scheduler from BaseAdvisor |
| 178 | + Scheduler scheduler = this.getScheduler(); |
| 179 | + // Process the request with the before method |
| 180 | + return Mono.just(chatClientRequest) |
| 181 | + .publishOn(scheduler) |
| 182 | + .map(request -> this.before(request, streamAdvisorChain)) |
| 183 | + .flatMapMany(streamAdvisorChain::nextStream) |
| 184 | + .transform(flux -> new ChatClientMessageAggregator().aggregateChatClientResponse(flux, |
| 185 | + response -> this.after(response, streamAdvisorChain))); |
| 186 | + } |
| 187 | + |
170 | 188 | private List<Document> toDocuments(List<Message> messages, String conversationId) {
|
171 | 189 | List<Document> docs = messages.stream()
|
172 | 190 | .filter(m -> m.getMessageType() == MessageType.USER || m.getMessageType() == MessageType.ASSISTANT)
|
|
0 commit comments