Skip to content

Commit e4357ba

Browse files
ghdcksgml1ilayaperumalg
authored andcommitted
Refact onFinishReason method to utility class
Test AdvisedResponseStreamUtils Add java docs Signed-off-by: ghdcksgml1 <ghdcksgml2@naver.com>
1 parent d25d37a commit e4357ba

File tree

4 files changed

+116
-34
lines changed

4 files changed

+116
-34
lines changed

spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisor.java

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,19 +19,13 @@
1919
import java.util.HashMap;
2020
import java.util.List;
2121
import java.util.Map;
22-
import java.util.function.Predicate;
2322
import java.util.stream.Collectors;
2423

24+
import org.springframework.ai.chat.client.advisor.api.*;
2525
import reactor.core.publisher.Flux;
2626
import reactor.core.publisher.Mono;
2727
import reactor.core.scheduler.Schedulers;
2828

29-
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
30-
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
31-
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor;
32-
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
33-
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor;
34-
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain;
3529
import org.springframework.ai.chat.model.ChatResponse;
3630
import org.springframework.ai.chat.prompt.PromptTemplate;
3731
import org.springframework.ai.document.Document;
@@ -201,7 +195,7 @@ public Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, StreamA
201195
// @formatter:on
202196

203197
return advisedResponses.map(ar -> {
204-
if (onFinishReason().test(ar)) {
198+
if (AdvisedResponseStreamUtils.onFinishReason().test(ar)) {
205199
ar = after(ar);
206200
}
207201
return ar;
@@ -260,16 +254,6 @@ protected Filter.Expression doGetFilterExpression(Map<String, Object> context) {
260254

261255
}
262256

263-
private Predicate<AdvisedResponse> onFinishReason() {
264-
return advisedResponse -> advisedResponse.response()
265-
.getResults()
266-
.stream()
267-
.filter(result -> result != null && result.getMetadata() != null
268-
&& StringUtils.hasText(result.getMetadata().getFinishReason()))
269-
.findFirst()
270-
.isPresent();
271-
}
272-
273257
public static final class Builder {
274258

275259
private final VectorStore vectorStore;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package org.springframework.ai.chat.client.advisor.api;
2+
3+
import org.springframework.ai.chat.model.ChatResponse;
4+
import org.springframework.util.StringUtils;
5+
6+
import java.util.function.Predicate;
7+
8+
/**
9+
* A stream utility class to provide support methods handling {@link AdvisedResponse}.
10+
*/
11+
public final class AdvisedResponseStreamUtils {
12+
13+
/**
14+
* Returns a predicate that checks whether the provided {@link AdvisedResponse}
15+
* contains a {@link ChatResponse} with at least one result having a non-empty finish
16+
* reason in its metadata.
17+
* @return a {@link Predicate} that evaluates whether the finish reason exists within
18+
* the response metadata.
19+
*/
20+
public static Predicate<AdvisedResponse> onFinishReason() {
21+
return advisedResponse -> {
22+
ChatResponse chatResponse = advisedResponse.response();
23+
return chatResponse != null && chatResponse.getResults() != null
24+
&& chatResponse.getResults()
25+
.stream()
26+
.anyMatch(result -> result != null && result.getMetadata() != null
27+
&& StringUtils.hasText(result.getMetadata().getFinishReason()));
28+
};
29+
}
30+
31+
}

spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/BaseAdvisor.java

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,12 @@
1616

1717
package org.springframework.ai.chat.client.advisor.api;
1818

19-
import java.util.function.Predicate;
20-
2119
import reactor.core.publisher.Flux;
2220
import reactor.core.publisher.Mono;
2321
import reactor.core.scheduler.Scheduler;
2422
import reactor.core.scheduler.Schedulers;
2523

26-
import org.springframework.ai.chat.model.ChatResponse;
2724
import org.springframework.util.Assert;
28-
import org.springframework.util.StringUtils;
2925

3026
/**
3127
* Base advisor that implements common aspects of the {@link CallAroundAdvisor} and
@@ -65,24 +61,13 @@ default Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, Stream
6561
.flatMapMany(chain::nextAroundStream);
6662

6763
return advisedResponses.map(ar -> {
68-
if (onFinishReason().test(ar)) {
64+
if (AdvisedResponseStreamUtils.onFinishReason().test(ar)) {
6965
ar = after(ar);
7066
}
7167
return ar;
7268
}).onErrorResume(error -> Flux.error(new IllegalStateException("Stream processing failed", error)));
7369
}
7470

75-
private Predicate<AdvisedResponse> onFinishReason() {
76-
return advisedResponse -> {
77-
ChatResponse chatResponse = advisedResponse.response();
78-
return chatResponse != null && chatResponse.getResults() != null
79-
&& chatResponse.getResults()
80-
.stream()
81-
.anyMatch(result -> result != null && result.getMetadata() != null
82-
&& StringUtils.hasText(result.getMetadata().getFinishReason()));
83-
};
84-
}
85-
8671
@Override
8772
default String getName() {
8873
return this.getClass().getSimpleName();
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
package org.springframework.ai.chat.client.advisor.api;
2+
3+
import org.junit.jupiter.api.Nested;
4+
import org.junit.jupiter.api.Test;
5+
import org.springframework.ai.chat.messages.AssistantMessage;
6+
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
7+
import org.springframework.ai.chat.model.ChatResponse;
8+
import org.springframework.ai.chat.model.Generation;
9+
10+
import java.util.List;
11+
12+
import static org.junit.jupiter.api.Assertions.*;
13+
import static org.mockito.BDDMockito.given;
14+
import static org.mockito.Mockito.mock;
15+
16+
/**
17+
* Unit tests for {@link AdvisedResponseStreamUtils}.
18+
*
19+
* @author ghdcksgml1
20+
*/
21+
class AdvisedResponseStreamUtilsTest {
22+
23+
@Nested
24+
class OnFinishReason {
25+
26+
@Test
27+
void whenChatResponseIsNullThenReturnFalse() {
28+
AdvisedResponse response = mock(AdvisedResponse.class);
29+
given(response.response()).willReturn(null);
30+
31+
boolean result = AdvisedResponseStreamUtils.onFinishReason().test(response);
32+
33+
assertFalse(result);
34+
}
35+
36+
@Test
37+
void whenChatResponseResultsIsNullThenReturnFalse() {
38+
AdvisedResponse response = mock(AdvisedResponse.class);
39+
ChatResponse chatResponse = mock(ChatResponse.class);
40+
41+
given(chatResponse.getResults()).willReturn(null);
42+
given(response.response()).willReturn(chatResponse);
43+
44+
boolean result = AdvisedResponseStreamUtils.onFinishReason().test(response);
45+
46+
assertFalse(result);
47+
}
48+
49+
@Test
50+
void whenChatIsRunningThenReturnFalse() {
51+
AdvisedResponse response = mock(AdvisedResponse.class);
52+
ChatResponse chatResponse = mock(ChatResponse.class);
53+
54+
Generation generation = new Generation(new AssistantMessage("running.."), ChatGenerationMetadata.NULL);
55+
56+
given(chatResponse.getResults()).willReturn(List.of(generation));
57+
given(response.response()).willReturn(chatResponse);
58+
59+
boolean result = AdvisedResponseStreamUtils.onFinishReason().test(response);
60+
61+
assertFalse(result);
62+
}
63+
64+
@Test
65+
void whenChatIsStopThenReturnTrue() {
66+
AdvisedResponse response = mock(AdvisedResponse.class);
67+
ChatResponse chatResponse = mock(ChatResponse.class);
68+
69+
Generation generation = new Generation(new AssistantMessage("finish."),
70+
ChatGenerationMetadata.builder().finishReason("STOP").build());
71+
72+
given(chatResponse.getResults()).willReturn(List.of(generation));
73+
given(response.response()).willReturn(chatResponse);
74+
75+
boolean result = AdvisedResponseStreamUtils.onFinishReason().test(response);
76+
77+
assertTrue(result);
78+
}
79+
80+
}
81+
82+
}

0 commit comments

Comments
 (0)