Skip to content

Commit dfb8bf6

Browse files
committed
Add a new abstraction to simplify implementation of common ChatBot use cases
* Add ChatBot and basic DefaultChatBot * Add streaming ChatBot support. * Add Evaluator interface and RelevancyEvaluator implementation * Add Content data type abstraction for Document and Message * Renaming and package refactoring * update .gitignore to allow node package name * Add List<Media> to node and move ai.transformer package to ai.prompt.transformer * Add Short/Long term memory support. * Add mixing transformers support Docs TBD
1 parent 012a2ad commit dfb8bf6

File tree

54 files changed

+3375
-98
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+3375
-98
lines changed

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,4 +36,6 @@ package.json
3636
.vscode
3737
.antlr
3838

39-
shell.log
39+
shell.log
40+
41+
.profiler

.mvn/extensions.xml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
<?xml version="1.0" encoding="UTF-8"?>
2+
<extensions>
3+
<extension>
4+
<groupId>fr.jcgay.maven</groupId>
5+
<artifactId>maven-profiler</artifactId>
6+
<version>3.2</version>
7+
</extension>
8+
</extensions>

models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/client/ClientIT.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ void helloWorldCompletion() {
5757
}
5858
```""";
5959
assertThat(chatResponse.getResult().getOutput().getContent()).isEqualTo(expectedResponse);
60-
assertThat(chatResponse.getResult().getOutput().getProperties()).containsKey("generated_tokens");
61-
assertThat(chatResponse.getResult().getOutput().getProperties()).containsEntry("generated_tokens", 39);
60+
assertThat(chatResponse.getResult().getOutput().getMetadata()).containsKey("generated_tokens");
61+
assertThat(chatResponse.getResult().getOutput().getMetadata()).containsEntry("generated_tokens", 39);
6262

6363
}
6464

models/spring-ai-openai/pom.xml

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
<?xml version="1.0" encoding="UTF-8"?>
2-
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
2+
<project xmlns="http://maven.apache.org/POM/4.0.0"
3+
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
34
<modelVersion>4.0.0</modelVersion>
45
<parent>
56
<groupId>org.springframework.ai</groupId>
@@ -74,6 +75,38 @@
7475
<scope>test</scope>
7576
</dependency>
7677

78+
<dependency>
79+
<groupId>org.springframework.ai</groupId>
80+
<artifactId>spring-ai-qdrant</artifactId>
81+
<version>${project.version}</version>
82+
<exclusions>
83+
<exclusion>
84+
<groupId>org.springframework.ai</groupId>
85+
<artifactId>spring-ai-openai</artifactId>
86+
</exclusion>
87+
</exclusions>
88+
<scope>test</scope>
89+
</dependency>
90+
91+
<dependency>
92+
<groupId>org.testcontainers</groupId>
93+
<artifactId>qdrant</artifactId>
94+
<scope>test</scope>
95+
</dependency>
96+
97+
<dependency>
98+
<groupId>org.testcontainers</groupId>
99+
<artifactId>testcontainers</artifactId>
100+
<scope>test</scope>
101+
</dependency>
102+
103+
<dependency>
104+
<groupId>org.testcontainers</groupId>
105+
<artifactId>junit-jupiter</artifactId>
106+
<scope>test</scope>
107+
</dependency>
108+
109+
77110
</dependencies>
78111

79112
</project>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
/*
2+
* Copyright 2024 - 2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.openai.chat.chatbot;
18+
19+
import java.util.List;
20+
21+
import io.qdrant.client.QdrantClient;
22+
import io.qdrant.client.QdrantGrpcClient;
23+
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
24+
import org.springframework.ai.chat.chatbot.ChatBot;
25+
import org.testcontainers.junit.jupiter.Container;
26+
import org.testcontainers.junit.jupiter.Testcontainers;
27+
import org.testcontainers.qdrant.QdrantContainer;
28+
29+
import org.springframework.ai.chat.chatbot.DefaultChatBot;
30+
import org.springframework.ai.chat.chatbot.DefaultStreamingChatBot;
31+
import org.springframework.ai.chat.chatbot.StreamingChatBot;
32+
import org.springframework.ai.chat.history.VectorStoreChatMemoryAgentListener;
33+
import org.springframework.ai.chat.history.VectorStoreChatMemoryRetriever;
34+
import org.springframework.ai.chat.history.LastMaxTokenSizeContentTransformer;
35+
import org.springframework.ai.chat.history.SystemPromptChatMemoryAugmentor;
36+
import org.springframework.ai.embedding.EmbeddingClient;
37+
import org.springframework.ai.evaluation.BaseMemoryTest;
38+
import org.springframework.ai.evaluation.RelevancyEvaluator;
39+
import org.springframework.ai.openai.OpenAiChatClient;
40+
import org.springframework.ai.openai.OpenAiEmbeddingClient;
41+
import org.springframework.ai.openai.api.OpenAiApi;
42+
import org.springframework.ai.tokenizer.JTokkitTokenCountEstimator;
43+
import org.springframework.ai.tokenizer.TokenCountEstimator;
44+
import org.springframework.ai.vectorstore.VectorStore;
45+
import org.springframework.ai.vectorstore.qdrant.QdrantVectorStore;
46+
import org.springframework.beans.factory.annotation.Autowired;
47+
import org.springframework.boot.SpringBootConfiguration;
48+
import org.springframework.boot.test.context.SpringBootTest;
49+
import org.springframework.context.annotation.Bean;
50+
51+
@Testcontainers
52+
@SpringBootTest(classes = ChatMemoryLongTermSystemPromptIT.Config.class)
53+
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
54+
public class ChatMemoryLongTermSystemPromptIT extends BaseMemoryTest {
55+
56+
private static final String COLLECTION_NAME = "test_collection";
57+
58+
private static final int QDRANT_GRPC_PORT = 6334;
59+
60+
@Container
61+
static QdrantContainer qdrantContainer = new QdrantContainer("qdrant/qdrant:v1.7.4");
62+
63+
@Autowired
64+
public ChatMemoryLongTermSystemPromptIT(RelevancyEvaluator relevancyEvaluator, ChatBot chatBot,
65+
StreamingChatBot streamingChatBot) {
66+
super(relevancyEvaluator, chatBot, streamingChatBot);
67+
}
68+
69+
@SpringBootConfiguration
70+
static class Config {
71+
72+
@Bean
73+
public OpenAiApi chatCompletionApi() {
74+
return new OpenAiApi(System.getenv("OPENAI_API_KEY"));
75+
}
76+
77+
@Bean
78+
public OpenAiChatClient openAiClient(OpenAiApi openAiApi) {
79+
return new OpenAiChatClient(openAiApi);
80+
}
81+
82+
@Bean
83+
public EmbeddingClient embeddingClient(OpenAiApi openAiApi) {
84+
return new OpenAiEmbeddingClient(openAiApi);
85+
}
86+
87+
@Bean
88+
public VectorStore qdrantVectorStore(EmbeddingClient embeddingClient) {
89+
QdrantClient qdrantClient = new QdrantClient(QdrantGrpcClient
90+
.newBuilder(qdrantContainer.getHost(), qdrantContainer.getMappedPort(QDRANT_GRPC_PORT), false)
91+
.build());
92+
return new QdrantVectorStore(qdrantClient, COLLECTION_NAME, embeddingClient);
93+
}
94+
95+
@Bean
96+
public TokenCountEstimator tokenCountEstimator() {
97+
return new JTokkitTokenCountEstimator();
98+
}
99+
100+
@Bean
101+
public ChatBot memoryChatAgent(OpenAiChatClient chatClient, VectorStore vectorStore,
102+
TokenCountEstimator tokenCountEstimator) {
103+
104+
return DefaultChatBot.builder(chatClient)
105+
.withRetrievers(List.of(new VectorStoreChatMemoryRetriever(vectorStore, 10)))
106+
.withContentPostProcessors(List.of(new LastMaxTokenSizeContentTransformer(tokenCountEstimator, 1000)))
107+
.withAugmentors(List.of(new SystemPromptChatMemoryAugmentor()))
108+
.withChatAgentListeners(List.of(new VectorStoreChatMemoryAgentListener(vectorStore)))
109+
.build();
110+
}
111+
112+
@Bean
113+
public StreamingChatBot memoryStreamingChatAgent(OpenAiChatClient streamingChatClient, VectorStore vectorStore,
114+
TokenCountEstimator tokenCountEstimator) {
115+
116+
return DefaultStreamingChatBot.builder(streamingChatClient)
117+
.withRetrievers(List.of(new VectorStoreChatMemoryRetriever(vectorStore, 10)))
118+
.withDocumentPostProcessors(List.of(new LastMaxTokenSizeContentTransformer(tokenCountEstimator, 1000)))
119+
.withAugmentors(List.of(new SystemPromptChatMemoryAugmentor()))
120+
.withChatAgentListeners(List.of(new VectorStoreChatMemoryAgentListener(vectorStore)))
121+
.build();
122+
}
123+
124+
@Bean
125+
public RelevancyEvaluator relevancyEvaluator(OpenAiChatClient chatClient) {
126+
return new RelevancyEvaluator(chatClient);
127+
}
128+
129+
}
130+
131+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
/*
2+
* Copyright 2024 - 2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.ai.openai.chat.chatbot;
17+
18+
import java.util.List;
19+
20+
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
21+
22+
import org.springframework.ai.chat.chatbot.ChatBot;
23+
import org.springframework.ai.chat.chatbot.DefaultChatBot;
24+
import org.springframework.ai.chat.chatbot.DefaultStreamingChatBot;
25+
import org.springframework.ai.chat.chatbot.StreamingChatBot;
26+
import org.springframework.ai.chat.history.ChatMemory;
27+
import org.springframework.ai.chat.history.ChatMemoryAgentListener;
28+
import org.springframework.ai.chat.history.ChatMemoryRetriever;
29+
import org.springframework.ai.chat.history.InMemoryChatMemory;
30+
import org.springframework.ai.chat.history.LastMaxTokenSizeContentTransformer;
31+
import org.springframework.ai.chat.history.MessageChatMemoryAugmentor;
32+
import org.springframework.ai.evaluation.BaseMemoryTest;
33+
import org.springframework.ai.evaluation.RelevancyEvaluator;
34+
import org.springframework.ai.openai.OpenAiChatClient;
35+
import org.springframework.ai.openai.api.OpenAiApi;
36+
import org.springframework.ai.tokenizer.JTokkitTokenCountEstimator;
37+
import org.springframework.ai.tokenizer.TokenCountEstimator;
38+
import org.springframework.beans.factory.annotation.Autowired;
39+
import org.springframework.boot.SpringBootConfiguration;
40+
import org.springframework.boot.test.context.SpringBootTest;
41+
import org.springframework.context.annotation.Bean;
42+
43+
@SpringBootTest(classes = ChatMemoryShortTermMessageListIT.Config.class)
44+
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
45+
public class ChatMemoryShortTermMessageListIT extends BaseMemoryTest {
46+
47+
@Autowired
48+
public ChatMemoryShortTermMessageListIT(RelevancyEvaluator relevancyEvaluator, ChatBot chatBot,
49+
StreamingChatBot streamingChatBot) {
50+
super(relevancyEvaluator, chatBot, streamingChatBot);
51+
}
52+
53+
@SpringBootConfiguration
54+
static class Config {
55+
56+
@Bean
57+
public OpenAiApi chatCompletionApi() {
58+
return new OpenAiApi(System.getenv("OPENAI_API_KEY"));
59+
}
60+
61+
@Bean
62+
public OpenAiChatClient openAiClient(OpenAiApi openAiApi) {
63+
return new OpenAiChatClient(openAiApi);
64+
}
65+
66+
@Bean
67+
public ChatMemory chatHistory() {
68+
return new InMemoryChatMemory();
69+
}
70+
71+
@Bean
72+
public TokenCountEstimator tokenCountEstimator() {
73+
return new JTokkitTokenCountEstimator();
74+
}
75+
76+
@Bean
77+
public ChatBot memoryChatAgent(OpenAiChatClient chatClient, ChatMemory chatHistory,
78+
TokenCountEstimator tokenCountEstimator) {
79+
80+
return DefaultChatBot.builder(chatClient)
81+
.withRetrievers(List.of(new ChatMemoryRetriever(chatHistory)))
82+
.withContentPostProcessors(List.of(new LastMaxTokenSizeContentTransformer(tokenCountEstimator, 1000)))
83+
.withAugmentors(List.of(new MessageChatMemoryAugmentor()))
84+
.withChatAgentListeners(List.of(new ChatMemoryAgentListener(chatHistory)))
85+
.build();
86+
}
87+
88+
@Bean
89+
public StreamingChatBot memoryStreamingChatAgent(OpenAiChatClient streamingChatClient, ChatMemory chatHistory,
90+
TokenCountEstimator tokenCountEstimator) {
91+
92+
return DefaultStreamingChatBot.builder(streamingChatClient)
93+
.withRetrievers(List.of(new ChatMemoryRetriever(chatHistory)))
94+
.withDocumentPostProcessors(List.of(new LastMaxTokenSizeContentTransformer(tokenCountEstimator, 1000)))
95+
.withAugmentors(List.of(new MessageChatMemoryAugmentor()))
96+
.withChatAgentListeners(List.of(new ChatMemoryAgentListener(chatHistory)))
97+
.build();
98+
}
99+
100+
@Bean
101+
public RelevancyEvaluator relevancyEvaluator(OpenAiChatClient chatClient) {
102+
return new RelevancyEvaluator(chatClient);
103+
}
104+
105+
}
106+
107+
}

0 commit comments

Comments
 (0)