Skip to content

Commit 31da8f3

Browse files
sunyuhan1998markpollack
authored andcommitted
fix: Fixed the incorrect SQL in getSelectMessagesSql of MysqlChatMemoryRepositoryDialect
- Added tests Signed-off-by: Sun Yuhan <1085481446@qq.com>
1 parent 0e1bb52 commit 31da8f3

File tree

6 files changed

+273
-125
lines changed

6 files changed

+273
-125
lines changed

memory/repository/spring-ai-model-chat-memory-repository-jdbc/pom.xml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,13 @@
6969
<optional>true</optional>
7070
</dependency>
7171

72+
<dependency>
73+
<groupId>com.microsoft.sqlserver</groupId>
74+
<artifactId>mssql-jdbc</artifactId>
75+
<scope>test</scope>
76+
<optional>true</optional>
77+
</dependency>
78+
7279
<!-- TESTING -->
7380
<dependency>
7481
<groupId>org.springframework.boot</groupId>
@@ -94,10 +101,21 @@
94101
<scope>test</scope>
95102
</dependency>
96103

104+
<dependency>
105+
<groupId>org.testcontainers</groupId>
106+
<artifactId>mssqlserver</artifactId>
107+
<scope>test</scope>
108+
</dependency>
109+
97110
<dependency>
98111
<groupId>org.testcontainers</groupId>
99112
<artifactId>junit-jupiter</artifactId>
100113
<scope>test</scope>
101114
</dependency>
115+
<dependency>
116+
<groupId>org.testcontainers</groupId>
117+
<artifactId>mssqlserver</artifactId>
118+
<scope>test</scope>
119+
</dependency>
102120
</dependencies>
103121
</project>

memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/main/java/org/springframework/ai/chat/memory/repository/jdbc/MysqlChatMemoryRepositoryDialect.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
package org.springframework.ai.chat.memory.repository.jdbc;
1818

