1
+ /*
2
+ * Copyright 2023-2024 the original author or authors.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * https://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
1
16
package org .springframework .ai .postgresml ;
2
17
3
18
import java .sql .Array ;
7
22
import java .util .List ;
8
23
import java .util .Map ;
9
24
10
- import com .fasterxml .jackson .core .JsonProcessingException ;
11
- import com .fasterxml .jackson .databind .ObjectMapper ;
12
-
13
25
import org .springframework .ai .document .Document ;
14
26
import org .springframework .ai .document .MetadataMode ;
15
27
import org .springframework .ai .embedding .AbstractEmbeddingClient ;
18
30
import org .springframework .ai .embedding .EmbeddingRequest ;
19
31
import org .springframework .ai .embedding .EmbeddingResponse ;
20
32
import org .springframework .ai .embedding .EmbeddingResponseMetadata ;
33
+ import org .springframework .ai .model .ModelOptionsUtils ;
21
34
import org .springframework .beans .factory .InitializingBean ;
22
35
import org .springframework .jdbc .core .JdbcTemplate ;
23
36
import org .springframework .jdbc .core .RowMapper ;
29
42
* <a href="https://postgresml.org">PostgresML</a> EmbeddingClient
30
43
*
31
44
* @author Toshiaki Maki
45
+ * @author Christian Tzolov
32
46
*/
33
47
public class PostgresMlEmbeddingClient extends AbstractEmbeddingClient implements InitializingBean {
34
48
35
- private final JdbcTemplate jdbcTemplate ;
49
+ public static final String DEFAULT_TRANSFORMER_MODEL = "distilbert-base-uncased" ;
36
50
37
- private final String transformer ;
51
+ private final PostgresMlEmbeddingOptions defaultOptions ;
38
52
39
- private final VectorType vectorType ;
40
-
41
- private final String kwargs ;
42
-
43
- private final MetadataMode metadataMode ;
53
+ private final JdbcTemplate jdbcTemplate ;
44
54
45
55
public enum VectorType {
46
56
47
57
PG_ARRAY ("" , null , (rs , i ) -> {
48
58
Array embedding = rs .getArray ("embedding" );
49
59
return Arrays .stream ((Float []) embedding .getArray ()).map (Float ::doubleValue ).toList ();
50
- }), PG_VECTOR ("::vector" , "vector" , (rs , i ) -> {
60
+ }),
61
+
62
+ PG_VECTOR ("::vector" , "vector" , (rs , i ) -> {
51
63
String embedding = rs .getString ("embedding" );
52
64
return Arrays .stream ((embedding .substring (1 , embedding .length () - 1 )
53
65
/* remove leading '[' and trailing ']' */ .split ("," ))).map (Double ::parseDouble ).toList ();
@@ -72,35 +84,57 @@ public enum VectorType {
72
84
* @param jdbcTemplate JdbcTemplate
73
85
*/
74
86
public PostgresMlEmbeddingClient (JdbcTemplate jdbcTemplate ) {
75
- this (jdbcTemplate , "distilbert-base-uncased" );
87
+ this (jdbcTemplate , PostgresMlEmbeddingOptions .builder ().build ());
88
+ }
89
+
90
+ /**
91
+ * a PostgresMlEmbeddingClient constructor
92
+ * @param jdbcTemplate JdbcTemplate to use to interact with the database.
93
+ * @param options PostgresMlEmbeddingOptions to configure the client.
94
+ */
95
+ public PostgresMlEmbeddingClient (JdbcTemplate jdbcTemplate , PostgresMlEmbeddingOptions options ) {
96
+ Assert .notNull (jdbcTemplate , "jdbc template must not be null." );
97
+ Assert .notNull (options , "options must not be null." );
98
+ Assert .notNull (options .getTransformer (), "transformer must not be null." );
99
+ Assert .notNull (options .getVectorType (), "vectorType must not be null." );
100
+ Assert .notNull (options .getKwargs (), "kwargs must not be null." );
101
+ Assert .notNull (options .getMetadataMode (), "metadataMode must not be null." );
102
+
103
+ this .jdbcTemplate = jdbcTemplate ;
104
+ this .defaultOptions = options ;
76
105
}
77
106
78
107
/**
79
108
* a constructor
80
109
* @param jdbcTemplate JdbcTemplate
81
110
* @param transformer huggingface sentence-transformer name
82
111
*/
112
+ @ Deprecated (since = "0.8.0" , forRemoval = true )
83
113
public PostgresMlEmbeddingClient (JdbcTemplate jdbcTemplate , String transformer ) {
84
114
this (jdbcTemplate , transformer , VectorType .PG_ARRAY );
85
115
}
86
116
87
117
/**
88
118
* a constructor
119
+ * @deprecated Use the constructor with {@link PostgresMlEmbeddingOptions} instead.
89
120
* @param jdbcTemplate JdbcTemplate
90
121
* @param transformer huggingface sentence-transformer name
91
122
* @param vectorType vector type in PostgreSQL
92
123
*/
124
+ @ Deprecated (since = "0.8.0" , forRemoval = true )
93
125
public PostgresMlEmbeddingClient (JdbcTemplate jdbcTemplate , String transformer , VectorType vectorType ) {
94
126
this (jdbcTemplate , transformer , vectorType , Map .of (), MetadataMode .EMBED );
95
127
}
96
128
97
129
/**
98
- * a constructor
130
+ * a constructor * @deprecated Use the constructor with
131
+ * {@link PostgresMlEmbeddingOptions} instead.
99
132
* @param jdbcTemplate JdbcTemplate
100
133
* @param transformer huggingface sentence-transformer name
101
134
* @param vectorType vector type in PostgreSQL
102
135
* @param kwargs optional arguments
103
136
*/
137
+ @ Deprecated (since = "0.8.0" , forRemoval = true )
104
138
public PostgresMlEmbeddingClient (JdbcTemplate jdbcTemplate , String transformer , VectorType vectorType ,
105
139
Map <String , Object > kwargs , MetadataMode metadataMode ) {
106
140
Assert .notNull (jdbcTemplate , "jdbc template must not be null." );
@@ -110,73 +144,93 @@ public PostgresMlEmbeddingClient(JdbcTemplate jdbcTemplate, String transformer,
110
144
Assert .notNull (metadataMode , "metadataMode must not be null." );
111
145
112
146
this .jdbcTemplate = jdbcTemplate ;
113
- this .transformer = transformer ;
114
- this .vectorType = vectorType ;
115
- this .metadataMode = metadataMode ;
116
- try {
117
- this .kwargs = new ObjectMapper ().writeValueAsString (kwargs );
118
- }
119
- catch (JsonProcessingException e ) {
120
- throw new IllegalArgumentException (e );
121
- }
147
+
148
+ this .defaultOptions = PostgresMlEmbeddingOptions .builder ()
149
+ .withTransformer (transformer )
150
+ .withVectorType (vectorType )
151
+ .withMetadataMode (metadataMode )
152
+ .withKwargs (ModelOptionsUtils .toJsonString (kwargs ))
153
+ .build ();
122
154
}
123
155
156
+ @ SuppressWarnings ("null" )
124
157
@ Override
125
158
public List <Double > embed (String text ) {
126
159
return this .jdbcTemplate .queryForObject (
127
- "SELECT pgml.embed(?, ?, ?::JSONB)" + this .vectorType .cast + " AS embedding" , this .vectorType .rowMapper ,
128
- this .transformer , text , this .kwargs );
160
+ "SELECT pgml.embed(?, ?, ?::JSONB)" + this .defaultOptions .getVectorType ().cast + " AS embedding" ,
161
+ this .defaultOptions .getVectorType ().rowMapper , this .defaultOptions .getTransformer (), text ,
162
+ this .defaultOptions .getKwargs ());
129
163
}
130
164
131
165
@ Override
132
166
public List <Double > embed (Document document ) {
133
- return this .embed (document .getFormattedContent (this .metadataMode ));
167
+ return this .embed (document .getFormattedContent (this .defaultOptions . getMetadataMode () ));
134
168
}
135
169
170
+ @ SuppressWarnings ("null" )
136
171
@ Override
137
- public List <List <Double >> embed (List <String > texts ) {
138
- if (CollectionUtils .isEmpty (texts )) {
139
- return List .of ();
140
- }
141
- return this .jdbcTemplate .query (connection -> {
142
- PreparedStatement preparedStatement = connection .prepareStatement ("SELECT pgml.embed(?, text, ?::JSONB)"
143
- + vectorType .cast + " AS embedding FROM (SELECT unnest(?) AS text) AS texts" );
144
- preparedStatement .setString (1 , transformer );
145
- preparedStatement .setString (2 , kwargs );
146
- preparedStatement .setArray (3 , connection .createArrayOf ("TEXT" , texts .toArray (Object []::new )));
147
- return preparedStatement ;
148
- }, rs -> {
149
- List <List <Double >> result = new ArrayList <>();
150
- while (rs .next ()) {
151
- result .add (vectorType .rowMapper .mapRow (rs , -1 ));
152
- }
153
- return result ;
154
- });
155
- }
172
+ public EmbeddingResponse call (EmbeddingRequest request ) {
156
173
157
- @ Override
158
- public EmbeddingResponse embedForResponse (List <String > texts ) {
159
- return this .call (new EmbeddingRequest (texts , EmbeddingOptions .EMPTY ));
160
- }
174
+ final PostgresMlEmbeddingOptions optionsToUse = this .mergeOptions (request .getOptions ());
161
175
162
- @ Override
163
- public EmbeddingResponse call (EmbeddingRequest request ) {
164
176
List <Embedding > data = new ArrayList <>();
165
- List <List <Double >> embed = this .embed (request .getInstructions ());
166
- for (int i = 0 ; i < embed .size (); i ++) {
167
- data .add (new Embedding (embed .get (i ), i ));
177
+ List <List <Double >> embed = List .of ();
178
+
179
+ List <String > texts = request .getInstructions ();
180
+ if (!CollectionUtils .isEmpty (texts )) {
181
+ embed = this .jdbcTemplate .query (connection -> {
182
+ PreparedStatement preparedStatement = connection .prepareStatement ("SELECT pgml.embed(?, text, ?::JSONB)"
183
+ + optionsToUse .getVectorType ().cast + " AS embedding FROM (SELECT unnest(?) AS text) AS texts" );
184
+ preparedStatement .setString (1 , optionsToUse .getTransformer ());
185
+ preparedStatement .setString (2 , ModelOptionsUtils .toJsonString (optionsToUse .getKwargs ()));
186
+ preparedStatement .setArray (3 , connection .createArrayOf ("TEXT" , texts .toArray (Object []::new )));
187
+ return preparedStatement ;
188
+ }, rs -> {
189
+ List <List <Double >> result = new ArrayList <>();
190
+ while (rs .next ()) {
191
+ result .add (optionsToUse .getVectorType ().rowMapper .mapRow (rs , -1 ));
192
+ }
193
+ return result ;
194
+ });
168
195
}
196
+
197
+ if (!CollectionUtils .isEmpty (embed )) {
198
+ for (int i = 0 ; i < embed .size (); i ++) {
199
+ data .add (new Embedding (embed .get (i ), i ));
200
+ }
201
+ }
202
+
169
203
var metadata = new EmbeddingResponseMetadata (
170
- Map .of ("transformer" , this .transformer , "vector-type" , this .vectorType .name (), "kwargs" , this .kwargs ));
204
+ Map .of ("transformer" , optionsToUse .getTransformer (), "vector-type" , optionsToUse .getVectorType ().name (),
205
+ "kwargs" , ModelOptionsUtils .toJsonString (optionsToUse .getKwargs ())));
171
206
172
207
return new EmbeddingResponse (data , metadata );
173
208
}
174
209
210
+ /**
211
+ * Merge the default and request options.
212
+ * @param requestOptions request options to merge.
213
+ * @return the merged options.
214
+ */
215
+ PostgresMlEmbeddingOptions mergeOptions (EmbeddingOptions requestOptions ) {
216
+
217
+ PostgresMlEmbeddingOptions options = (this .defaultOptions != null ) ? this .defaultOptions
218
+ : PostgresMlEmbeddingOptions .builder ().build ();
219
+
220
+ if (requestOptions != null && !EmbeddingOptions .EMPTY .equals (requestOptions )) {
221
+ options = ModelOptionsUtils .merge (requestOptions , options , PostgresMlEmbeddingOptions .class );
222
+ }
223
+
224
+ return options ;
225
+ }
226
+
175
227
@ Override
176
228
public void afterPropertiesSet () {
177
229
this .jdbcTemplate .execute ("CREATE EXTENSION IF NOT EXISTS pgml" );
178
- if (StringUtils .hasText (this .vectorType .extensionName )) {
179
- this .jdbcTemplate .execute ("CREATE EXTENSION IF NOT EXISTS " + this .vectorType .extensionName );
230
+ this .jdbcTemplate .execute ("CREATE EXTENSION IF NOT EXISTS hstore" );
231
+ if (StringUtils .hasText (this .defaultOptions .getVectorType ().extensionName )) {
232
+ this .jdbcTemplate
233
+ .execute ("CREATE EXTENSION IF NOT EXISTS " + this .defaultOptions .getVectorType ().extensionName );
180
234
}
181
235
}
182
236
0 commit comments