Skip to content

Commit ed6a464

Browse files
committed
Add options support to PostgresMlEmbeddingClient
- Add postgremaddmbedding adoc page. - Auto-configuration: - add missing boot-starter. - refactor autoconf class and properties to accomodate the PostgresMlEmbeddingOptions. - PostgesMlEmbeddingClient - Add the (default) options field and remove old fields. - Implement default and request options merging. - Add tests for options and merging. - Remove redundant code. - Code style fixes.
1 parent 7b58f42 commit ed6a464

File tree

16 files changed

+757
-253
lines changed

16 files changed

+757
-253
lines changed

models/spring-ai-postgresml/src/main/java/org/springframework/ai/postgresml/PostgresMlEmbeddingClient.java

Lines changed: 110 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,18 @@
1+
/*
2+
* Copyright 2023-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
116
package org.springframework.ai.postgresml;
217

318
import java.sql.Array;
@@ -7,9 +22,6 @@
722
import java.util.List;
823
import java.util.Map;
924

10-
import com.fasterxml.jackson.core.JsonProcessingException;
11-
import com.fasterxml.jackson.databind.ObjectMapper;
12-
1325
import org.springframework.ai.document.Document;
1426
import org.springframework.ai.document.MetadataMode;
1527
import org.springframework.ai.embedding.AbstractEmbeddingClient;
@@ -18,6 +30,7 @@
1830
import org.springframework.ai.embedding.EmbeddingRequest;
1931
import org.springframework.ai.embedding.EmbeddingResponse;
2032
import org.springframework.ai.embedding.EmbeddingResponseMetadata;
33+
import org.springframework.ai.model.ModelOptionsUtils;
2134
import org.springframework.beans.factory.InitializingBean;
2235
import org.springframework.jdbc.core.JdbcTemplate;
2336
import org.springframework.jdbc.core.RowMapper;
@@ -29,25 +42,24 @@
2942
* <a href="https://postgresml.org">PostgresML</a> EmbeddingClient
3043
*
3144
* @author Toshiaki Maki
45+
* @author Christian Tzolov
3246
*/
3347
public class PostgresMlEmbeddingClient extends AbstractEmbeddingClient implements InitializingBean {
3448

35-
private final JdbcTemplate jdbcTemplate;
49+
public static final String DEFAULT_TRANSFORMER_MODEL = "distilbert-base-uncased";
3650

37-
private final String transformer;
51+
private final PostgresMlEmbeddingOptions defaultOptions;
3852

39-
private final VectorType vectorType;
40-
41-
private final String kwargs;
42-
43-
private final MetadataMode metadataMode;
53+
private final JdbcTemplate jdbcTemplate;
4454

4555
public enum VectorType {
4656

4757
PG_ARRAY("", null, (rs, i) -> {
4858
Array embedding = rs.getArray("embedding");
4959
return Arrays.stream((Float[]) embedding.getArray()).map(Float::doubleValue).toList();
50-
}), PG_VECTOR("::vector", "vector", (rs, i) -> {
60+
}),
61+
62+
PG_VECTOR("::vector", "vector", (rs, i) -> {
5163
String embedding = rs.getString("embedding");
5264
return Arrays.stream((embedding.substring(1, embedding.length() - 1)
5365
/* remove leading '[' and trailing ']' */.split(","))).map(Double::parseDouble).toList();
@@ -72,35 +84,57 @@ public enum VectorType {
7284
* @param jdbcTemplate JdbcTemplate
7385
*/
7486
public PostgresMlEmbeddingClient(JdbcTemplate jdbcTemplate) {
75-
this(jdbcTemplate, "distilbert-base-uncased");
87+
this(jdbcTemplate, PostgresMlEmbeddingOptions.builder().build());
88+
}
89+
90+
/**
91+
* a PostgresMlEmbeddingClient constructor
92+
* @param jdbcTemplate JdbcTemplate to use to interact with the database.
93+
* @param options PostgresMlEmbeddingOptions to configure the client.
94+
*/
95+
public PostgresMlEmbeddingClient(JdbcTemplate jdbcTemplate, PostgresMlEmbeddingOptions options) {
96+
Assert.notNull(jdbcTemplate, "jdbc template must not be null.");
97+
Assert.notNull(options, "options must not be null.");
98+
Assert.notNull(options.getTransformer(), "transformer must not be null.");
99+
Assert.notNull(options.getVectorType(), "vectorType must not be null.");
100+
Assert.notNull(options.getKwargs(), "kwargs must not be null.");
101+
Assert.notNull(options.getMetadataMode(), "metadataMode must not be null.");
102+
103+
this.jdbcTemplate = jdbcTemplate;
104+
this.defaultOptions = options;
76105
}
77106

78107
/**
79108
* a constructor
80109
* @param jdbcTemplate JdbcTemplate
81110
* @param transformer huggingface sentence-transformer name
82111
*/
112+
@Deprecated(since = "0.8.0", forRemoval = true)
83113
public PostgresMlEmbeddingClient(JdbcTemplate jdbcTemplate, String transformer) {
84114
this(jdbcTemplate, transformer, VectorType.PG_ARRAY);
85115
}
86116

87117
/**
88118
* a constructor
119+
* @deprecated Use the constructor with {@link PostgresMlEmbeddingOptions} instead.
89120
* @param jdbcTemplate JdbcTemplate
90121
* @param transformer huggingface sentence-transformer name
91122
* @param vectorType vector type in PostgreSQL
92123
*/
124+
@Deprecated(since = "0.8.0", forRemoval = true)
93125
public PostgresMlEmbeddingClient(JdbcTemplate jdbcTemplate, String transformer, VectorType vectorType) {
94126
this(jdbcTemplate, transformer, vectorType, Map.of(), MetadataMode.EMBED);
95127
}
96128

97129
/**
98-
* a constructor
130+
* a constructor * @deprecated Use the constructor with
131+
* {@link PostgresMlEmbeddingOptions} instead.
99132
* @param jdbcTemplate JdbcTemplate
100133
* @param transformer huggingface sentence-transformer name
101134
* @param vectorType vector type in PostgreSQL
102135
* @param kwargs optional arguments
103136
*/
137+
@Deprecated(since = "0.8.0", forRemoval = true)
104138
public PostgresMlEmbeddingClient(JdbcTemplate jdbcTemplate, String transformer, VectorType vectorType,
105139
Map<String, Object> kwargs, MetadataMode metadataMode) {
106140
Assert.notNull(jdbcTemplate, "jdbc template must not be null.");
@@ -110,73 +144,93 @@ public PostgresMlEmbeddingClient(JdbcTemplate jdbcTemplate, String transformer,
110144
Assert.notNull(metadataMode, "metadataMode must not be null.");
111145

112146
this.jdbcTemplate = jdbcTemplate;
113-
this.transformer = transformer;
114-
this.vectorType = vectorType;
115-
this.metadataMode = metadataMode;
116-
try {
117-
this.kwargs = new ObjectMapper().writeValueAsString(kwargs);
118-
}
119-
catch (JsonProcessingException e) {
120-
throw new IllegalArgumentException(e);
121-
}
147+
148+
this.defaultOptions = PostgresMlEmbeddingOptions.builder()
149+
.withTransformer(transformer)
150+
.withVectorType(vectorType)
151+
.withMetadataMode(metadataMode)
152+
.withKwargs(ModelOptionsUtils.toJsonString(kwargs))
153+
.build();
122154
}
123155

156+
@SuppressWarnings("null")
124157
@Override
125158
public List<Double> embed(String text) {
126159
return this.jdbcTemplate.queryForObject(
127-
"SELECT pgml.embed(?, ?, ?::JSONB)" + this.vectorType.cast + " AS embedding", this.vectorType.rowMapper,
128-
this.transformer, text, this.kwargs);
160+
"SELECT pgml.embed(?, ?, ?::JSONB)" + this.defaultOptions.getVectorType().cast + " AS embedding",
161+
this.defaultOptions.getVectorType().rowMapper, this.defaultOptions.getTransformer(), text,
162+
this.defaultOptions.getKwargs());
129163
}
130164

131165
@Override
132166
public List<Double> embed(Document document) {
133-
return this.embed(document.getFormattedContent(this.metadataMode));
167+
return this.embed(document.getFormattedContent(this.defaultOptions.getMetadataMode()));
134168
}
135169

170+
@SuppressWarnings("null")
136171
@Override
137-
public List<List<Double>> embed(List<String> texts) {
138-
if (CollectionUtils.isEmpty(texts)) {
139-
return List.of();
140-
}
141-
return this.jdbcTemplate.query(connection -> {
142-
PreparedStatement preparedStatement = connection.prepareStatement("SELECT pgml.embed(?, text, ?::JSONB)"
143-
+ vectorType.cast + " AS embedding FROM (SELECT unnest(?) AS text) AS texts");
144-
preparedStatement.setString(1, transformer);
145-
preparedStatement.setString(2, kwargs);
146-
preparedStatement.setArray(3, connection.createArrayOf("TEXT", texts.toArray(Object[]::new)));
147-
return preparedStatement;
148-
}, rs -> {
149-
List<List<Double>> result = new ArrayList<>();
150-
while (rs.next()) {
151-
result.add(vectorType.rowMapper.mapRow(rs, -1));
152-
}
153-
return result;
154-
});
155-
}
172+
public EmbeddingResponse call(EmbeddingRequest request) {
156173

157-
@Override
158-
public EmbeddingResponse embedForResponse(List<String> texts) {
159-
return this.call(new EmbeddingRequest(texts, EmbeddingOptions.EMPTY));
160-
}
174+
final PostgresMlEmbeddingOptions optionsToUse = this.mergeOptions(request.getOptions());
161175

162-
@Override
163-
public EmbeddingResponse call(EmbeddingRequest request) {
164176
List<Embedding> data = new ArrayList<>();
165-
List<List<Double>> embed = this.embed(request.getInstructions());
166-
for (int i = 0; i < embed.size(); i++) {
167-
data.add(new Embedding(embed.get(i), i));
177+
List<List<Double>> embed = List.of();
178+
179+
List<String> texts = request.getInstructions();
180+
if (!CollectionUtils.isEmpty(texts)) {
181+
embed = this.jdbcTemplate.query(connection -> {
182+
PreparedStatement preparedStatement = connection.prepareStatement("SELECT pgml.embed(?, text, ?::JSONB)"
183+
+ optionsToUse.getVectorType().cast + " AS embedding FROM (SELECT unnest(?) AS text) AS texts");
184+
preparedStatement.setString(1, optionsToUse.getTransformer());
185+
preparedStatement.setString(2, ModelOptionsUtils.toJsonString(optionsToUse.getKwargs()));
186+
preparedStatement.setArray(3, connection.createArrayOf("TEXT", texts.toArray(Object[]::new)));
187+
return preparedStatement;
188+
}, rs -> {
189+
List<List<Double>> result = new ArrayList<>();
190+
while (rs.next()) {
191+
result.add(optionsToUse.getVectorType().rowMapper.mapRow(rs, -1));
192+
}
193+
return result;
194+
});
168195
}
196+
197+
if (!CollectionUtils.isEmpty(embed)) {
198+
for (int i = 0; i < embed.size(); i++) {
199+
data.add(new Embedding(embed.get(i), i));
200+
}
201+
}
202+
169203
var metadata = new EmbeddingResponseMetadata(
170-
Map.of("transformer", this.transformer, "vector-type", this.vectorType.name(), "kwargs", this.kwargs));
204+
Map.of("transformer", optionsToUse.getTransformer(), "vector-type", optionsToUse.getVectorType().name(),
205+
"kwargs", ModelOptionsUtils.toJsonString(optionsToUse.getKwargs())));
171206

172207
return new EmbeddingResponse(data, metadata);
173208
}
174209

210+
/**
211+
* Merge the default and request options.
212+
* @param requestOptions request options to merge.
213+
* @return the merged options.
214+
*/
215+
PostgresMlEmbeddingOptions mergeOptions(EmbeddingOptions requestOptions) {
216+
217+
PostgresMlEmbeddingOptions options = (this.defaultOptions != null) ? this.defaultOptions
218+
: PostgresMlEmbeddingOptions.builder().build();
219+
220+
if (requestOptions != null && !EmbeddingOptions.EMPTY.equals(requestOptions)) {
221+
options = ModelOptionsUtils.merge(requestOptions, options, PostgresMlEmbeddingOptions.class);
222+
}
223+
224+
return options;
225+
}
226+
175227
@Override
176228
public void afterPropertiesSet() {
177229
this.jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS pgml");
178-
if (StringUtils.hasText(this.vectorType.extensionName)) {
179-
this.jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS " + this.vectorType.extensionName);
230+
this.jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS hstore");
231+
if (StringUtils.hasText(this.defaultOptions.getVectorType().extensionName)) {
232+
this.jdbcTemplate
233+
.execute("CREATE EXTENSION IF NOT EXISTS " + this.defaultOptions.getVectorType().extensionName);
180234
}
181235
}
182236

0 commit comments

Comments
 (0)