1919
/**
20-
* Dialect for MySQL.
20+
* MySQL dialect for chat memory repository.
2121
*
2222
* @author Mark Pollack
2323
* @since 1.0.0
@@ -26,7 +26,7 @@ public class MysqlChatMemoryRepositoryDialect implements JdbcChatMemoryRepositor
2626

2727
@Override
2828
public String getSelectMessagesSql() {
29-
return "SELECT content, type FROM SPRING_AI_CHAT_MEMORY WHERE conversation_id = ? ORDER BY `timestamp` DESC LIMIT ?";
29+
return "SELECT content, type FROM SPRING_AI_CHAT_MEMORY WHERE conversation_id = ? ORDER BY `timestamp`";
3030
}
3131

3232
@Override

memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/main/java/org/springframework/ai/chat/memory/repository/jdbc/SqlServerChatMemoryRepositoryDialect.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ public class SqlServerChatMemoryRepositoryDialect implements JdbcChatMemoryRepos
2626

2727
@Override
2828
public String getSelectMessagesSql() {
29-
return "SELECT content, type FROM SPRING_AI_CHAT_MEMORY WHERE conversation_id = ? ORDER BY [timestamp] ASC";
29+
return "SELECT content, type FROM SPRING_AI_CHAT_MEMORY WHERE conversation_id = ? ORDER BY [timestamp]";
3030
}
3131

3232
@Override
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
/*
2+
* Copyright 2023-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.chat.memory.repository.jdbc;
18+
19+
import org.junit.jupiter.api.Test;
20+
import org.junit.jupiter.params.ParameterizedTest;
21+
import org.junit.jupiter.params.provider.CsvSource;
22+
import org.springframework.ai.chat.memory.ChatMemoryRepository;
23+
import org.springframework.ai.chat.messages.Message;
24+
import org.springframework.ai.chat.messages.MessageType;
25+
import org.springframework.ai.chat.messages.AssistantMessage;
26+
import org.springframework.ai.chat.messages.SystemMessage;
27+
import org.springframework.ai.chat.messages.UserMessage;
28+
import org.springframework.beans.factory.annotation.Autowired;
29+
import org.springframework.boot.autoconfigure.ImportAutoConfiguration;
30+
import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration;
31+
import org.springframework.boot.autoconfigure.jdbc.JdbcTemplateAutoConfiguration;
32+
import org.springframework.boot.test.context.SpringBootTest;
33+
import org.springframework.context.annotation.Bean;
34+
import org.springframework.jdbc.core.JdbcTemplate;
35+
36+
import java.sql.Timestamp;
37+
import java.util.List;
38+
import java.util.Map;
39+
import java.util.UUID;
40+
import java.util.stream.Collectors;
41+
42+
import javax.sql.DataSource;
43+
44+
import static org.assertj.core.api.Assertions.assertThat;
45+
46+
/**
47+
* Base class for integration tests for {@link JdbcChatMemoryRepository}.
48+
*
49+
* @author Mark Pollack
50+
*/
51+
public abstract class AbstractJdbcChatMemoryRepositoryIT {
52+
53+
@Autowired
54+
protected ChatMemoryRepository chatMemoryRepository;
55+
56+
@Autowired
57+
protected JdbcTemplate jdbcTemplate;
58+
59+
@Test
60+
void correctChatMemoryRepositoryInstance() {
61+
assertThat(chatMemoryRepository).isInstanceOf(ChatMemoryRepository.class);
62+
}
63+
64+
@ParameterizedTest
65+
@CsvSource({ "Message from assistant,ASSISTANT", "Message from user,USER", "Message from system,SYSTEM" })
66+
void saveMessagesSingleMessage(String content, MessageType messageType) {
67+
String conversationId = UUID.randomUUID().toString();
68+
var message = switch (messageType) {
69+
case ASSISTANT -> new AssistantMessage(content + " - " + conversationId);
70+
case USER -> new UserMessage(content + " - " + conversationId);
71+
case SYSTEM -> new SystemMessage(content + " - " + conversationId);
72+
case TOOL -> throw new IllegalArgumentException("TOOL message type not supported in this test");
73+
};
74+
75+
chatMemoryRepository.saveAll(conversationId, List.of(message));
76+
77+
// Use dialect to get the appropriate SQL query
78+
JdbcChatMemoryRepositoryDialect dialect = JdbcChatMemoryRepositoryDialect.from(jdbcTemplate.getDataSource());
79+
String selectSql = dialect.getSelectMessagesSql()
80+
.replace("content, type", "conversation_id, content, type, timestamp");
81+
var result = jdbcTemplate.queryForMap(selectSql, conversationId);
82+
83+
assertThat(result.size()).isEqualTo(4);
84+
assertThat(result.get("conversation_id")).isEqualTo(conversationId);
85+
assertThat(result.get("content")).isEqualTo(message.getText());
86+
assertThat(result.get("type")).isEqualTo(messageType.name());
87+
assertThat(result.get("timestamp")).isInstanceOf(Timestamp.class);
88+
}
89+
90+
@Test
91+
void saveMessagesMultipleMessages() {
92+
String conversationId = UUID.randomUUID().toString();
93+
var messages = List.<Message>of(new AssistantMessage("Message from assistant - " + conversationId),
94+
new UserMessage("Message from user - " + conversationId),
95+
new SystemMessage("Message from system - " + conversationId));
96+
97+
chatMemoryRepository.saveAll(conversationId, messages);
98+
99+
// Use dialect to get the appropriate SQL query
100+
JdbcChatMemoryRepositoryDialect dialect = JdbcChatMemoryRepositoryDialect.from(jdbcTemplate.getDataSource());
101+
String selectSql = dialect.getSelectMessagesSql()
102+
.replace("content, type", "conversation_id, content, type, timestamp");
103+
var results = jdbcTemplate.queryForList(selectSql, conversationId);
104+
105+
assertThat(results).hasSize(messages.size());
106+
107+
for (int i = 0; i < messages.size(); i++) {
108+
var message = messages.get(i);
109+
var result = results.get(i);
110+
111+
assertThat(result.get("conversation_id")).isEqualTo(conversationId);
112+
assertThat(result.get("content")).isEqualTo(message.getText());
113+
assertThat(result.get("type")).isEqualTo(message.getMessageType().name());
114+
assertThat(result.get("timestamp")).isInstanceOf(Timestamp.class);
115+
}
116+
117+
var count = chatMemoryRepository.findByConversationId(conversationId).size();
118+
assertThat(count).isEqualTo(messages.size());
119+
120+
chatMemoryRepository.saveAll(conversationId, List.of(new UserMessage("Hello")));
121+
122+
count = chatMemoryRepository.findByConversationId(conversationId).size();
123+
assertThat(count).isEqualTo(1);
124+
}
125+
126+
@Test
127+
void findMessagesByConversationId() {
128+
var conversationId = UUID.randomUUID().toString();
129+
var messages = List.<Message>of(new AssistantMessage("Message from assistant 1 - " + conversationId),
130+
new AssistantMessage("Message from assistant 2 - " + conversationId),
131+
new UserMessage("Message from user - " + conversationId),
132+
new SystemMessage("Message from system - " + conversationId));
133+
134+
chatMemoryRepository.saveAll(conversationId, messages);
135+
136+
var results = chatMemoryRepository.findByConversationId(conversationId);
137+
138+
assertThat(results.size()).isEqualTo(messages.size());
139+
assertThat(results).isEqualTo(messages);
140+
}
141+
142+
@Test
143+
void deleteMessagesByConversationId() {
144+
var conversationId = UUID.randomUUID().toString();
145+
var messages = List.<Message>of(new AssistantMessage("Message from assistant - " + conversationId),
146+
new UserMessage("Message from user - " + conversationId),
147+
new SystemMessage("Message from system - " + conversationId));
148+
149+
chatMemoryRepository.saveAll(conversationId, messages);
150+
151+
chatMemoryRepository.deleteByConversationId(conversationId);
152+
153+
var count = jdbcTemplate.queryForObject("SELECT COUNT(*) FROM SPRING_AI_CHAT_MEMORY WHERE conversation_id = ?",
154+
Integer.class, conversationId);
155+
156+
assertThat(count).isZero();
157+
}
158+
159+
@Test
160+
void testMessageOrder() {
161+
// Create a repository using the from method to detect the dialect
162+
JdbcChatMemoryRepository repository = JdbcChatMemoryRepository.builder()
163+
.jdbcTemplate(jdbcTemplate)
164+
.dialect(JdbcChatMemoryRepositoryDialect.from(jdbcTemplate.getDataSource()))
165+
.build();
166+
167+
var conversationId = UUID.randomUUID().toString();
168+
169+
// Create messages with very distinct content to make order obvious
170+
var firstMessage = new UserMessage("1-First message");
171+
var secondMessage = new AssistantMessage("2-Second message");
172+
var thirdMessage = new UserMessage("3-Third message");
173+
var fourthMessage = new SystemMessage("4-Fourth message");
174+
175+
// Save messages in the expected order
176+
List<Message> orderedMessages = List.of(firstMessage, secondMessage, thirdMessage, fourthMessage);
177+
repository.saveAll(conversationId, orderedMessages);
178+
179+
// Retrieve messages using the repository
180+
List<Message> retrievedMessages = repository.findByConversationId(conversationId);
181+
assertThat(retrievedMessages).hasSize(4);
182+
183+
// Get the actual order from the retrieved messages
184+
List<String> retrievedContents = retrievedMessages.stream().map(Message::getText).collect(Collectors.toList());
185+
186+
// Messages should be in the original order (ASC)
187+
assertThat(retrievedContents).containsExactly("1-First message", "2-Second message", "3-Third message",
188+
"4-Fourth message");
189+
}
190+
191+
/**
192+
* Base configuration for all integration tests.
193+
*/
194+
@ImportAutoConfiguration({ DataSourceAutoConfiguration.class, JdbcTemplateAutoConfiguration.class })
195+
static abstract class BaseTestConfiguration {
196+
197+
@Bean
198+
ChatMemoryRepository chatMemoryRepository(JdbcTemplate jdbcTemplate, DataSource dataSource) {
199+
return JdbcChatMemoryRepository.builder()
200+
.jdbcTemplate(jdbcTemplate)
201+
.dialect(JdbcChatMemoryRepositoryDialect.from(dataSource))
202+
.build();
203+
}
204+
205+
}
206+
207+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
/*
2+
* Copyright 2023-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.chat.memory.repository.jdbc;
18+
19+
import org.springframework.boot.SpringBootConfiguration;
20+
import org.springframework.boot.test.context.SpringBootTest;
21+
import org.springframework.test.context.TestPropertySource;
22+
import org.springframework.test.context.jdbc.Sql;
23+
24+
/**
25+
* Integration tests for {@link JdbcChatMemoryRepository} with MySQL.
26+
*
27+
* @author Jonathan Leijendekker
28+
* @author Thomas Vitale
29+
* @author Mark Pollack
30+
*/
31+
@SpringBootTest(classes = JdbcChatMemoryRepositoryMysqlIT.TestConfiguration.class)
32+
@TestPropertySource(properties = { "spring.datasource.url=jdbc:tc:mariadb:10.3.39:///" })
33+
@Sql(scripts = "classpath:org/springframework/ai/chat/memory/repository/jdbc/schema-mariadb.sql")
34+
class JdbcChatMemoryRepositoryMysqlIT extends AbstractJdbcChatMemoryRepositoryIT {
35+
36+
@SpringBootConfiguration
37+
static class TestConfiguration extends BaseTestConfiguration {
38+
39+
}
40+
41+
}

0 commit comments

Comments
 (0)