Skip to content

Commit 008a760

Browse files
committed
Add dataSource() method to JdbcChatMemoryRepository.Builder and improve dialect detection
- Added a dataSource(DataSource) method to the builder - Now, if no dialect is specified, the implementation will attempt to detect it from the effective DataSource (either set directly or obtained from the JdbcTemplate). - Updated builder logic to default to the DataSource from the JdbcTemplate if dataSource() is not called, ensuring dialect detection works out-of-the-box. - The builder now prefers a directly provided DataSource, but remains backwards compatible with JdbcTemplate-based configuration. - Added tests Fixes #3148 Signed-off-by: Mark Pollack <mark.pollack@broadcom.com>
1 parent 31da8f3 commit 008a760

File tree

2 files changed

+295
-7
lines changed

2 files changed

+295
-7
lines changed

memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/main/java/org/springframework/ai/chat/memory/repository/jdbc/JdbcChatMemoryRepository.java

Lines changed: 69 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
import java.util.List;
2323
import java.util.concurrent.atomic.AtomicLong;
2424

25+
import javax.sql.DataSource;
26+
2527
import org.springframework.ai.chat.memory.ChatMemoryRepository;
2628
import org.springframework.ai.chat.messages.AssistantMessage;
2729
import org.springframework.ai.chat.messages.Message;
@@ -35,8 +37,15 @@
3537
import org.springframework.jdbc.datasource.DataSourceTransactionManager;
3638
import org.springframework.lang.Nullable;
3739
import org.springframework.transaction.PlatformTransactionManager;
40+
import org.springframework.transaction.TransactionDefinition;
41+
import org.springframework.transaction.TransactionException;
42+
import org.springframework.transaction.TransactionStatus;
43+
import org.springframework.transaction.support.AbstractPlatformTransactionManager;
44+
import org.springframework.transaction.support.DefaultTransactionStatus;
3845
import org.springframework.transaction.support.TransactionTemplate;
3946
import org.springframework.util.Assert;
47+
import org.slf4j.Logger;
48+
import org.slf4j.LoggerFactory;
4049

