@@ -15,6 +15,8 @@ pub enum ParseVectorError {
15
15
TooShortNumber { position : usize } ,
16
16
#[ error( "Bad parsing at position {position}" ) ]
17
17
BadParsing { position : usize } ,
18
+ #[ error( "Index out of bounds: the dim is {dims} but the index is {index}" ) ]
19
+ OutOfBound { dims : usize , index : usize } ,
18
20
}
19
21
20
22
#[ inline( always) ]
85
87
Ok ( vector)
86
88
}
87
89
90
+ #[ derive( PartialEq ) ]
91
+ enum ParseState {
92
+ Number ,
93
+ Comma ,
94
+ Colon ,
95
+ Start ,
96
+ }
97
+
88
98
#[ inline( always) ]
89
99
pub fn parse_pgvector_svector < T : Zero + Clone , F > (
90
100
input : & [ u8 ] ,
@@ -136,7 +146,9 @@ where
136
146
} ;
137
147
let mut indexes = Vec :: < u32 > :: new ( ) ;
138
148
let mut values = Vec :: < T > :: new ( ) ;
139
- let mut index: u32 = 0 ;
149
+ let mut index: u32 = u32:: MAX ;
150
+ let mut state = ParseState :: Start ;
151
+
140
152
for position in left + 1 ..right {
141
153
let c = input[ position] ;
142
154
match c {
@@ -147,15 +159,29 @@ where
147
159
if token. try_push ( c) . is_err ( ) {
148
160
return Err ( ParseVectorError :: TooLongNumber { position } ) ;
149
161
}
162
+ state = ParseState :: Number ;
150
163
}
151
164
b',' => {
165
+ if state != ParseState :: Number {
166
+ return Err ( ParseVectorError :: BadCharacter { position } ) ;
167
+ }
152
168
if !token. is_empty ( ) {
153
169
// Safety: all bytes in `token` are ascii characters
154
170
let s = unsafe { std:: str:: from_utf8_unchecked ( & token[ 1 ..] ) } ;
155
171
let num = f ( s) . ok_or ( ParseVectorError :: BadParsing { position } ) ?;
156
- indexes. push ( index) ;
157
- values. push ( num) ;
172
+ if index as usize >= dims {
173
+ return Err ( ParseVectorError :: OutOfBound {
174
+ dims,
175
+ index : index as usize ,
176
+ } ) ;
177
+ }
178
+ if !num. is_zero ( ) {
179
+ indexes. push ( index) ;
180
+ values. push ( num) ;
181
+ }
182
+ index = u32:: MAX ;
158
183
token. clear ( ) ;
184
+ state = ParseState :: Comma ;
159
185
} else {
160
186
return Err ( ParseVectorError :: TooShortNumber { position } ) ;
161
187
}
@@ -168,6 +194,7 @@ where
168
194
. parse :: < u32 > ( )
169
195
. map_err ( |_| ParseVectorError :: BadParsing { position } ) ?;
170
196
token. clear ( ) ;
197
+ state = ParseState :: Colon ;
171
198
} else {
172
199
return Err ( ParseVectorError :: TooShortNumber { position } ) ;
173
200
}
@@ -176,14 +203,90 @@ where
176
203
_ => return Err ( ParseVectorError :: BadCharacter { position } ) ,
177
204
}
178
205
}
206
+ if state != ParseState :: Start && ( state != ParseState :: Number || index == u32:: MAX ) {
207
+ return Err ( ParseVectorError :: BadCharacter { position : right } ) ;
208
+ }
179
209
if !token. is_empty ( ) {
180
210
let position = right;
181
211
// Safety: all bytes in `token` are ascii characters
182
212
let s = unsafe { std:: str:: from_utf8_unchecked ( & token[ 1 ..] ) } ;
183
213
let num = f ( s) . ok_or ( ParseVectorError :: BadParsing { position } ) ?;
184
- indexes. push ( index) ;
185
- values. push ( num) ;
214
+ if index as usize >= dims {
215
+ return Err ( ParseVectorError :: OutOfBound {
216
+ dims,
217
+ index : index as usize ,
218
+ } ) ;
219
+ }
220
+ if !num. is_zero ( ) {
221
+ indexes. push ( index) ;
222
+ values. push ( num) ;
223
+ }
186
224
token. clear ( ) ;
187
225
}
188
- Ok ( ( indexes, values, dims) )
226
+ // sort values and indexes ascend by indexes
227
+ let mut indices = ( 0 ..indexes. len ( ) ) . collect :: < Vec < _ > > ( ) ;
228
+ indices. sort_by_key ( |& i| & indexes[ i] ) ;
229
+ let sortedValues: Vec < T > = indices
230
+ . iter ( )
231
+ . map ( |i| values. get ( * i) . unwrap ( ) . clone ( ) )
232
+ . collect ( ) ;
233
+ indexes. sort ( ) ;
234
+ Ok ( ( indexes, sortedValues, dims) )
235
+ }
236
+
237
+ #[ cfg( test) ]
238
+ mod tests {
239
+ use std:: collections:: HashMap ;
240
+
241
+ use base:: scalar:: F32 ;
242
+
243
+ use super :: * ;
244
+
245
+ #[ test]
246
+ fn test_svector_parse_accept ( ) {
247
+ let exprs: HashMap < & str , ( Vec < u32 > , Vec < F32 > , usize ) > = HashMap :: from ( [
248
+ ( "{}/1" , ( vec ! [ ] , vec ! [ ] , 1 ) ) ,
249
+ ( "{0:1}/1" , ( vec ! [ 0 ] , vec ! [ F32 ( 1.0 ) ] , 1 ) ) ,
250
+ ( "{0:1, 1:1.5}/2" , ( vec ! [ 0 , 1 ] , vec ! [ F32 ( 1.0 ) , F32 ( 1.5 ) ] , 2 ) ) ,
251
+ (
252
+ "{0:+3, 2:-4.1}/3" ,
253
+ ( vec ! [ 0 , 2 ] , vec ! [ F32 ( 3.0 ) , F32 ( -4.1 ) ] , 3 ) ,
254
+ ) ,
255
+ ( "{0:0, 1:0, 2:0}/3" , ( vec ! [ ] , vec ! [ ] , 3 ) ) ,
256
+ (
257
+ "{3:3, 2:2, 1:1, 0:0}/4" ,
258
+ ( vec ! [ 1 , 2 , 3 ] , vec ! [ F32 ( 1.0 ) , F32 ( 2.0 ) , F32 ( 3.0 ) ] , 4 ) ,
259
+ ) ,
260
+ ] ) ;
261
+ for ( e, ans) in exprs {
262
+ let ret = parse_pgvector_svector ( e. as_bytes ( ) , |s| s. parse :: < F32 > ( ) . ok ( ) ) ;
263
+ assert ! ( ret. is_ok( ) , "at expr {e}" ) ;
264
+ assert_eq ! ( ret. unwrap( ) , ans, "at expr {e}" ) ;
265
+ }
266
+ }
267
+
268
+ #[ test]
269
+ fn test_svector_parse_reject ( ) {
270
+ let exprs: Vec < & str > = vec ! [
271
+ "{" ,
272
+ "}" ,
273
+ "{:" ,
274
+ ":}" ,
275
+ "{0:1, 1:1.5}/1" ,
276
+ "{0:0, 1:0, 2:0}/2" ,
277
+ "{0:1, 1:2, 2:3}" ,
278
+ "{0:1, 1:2, 2:3" ,
279
+ "{0:1, 1:2}/" ,
280
+ "{0}/5" ,
281
+ "{0:}/5" ,
282
+ "{:0}/5" ,
283
+ "{0:, 1:2}/5" ,
284
+ "{0:1, 1}/5" ,
285
+ "/2" ,
286
+ ] ;
287
+ for e in exprs {
288
+ let ret = parse_pgvector_svector ( e. as_bytes ( ) , |s| s. parse :: < F32 > ( ) . ok ( ) ) ;
289
+ assert ! ( ret. is_err( ) , "at expr {e}" )
290
+ }
291
+ }
189
292
}
0 commit comments