1
1
use num_traits:: Zero ;
2
2
use thiserror:: Error ;
3
3
4
- #[ derive( Debug , Error ) ]
4
+ #[ derive( Debug , Error , PartialEq ) ]
5
5
pub enum ParseVectorError {
6
6
#[ error( "The input string is empty." ) ]
7
7
EmptyString { } ,
@@ -89,12 +89,15 @@ where
89
89
90
90
#[ derive( PartialEq ) ]
91
91
enum ParseState {
92
- Number ,
92
+ Start ,
93
+ Index ,
94
+ Value ,
93
95
Comma ,
94
96
Colon ,
95
- Start ,
97
+ End ,
96
98
}
97
99
100
+ // Index -> Colon -> Value -> Comma
98
101
#[ inline( always) ]
99
102
pub fn parse_pgvector_svector < T : Zero + Clone , F > (
100
103
input : & [ u8 ] ,
@@ -157,24 +160,59 @@ where
157
160
let mut indexes = Vec :: < u32 > :: new ( ) ;
158
161
let mut values = Vec :: < T > :: new ( ) ;
159
162
let mut index: u32 = u32:: MAX ;
160
- let mut state = ParseState :: Start ;
161
163
162
- for position in left + 1 ..right {
163
- let c = input[ position] ;
164
- match c {
165
- b'0' ..=b'9' | b'a' ..=b'z' | b'A' ..=b'Z' | b'.' | b'+' | b'-' => {
166
- if token. is_empty ( ) {
167
- token. push ( b'$' ) ;
168
- }
169
- if token. try_push ( c) . is_err ( ) {
170
- return Err ( ParseVectorError :: TooLongNumber { position } ) ;
164
+ let mut state = ParseState :: Start ;
165
+ let mut position = left;
166
+ loop {
167
+ if position == right {
168
+ let end_with_number = state == ParseState :: Value && !token. is_empty ( ) ;
169
+ let end_with_comma = state == ParseState :: Index && token. is_empty ( ) ;
170
+ if end_with_number || end_with_comma {
171
+ state = ParseState :: End ;
172
+ } else {
173
+ return Err ( ParseVectorError :: BadCharacter { position } ) ;
174
+ }
175
+ }
176
+ match state {
177
+ ParseState :: Index => {
178
+ let c = input[ position] ;
179
+ match c {
180
+ b'0' ..=b'9' | b'a' ..=b'z' | b'A' ..=b'Z' | b'.' | b'+' | b'-' => {
181
+ if token. is_empty ( ) {
182
+ token. push ( b'$' ) ;
183
+ }
184
+ if token. try_push ( c) . is_err ( ) {
185
+ return Err ( ParseVectorError :: TooLongNumber { position } ) ;
186
+ }
187
+ position += 1 ;
188
+ }
189
+ b':' => {
190
+ state = ParseState :: Colon ;
191
+ }
192
+ b' ' => position += 1 ,
193
+ _ => return Err ( ParseVectorError :: BadCharacter { position } ) ,
171
194
}
172
- state = ParseState :: Number ;
173
195
}
174
- b',' => {
175
- if state != ParseState :: Number {
176
- return Err ( ParseVectorError :: BadCharacter { position } ) ;
196
+ ParseState :: Value => {
197
+ let c = input[ position] ;
198
+ match c {
199
+ b'0' ..=b'9' | b'a' ..=b'z' | b'A' ..=b'Z' | b'.' | b'+' | b'-' => {
200
+ if token. is_empty ( ) {
201
+ token. push ( b'$' ) ;
202
+ }
203
+ if token. try_push ( c) . is_err ( ) {
204
+ return Err ( ParseVectorError :: TooLongNumber { position } ) ;
205
+ }
206
+ position += 1 ;
207
+ }
208
+ b',' => {
209
+ state = ParseState :: Comma ;
210
+ }
211
+ b' ' => position += 1 ,
212
+ _ => return Err ( ParseVectorError :: BadCharacter { position } ) ,
177
213
}
214
+ }
215
+ e @ ( ParseState :: Comma | ParseState :: End ) => {
178
216
if !token. is_empty ( ) {
179
217
// Safety: all bytes in `token` are ascii characters
180
218
let s = unsafe { std:: str:: from_utf8_unchecked ( & token[ 1 ..] ) } ;
@@ -191,60 +229,44 @@ where
191
229
}
192
230
index = u32:: MAX ;
193
231
token. clear ( ) ;
194
- state = ParseState :: Comma ;
195
- } else {
232
+ } else if e != ParseState :: End {
196
233
return Err ( ParseVectorError :: TooShortNumber { position } ) ;
197
234
}
235
+ if e == ParseState :: End {
236
+ break ;
237
+ } else {
238
+ state = ParseState :: Index ;
239
+ position += 1 ;
240
+ }
198
241
}
199
- b':' => {
242
+ ParseState :: Colon => {
200
243
if !token. is_empty ( ) {
201
244
// Safety: all bytes in `token` are ascii characters
202
245
let s = unsafe { std:: str:: from_utf8_unchecked ( & token[ 1 ..] ) } ;
203
246
index = s
204
247
. parse :: < u32 > ( )
205
248
. map_err ( |_| ParseVectorError :: BadParsing { position } ) ?;
206
249
token. clear ( ) ;
207
- state = ParseState :: Colon ;
208
250
} else {
209
251
return Err ( ParseVectorError :: TooShortNumber { position } ) ;
210
252
}
253
+ state = ParseState :: Value ;
254
+ position += 1 ;
255
+ }
256
+ ParseState :: Start => {
257
+ state = ParseState :: Index ;
258
+ position += 1 ;
211
259
}
212
- b' ' => ( ) ,
213
- _ => return Err ( ParseVectorError :: BadCharacter { position } ) ,
214
- }
215
- }
216
- // A valid case is either
217
- // - empty string: ""
218
- // - end with number when a index is extracted:"1:2, 3:4"
219
- if state != ParseState :: Start && ( state != ParseState :: Number || index == u32:: MAX ) {
220
- return Err ( ParseVectorError :: BadCharacter { position : right } ) ;
221
- }
222
- if !token. is_empty ( ) {
223
- let position = right;
224
- // Safety: all bytes in `token` are ascii characters
225
- let s = unsafe { std:: str:: from_utf8_unchecked ( & token[ 1 ..] ) } ;
226
- let num = f ( s) . ok_or ( ParseVectorError :: BadParsing { position } ) ?;
227
- if index as usize >= dims {
228
- return Err ( ParseVectorError :: OutOfBound {
229
- dims,
230
- index : index as usize ,
231
- } ) ;
232
- }
233
- if !num. is_zero ( ) {
234
- indexes. push ( index) ;
235
- values. push ( num) ;
236
260
}
237
- token. clear ( ) ;
238
261
}
239
- // sort values and indexes ascend by indexes
240
262
let mut indices = ( 0 ..indexes. len ( ) ) . collect :: < Vec < _ > > ( ) ;
241
263
indices. sort_by_key ( |& i| & indexes[ i] ) ;
242
- let sortedValues : Vec < T > = indices
264
+ let sorted_values : Vec < T > = indices
243
265
. iter ( )
244
266
. map ( |i| values. get ( * i) . unwrap ( ) . clone ( ) )
245
267
. collect ( ) ;
246
268
indexes. sort ( ) ;
247
- Ok ( ( indexes, sortedValues , dims) )
269
+ Ok ( ( indexes, sorted_values , dims) )
248
270
}
249
271
250
272
#[ cfg( test) ]
@@ -260,6 +282,10 @@ mod tests {
260
282
let exprs: HashMap < & str , ( Vec < u32 > , Vec < F32 > , usize ) > = HashMap :: from ( [
261
283
( "{}/1" , ( vec ! [ ] , vec ! [ ] , 1 ) ) ,
262
284
( "{0:1}/1" , ( vec ! [ 0 ] , vec ! [ F32 ( 1.0 ) ] , 1 ) ) ,
285
+ (
286
+ "{0:1, 1:-2, }/2" ,
287
+ ( vec ! [ 0 , 1 ] , vec ! [ F32 ( 1.0 ) , F32 ( -2.0 ) ] , 2 ) ,
288
+ ) ,
263
289
( "{0:1, 1:1.5}/2" , ( vec ! [ 0 , 1 ] , vec ! [ F32 ( 1.0 ) , F32 ( 1.5 ) ] , 2 ) ) ,
264
290
(
265
291
"{0:+3, 2:-4.1}/3" ,
@@ -280,27 +306,47 @@ mod tests {
280
306
281
307
#[ test]
282
308
fn test_svector_parse_reject ( ) {
283
- let exprs: Vec < & str > = vec ! [
284
- "{" ,
285
- "}" ,
286
- "{:" ,
287
- ":}" ,
288
- "{0:1, 1:1.5}/1" ,
289
- "{0:0, 1:0, 2:0}/2" ,
290
- "{0:1, 1:2, 2:3}" ,
291
- "{0:1, 1:2, 2:3" ,
292
- "{0:1, 1:2}/" ,
293
- "{0}/5" ,
294
- "{0:}/5" ,
295
- "{:0}/5" ,
296
- "{0:, 1:2}/5" ,
297
- "{0:1, 1}/5" ,
298
- "/2" ,
299
- "{}/1/2" ,
300
- ] ;
301
- for e in exprs {
309
+ let exprs: HashMap < & str , ParseVectorError > = HashMap :: from ( [
310
+ ( "{" , ParseVectorError :: BadParentheses { character : '{' } ) ,
311
+ ( "}" , ParseVectorError :: BadParentheses { character : '{' } ) ,
312
+ ( "{:" , ParseVectorError :: BadCharacter { position : 1 } ) ,
313
+ ( ":}" , ParseVectorError :: BadCharacter { position : 0 } ) ,
314
+ (
315
+ "{0:1, 1:1.5}/1" ,
316
+ ParseVectorError :: OutOfBound { dims : 1 , index : 1 } ,
317
+ ) ,
318
+ (
319
+ "{0:0, 1:0, 2:0}/2" ,
320
+ ParseVectorError :: OutOfBound { dims : 2 , index : 2 } ,
321
+ ) ,
322
+ (
323
+ "{0:1, 1:2, 2:3}" ,
324
+ ParseVectorError :: BadCharacter { position : 15 } ,
325
+ ) ,
326
+ (
327
+ "{0:1, 1:2, 2:3" ,
328
+ ParseVectorError :: BadCharacter { position : 12 } ,
329
+ ) ,
330
+ ( "{0:1, 1:2}/" , ParseVectorError :: BadParsing { position : 10 } ) ,
331
+ ( "{0}/5" , ParseVectorError :: BadCharacter { position : 2 } ) ,
332
+ ( "{0:}/5" , ParseVectorError :: BadCharacter { position : 3 } ) ,
333
+ ( "{:0}/5" , ParseVectorError :: TooShortNumber { position : 1 } ) ,
334
+ (
335
+ "{0:, 1:2}/5" ,
336
+ ParseVectorError :: TooShortNumber { position : 3 } ,
337
+ ) ,
338
+ ( "{0:1, 1}/5" , ParseVectorError :: BadCharacter { position : 7 } ) ,
339
+ ( "/2" , ParseVectorError :: BadCharacter { position : 0 } ) ,
340
+ ( "{}/1/2" , ParseVectorError :: BadCharacter { position : 2 } ) ,
341
+ (
342
+ "{1,2,3,4}/5" ,
343
+ ParseVectorError :: BadCharacter { position : 2 } ,
344
+ ) ,
345
+ ] ) ;
346
+ for ( e, err) in exprs {
302
347
let ret = parse_pgvector_svector ( e. as_bytes ( ) , |s| s. parse :: < F32 > ( ) . ok ( ) ) ;
303
- assert ! ( ret. is_err( ) , "at expr {e}" )
348
+ assert ! ( ret. is_err( ) , "at expr {e}" ) ;
349
+ assert_eq ! ( ret. unwrap_err( ) , err, "at expr {e}" ) ;
304
350
}
305
351
}
306
352
}
0 commit comments