4150
/**
4251
* An implementation of {@link ChatMemoryRepository} for JDBC.
@@ -55,14 +64,16 @@ public class JdbcChatMemoryRepository implements ChatMemoryRepository {
5564

5665
private final JdbcChatMemoryRepositoryDialect dialect;
5766

58-
private JdbcChatMemoryRepository(JdbcTemplate jdbcTemplate, JdbcChatMemoryRepositoryDialect dialect,
67+
private static final Logger logger = LoggerFactory.getLogger(JdbcChatMemoryRepository.class);
68+
69+
private JdbcChatMemoryRepository(DataSource dataSource, JdbcChatMemoryRepositoryDialect dialect,
5970
PlatformTransactionManager txManager) {
60-
Assert.notNull(jdbcTemplate, "jdbcTemplate cannot be null");
71+
Assert.notNull(dataSource, "dataSource cannot be null");
6172
Assert.notNull(dialect, "dialect cannot be null");
62-
this.jdbcTemplate = jdbcTemplate;
73+
this.jdbcTemplate = new JdbcTemplate(dataSource);
6374
this.dialect = dialect;
6475
this.transactionTemplate = new TransactionTemplate(
65-
txManager != null ? txManager : new DataSourceTransactionManager(jdbcTemplate.getDataSource()));
76+
txManager != null ? txManager : new DataSourceTransactionManager(dataSource));
6677
}
6778

6879
@Override
@@ -157,8 +168,12 @@ public static class Builder {
157168

158169
private JdbcChatMemoryRepositoryDialect dialect;
159170

171+
private DataSource dataSource;
172+
160173
private PlatformTransactionManager platformTransactionManager;
161174

175+
private static final Logger logger = LoggerFactory.getLogger(Builder.class);
176+
162177
private Builder() {
163178
}
164179

@@ -172,15 +187,62 @@ public Builder dialect(JdbcChatMemoryRepositoryDialect dialect) {
172187
return this;
173188
}
174189

190+
public Builder dataSource(DataSource dataSource) {
191+
this.dataSource = dataSource;
192+
return this;
193+
}
194+
175195
public Builder transactionManager(PlatformTransactionManager txManager) {
176196
this.platformTransactionManager = txManager;
177197
return this;
178198
}
179199

180200
public JdbcChatMemoryRepository build() {
181-
if (this.dialect == null)
182-
throw new IllegalStateException("Dialect must be set");
183-
return new JdbcChatMemoryRepository(this.jdbcTemplate, this.dialect, this.platformTransactionManager);
201+
DataSource effectiveDataSource = resolveDataSource();
202+
JdbcChatMemoryRepositoryDialect effectiveDialect = resolveDialect(effectiveDataSource);
203+
return new JdbcChatMemoryRepository(effectiveDataSource, effectiveDialect, this.platformTransactionManager);
204+
}
205+
206+
private DataSource resolveDataSource() {
207+
if (this.dataSource != null) {
208+
return this.dataSource;
209+
}
210+
if (this.jdbcTemplate != null && this.jdbcTemplate.getDataSource() != null) {
211+
return this.jdbcTemplate.getDataSource();
212+
}
213+
throw new IllegalArgumentException("DataSource must be set (either via dataSource() or jdbcTemplate())");
214+
}
215+
216+
private JdbcChatMemoryRepositoryDialect resolveDialect(DataSource dataSource) {
217+
if (this.dialect == null) {
218+
try {
219+
return JdbcChatMemoryRepositoryDialect.from(dataSource);
220+
}
221+
catch (Exception ex) {
222+
throw new IllegalStateException("Could not detect dialect from datasource", ex);
223+
}
224+
}
225+
else {
226+
warnIfDialectMismatch(dataSource, this.dialect);
227+
return this.dialect;
228+
}
229+
}
230+
231+
/**
232+
* Logs a warning if the explicitly set dialect differs from the dialect detected
233+
* from the DataSource.
234+
*/
235+
private void warnIfDialectMismatch(DataSource dataSource, JdbcChatMemoryRepositoryDialect explicitDialect) {
236+
try {
237+
JdbcChatMemoryRepositoryDialect detected = JdbcChatMemoryRepositoryDialect.from(dataSource);
238+
if (!detected.getClass().equals(explicitDialect.getClass())) {
239+
logger.warn("Explicitly set dialect {} will be used instead of detected dialect {} from datasource",
240+
explicitDialect.getClass().getSimpleName(), detected.getClass().getSimpleName());
241+
}
242+
}
243+
catch (Exception ex) {
244+
logger.debug("Could not detect dialect from datasource", ex);
245+
}
184246
}
185247

