Skip to content

Commit 0e1bb52

Browse files
sunyuhan1998markpollack
authored andcommitted
fix: Added transaction support for saveAll in JdbcChatMemoryRepository
- Add optional PlatformTransactionManager - Add tests Signed-off-by: Sun Yuhan <1085481446@qq.com>
1 parent 1a395e6 commit 0e1bb52

File tree

2 files changed

+76
-16
lines changed

2 files changed

+76
-16
lines changed

memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/main/java/org/springframework/ai/chat/memory/repository/jdbc/JdbcChatMemoryRepository.java

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,7 @@
1616

1717
package org.springframework.ai.chat.memory.repository.jdbc;
1818

19-
import java.sql.PreparedStatement;
20-
import java.sql.ResultSet;
21-
import java.sql.SQLException;
22-
import java.sql.Timestamp;
19+
import java.sql.*;
2320
import java.time.Instant;
2421
import java.util.ArrayList;
2522
import java.util.List;
@@ -35,7 +32,10 @@
3532
import org.springframework.jdbc.core.BatchPreparedStatementSetter;
3633
import org.springframework.jdbc.core.JdbcTemplate;
3734
import org.springframework.jdbc.core.RowMapper;
35+
import org.springframework.jdbc.datasource.DataSourceTransactionManager;
3836
import org.springframework.lang.Nullable;
37+
import org.springframework.transaction.PlatformTransactionManager;
38+
import org.springframework.transaction.support.TransactionTemplate;
3939
import org.springframework.util.Assert;
4040