186248
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
/*
2+
* Copyright 2023-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.chat.memory.repository.jdbc;
18+
19+
import java.sql.Connection;
20+
import java.sql.DatabaseMetaData;
21+
import java.sql.SQLException;
22+
import javax.sql.DataSource;
23+
24+
import org.junit.jupiter.api.Test;
25+
26+
import org.springframework.transaction.PlatformTransactionManager;
27+
28+
import static org.assertj.core.api.Assertions.assertThat;
29+
import static org.assertj.core.api.Assertions.assertThatThrownBy;
30+
import static org.mockito.Mockito.mock;
31+
import static org.mockito.Mockito.when;
32+
33+
/**
34+
* Tests for {@link JdbcChatMemoryRepository.Builder}.
35+
*
36+
* @author Mark Pollack
37+
*/
38+
public class JdbcChatMemoryRepositoryBuilderTests {
39+
40+
@Test
41+
void testBuilderWithExplicitDialect() {
42+
DataSource dataSource = mock(DataSource.class);
43+
JdbcChatMemoryRepositoryDialect dialect = mock(JdbcChatMemoryRepositoryDialect.class);
44+
45+
JdbcChatMemoryRepository repository = JdbcChatMemoryRepository.builder()
46+
.dataSource(dataSource)
47+
.dialect(dialect)
48+
.build();
49+
50+
assertThat(repository).isNotNull();
51+
}
52+
53+
@Test
54+
void testBuilderWithExplicitDialectAndTransactionManager() {
55+
DataSource dataSource = mock(DataSource.class);
56+
JdbcChatMemoryRepositoryDialect dialect = mock(JdbcChatMemoryRepositoryDialect.class);
57+
PlatformTransactionManager txManager = mock(PlatformTransactionManager.class);
58+
59+
JdbcChatMemoryRepository repository = JdbcChatMemoryRepository.builder()
60+
.dataSource(dataSource)
61+
.dialect(dialect)
62+
.transactionManager(txManager)
63+
.build();
64+
65+
assertThat(repository).isNotNull();
66+
}
67+
68+
@Test
69+
void testBuilderWithDialectFromDataSource() throws SQLException {
70+
// Setup mocks
71+
DataSource dataSource = mock(DataSource.class);
72+
Connection connection = mock(Connection.class);
73+
DatabaseMetaData metaData = mock(DatabaseMetaData.class);
74+
75+
when(dataSource.getConnection()).thenReturn(connection);
76+
when(connection.getMetaData()).thenReturn(metaData);
77+
when(metaData.getURL()).thenReturn("jdbc:postgresql://localhost:5432/testdb");
78+
79+
// Test with dialect from datasource
80+
JdbcChatMemoryRepository repository = JdbcChatMemoryRepository.builder().dataSource(dataSource).build();
81+
82+
assertThat(repository).isNotNull();
83+
}
84+
85+
@Test
86+
void testBuilderWithMysqlDialectFromDataSource() throws SQLException {
87+
// Setup mocks for MySQL
88+
DataSource dataSource = mock(DataSource.class);
89+
Connection connection = mock(Connection.class);
90+
DatabaseMetaData metaData = mock(DatabaseMetaData.class);
91+
92+
when(dataSource.getConnection()).thenReturn(connection);
93+
when(connection.getMetaData()).thenReturn(metaData);
94+
when(metaData.getURL()).thenReturn("jdbc:mysql://localhost:3306/testdb");
95+
96+
// Test with dialect from datasource
97+
JdbcChatMemoryRepository repository = JdbcChatMemoryRepository.builder().dataSource(dataSource).build();
98+
99+
assertThat(repository).isNotNull();
100+
}
101+
102+
@Test
103+
void testBuilderWithSqlServerDialectFromDataSource() throws SQLException {
104+
// Setup mocks for SQL Server
105+
DataSource dataSource = mock(DataSource.class);
106+
Connection connection = mock(Connection.class);
107+
DatabaseMetaData metaData = mock(DatabaseMetaData.class);
108+
109+
when(dataSource.getConnection()).thenReturn(connection);
110+
when(connection.getMetaData()).thenReturn(metaData);
111+
when(metaData.getURL()).thenReturn("jdbc:sqlserver://localhost:1433;databaseName=testdb");
112+
113+
// Test with dialect from datasource
114+
JdbcChatMemoryRepository repository = JdbcChatMemoryRepository.builder().dataSource(dataSource).build();
115+
116+
assertThat(repository).isNotNull();
117+
}
118+
119+
@Test
120+
void testBuilderWithHsqldbDialectFromDataSource() throws SQLException {
121+
// Setup mocks for HSQLDB
122+
DataSource dataSource = mock(DataSource.class);
123+
Connection connection = mock(Connection.class);
124+
DatabaseMetaData metaData = mock(DatabaseMetaData.class);
125+
126+
when(dataSource.getConnection()).thenReturn(connection);
127+
when(connection.getMetaData()).thenReturn(metaData);
128+
when(metaData.getURL()).thenReturn("jdbc:hsqldb:mem:testdb");
129+
130+
// Test with dialect from datasource
131+
JdbcChatMemoryRepository repository = JdbcChatMemoryRepository.builder().dataSource(dataSource).build();
132+
133+
assertThat(repository).isNotNull();
134+
}
135+
136+
@Test
137+
void testBuilderWithUnknownDialectFromDataSource() throws SQLException {
138+
// Setup mocks for unknown database
139+
DataSource dataSource = mock(DataSource.class);
140+
Connection connection = mock(Connection.class);
141+
DatabaseMetaData metaData = mock(DatabaseMetaData.class);
142+
143+
when(dataSource.getConnection()).thenReturn(connection);
144+
when(connection.getMetaData()).thenReturn(metaData);
145+
when(metaData.getURL()).thenReturn("jdbc:unknown://localhost:1234/testdb");
146+
147+
// Test with dialect from datasource - should default to PostgreSQL
148+
JdbcChatMemoryRepository repository = JdbcChatMemoryRepository.builder().dataSource(dataSource).build();
149+
150+
assertThat(repository).isNotNull();
151+
}
152+
153+
@Test
154+
void testBuilderWithExceptionInDataSourceConnection() throws SQLException {
155+
// Setup mocks with exception
156+
DataSource dataSource = mock(DataSource.class);
157+
when(dataSource.getConnection()).thenThrow(new SQLException("Connection failed"));
158+
159+
// Test with dialect from datasource - should default to PostgreSQL
160+
JdbcChatMemoryRepository repository = JdbcChatMemoryRepository.builder().dataSource(dataSource).build();
161+
162+
assertThat(repository).isNotNull();
163+
}
164+
165+
@Test
166+
void testBuilderWithNullDataSource() {
167+
assertThatThrownBy(() -> JdbcChatMemoryRepository.builder().build())
168+
.isInstanceOf(IllegalArgumentException.class)
169+
.hasMessage("DataSource must be set (either via dataSource() or jdbcTemplate())");
170+
}
171+
172+
@Test
173+
void testBuilderWithNullDataSourceButExplicitDialect() {
174+
DataSource dataSource = mock(DataSource.class);
175+
JdbcChatMemoryRepositoryDialect dialect = mock(JdbcChatMemoryRepositoryDialect.class);
176+
177+
// Should work because dialect is explicitly set
178+
JdbcChatMemoryRepository repository = JdbcChatMemoryRepository.builder()
179+
.dataSource(dataSource)
180+
.dialect(dialect)
181+
.build();
182+
183+
assertThat(repository).isNotNull();
184+
}
185+
186+
@Test
187+
void testBuilderWithNullDataSourceAndDialect() {
188+
assertThatThrownBy(() -> JdbcChatMemoryRepository.builder().build())
189+
.isInstanceOf(IllegalArgumentException.class)
190+
.hasMessage("DataSource must be set (either via dataSource() or jdbcTemplate())");
191+
}
192+
193+
/**
194+
* Verifies that when an explicit dialect is provided to the builder, it takes
195+
* precedence over any dialect detected from the DataSource. If the explicit dialect
196+
* differs from the detected one, the explicit dialect is used and a warning is
197+
* logged. This ensures that user intent (explicit configuration) always overrides
198+
* automatic detection.
199+
*/
200+
@Test
201+
void testBuilderPreferenceForExplicitDialect() throws SQLException {
202+
// Setup mocks for PostgreSQL
203+
DataSource dataSource = mock(DataSource.class);
204+
Connection connection = mock(Connection.class);
205+
DatabaseMetaData metaData = mock(DatabaseMetaData.class);
206+
207+
when(dataSource.getConnection()).thenReturn(connection);
208+
when(connection.getMetaData()).thenReturn(metaData);
209+
when(metaData.getURL()).thenReturn("jdbc:postgresql://localhost:5432/testdb");
210+
211+
// Create an explicit MySQL dialect
212+
JdbcChatMemoryRepositoryDialect mysqlDialect = new MysqlChatMemoryRepositoryDialect();
213+
214+
// Test with explicit dialect - should use MySQL dialect even though PostgreSQL is
215+
// detected
216+
JdbcChatMemoryRepository repository = JdbcChatMemoryRepository.builder()
217+
.dataSource(dataSource)
218+
.dialect(mysqlDialect)
219+
.build();
220+
221+
assertThat(repository).isNotNull();
222+
// Verify warning was logged (would need to use a logging framework test utility
223+
// for this)
224+
}
225+
226+
}

0 commit comments

Comments
 (0)