4141
/**
@@ -51,13 +51,18 @@ public class JdbcChatMemoryRepository implements ChatMemoryRepository {
5151

5252
private final JdbcTemplate jdbcTemplate;
5353

54+
private final TransactionTemplate transactionTemplate;
55+
5456
private final JdbcChatMemoryRepositoryDialect dialect;
5557

56-
private JdbcChatMemoryRepository(JdbcTemplate jdbcTemplate, JdbcChatMemoryRepositoryDialect dialect) {
58+
private JdbcChatMemoryRepository(JdbcTemplate jdbcTemplate, JdbcChatMemoryRepositoryDialect dialect,
59+
PlatformTransactionManager txManager) {
5760
Assert.notNull(jdbcTemplate, "jdbcTemplate cannot be null");
5861
Assert.notNull(dialect, "dialect cannot be null");
5962
this.jdbcTemplate = jdbcTemplate;
6063
this.dialect = dialect;
64+
this.transactionTemplate = new TransactionTemplate(
65+
txManager != null ? txManager : new DataSourceTransactionManager(jdbcTemplate.getDataSource()));
6166
}
6267

6368
@Override
@@ -83,9 +88,13 @@ public void saveAll(String conversationId, List<Message> messages) {
8388
Assert.hasText(conversationId, "conversationId cannot be null or empty");
8489
Assert.notNull(messages, "messages cannot be null");
8590
Assert.noNullElements(messages, "messages cannot contain null elements");
86-
this.deleteByConversationId(conversationId);
87-
this.jdbcTemplate.batchUpdate(dialect.getInsertMessageSql(),
88-
new AddBatchPreparedStatement(conversationId, messages));
91+
92+
transactionTemplate.execute(status -> {
93+
deleteByConversationId(conversationId);
94+
jdbcTemplate.batchUpdate(dialect.getInsertMessageSql(),
95+
new AddBatchPreparedStatement(conversationId, messages));
96+
return null;
97+
});
8998
}
9099

91100
@Override
@@ -148,6 +157,8 @@ public static class Builder {
148157

149158
private JdbcChatMemoryRepositoryDialect dialect;
150159

160+
private PlatformTransactionManager platformTransactionManager;
161+
151162
private Builder() {
152163
}
153164

@@ -161,10 +172,15 @@ public Builder dialect(JdbcChatMemoryRepositoryDialect dialect) {
161172
return this;
162173
}
163174

175+
public Builder transactionManager(PlatformTransactionManager txManager) {
176+
this.platformTransactionManager = txManager;
177+
return this;
178+
}
179+
164180
public JdbcChatMemoryRepository build() {
165181
if (this.dialect == null)
166182
throw new IllegalStateException("Dialect must be set");
167-
return new JdbcChatMemoryRepository(this.jdbcTemplate, this.dialect);
183+
return new JdbcChatMemoryRepository(this.jdbcTemplate, this.dialect, this.platformTransactionManager);
168184
}
169185

170186
}

memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/test/java/org/springframework/ai/chat/memory/repository/jdbc/JdbcChatMemoryRepositoryPostgresqlIT.java

Lines changed: 51 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,19 @@
1616

1717
package org.springframework.ai.chat.memory.repository.jdbc;
1818

19+
import java.sql.Timestamp;
20+
import java.util.List;
21+
import java.util.UUID;
22+
import javax.sql.DataSource;
23+
1924
import org.junit.jupiter.api.Test;
2025
import org.junit.jupiter.params.ParameterizedTest;
2126
import org.junit.jupiter.params.provider.CsvSource;
27+
2228
import org.springframework.ai.chat.memory.ChatMemoryRepository;
29+
import org.springframework.ai.chat.messages.AssistantMessage;
2330
import org.springframework.ai.chat.messages.Message;
2431
import org.springframework.ai.chat.messages.MessageType;
25-
import org.springframework.ai.chat.messages.AssistantMessage;
2632
import org.springframework.ai.chat.messages.SystemMessage;
2733
import org.springframework.ai.chat.messages.UserMessage;
2834
import org.springframework.beans.factory.annotation.Autowired;
@@ -33,14 +39,10 @@
3339
import org.springframework.boot.test.context.SpringBootTest;
3440
import org.springframework.context.annotation.Bean;
3541
import org.springframework.jdbc.core.JdbcTemplate;
42+
import org.springframework.jdbc.datasource.DataSourceTransactionManager;
3643
import org.springframework.test.context.TestPropertySource;
3744
import org.springframework.test.context.jdbc.Sql;
38-
39-
import java.sql.Timestamp;
40-
import java.util.List;
41-
import java.util.UUID;
42-
43-
import javax.sql.DataSource;
45+
import org.springframework.transaction.support.TransactionTemplate;
4446

4547
import static org.assertj.core.api.Assertions.assertThat;
4648

@@ -156,6 +158,34 @@ void deleteMessagesByConversationId() {
156158
assertThat(count).isZero();
157159
}
158160

161+
@Test
162+
void repositoryWithExplicitTransactionManager() {
163+
// Get the repository with explicit transaction manager
164+
ChatMemoryRepository repositoryWithTxManager = TestConfiguration
165+
.chatMemoryRepositoryWithTransactionManager(jdbcTemplate, jdbcTemplate.getDataSource());
166+
167+
var conversationId = UUID.randomUUID().toString();
168+
var messages = List.<Message>of(new AssistantMessage("Message with transaction manager - " + conversationId),
169+
new UserMessage("User message with transaction manager - " + conversationId));
170+
171+
// Save messages using the repository with explicit transaction manager
172+
repositoryWithTxManager.saveAll(conversationId, messages);
173+
174+
// Verify messages were saved correctly
175+
var savedMessages = repositoryWithTxManager.findByConversationId(conversationId);
176+
assertThat(savedMessages).hasSize(2);
177+
assertThat(savedMessages).isEqualTo(messages);
178+
179+
// Verify transaction works by updating and checking atomicity
180+
var newMessages = List.<Message>of(new SystemMessage("New system message - " + conversationId));
181+
repositoryWithTxManager.saveAll(conversationId, newMessages);
182+
183+
// The old messages should be deleted and only the new one should exist
184+
var updatedMessages = repositoryWithTxManager.findByConversationId(conversationId);
185+
assertThat(updatedMessages).hasSize(1);
186+
assertThat(updatedMessages).isEqualTo(newMessages);
187+
}
188+
159189
@SpringBootConfiguration
160190
@ImportAutoConfiguration({ DataSourceAutoConfiguration.class, JdbcTemplateAutoConfiguration.class })
161191
static class TestConfiguration {
@@ -168,6 +198,20 @@ ChatMemoryRepository chatMemoryRepository(JdbcTemplate jdbcTemplate, DataSource
168198
.build();
169199
}
170200

201+
@Bean
202+
ChatMemoryRepository chatMemoryRepositoryWithTxManager(JdbcTemplate jdbcTemplate, DataSource dataSource) {
203+
return chatMemoryRepositoryWithTransactionManager(jdbcTemplate, dataSource);
204+
}
205+
206+
static ChatMemoryRepository chatMemoryRepositoryWithTransactionManager(JdbcTemplate jdbcTemplate,
207+
DataSource dataSource) {
208+
return JdbcChatMemoryRepository.builder()
209+
.jdbcTemplate(jdbcTemplate)
210+
.dialect(JdbcChatMemoryRepositoryDialect.from(dataSource))
211+
.transactionManager(new DataSourceTransactionManager(dataSource))
212+
.build();
213+
}
214+
171215
}
172216

173217
}

0 commit comments

Comments
 (0)