Skip to content

Commit 45ab3c4

Browse files
committed
refactor: use state machine
Signed-off-by: cutecutecat <junyuchen@tensorchord.ai>
1 parent 6da2c90 commit 45ab3c4

File tree

1 file changed

+115
-69
lines changed

1 file changed

+115
-69
lines changed

src/utils/parse.rs

Lines changed: 115 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use num_traits::Zero;
22
use thiserror::Error;
33

4-
#[derive(Debug, Error)]
4+
#[derive(Debug, Error, PartialEq)]
55
pub enum ParseVectorError {
66
#[error("The input string is empty.")]
77
EmptyString {},
@@ -89,12 +89,15 @@ where
8989

9090
#[derive(PartialEq)]
9191
enum ParseState {
92-
Number,
92+
Start,
93+
Index,
94+
Value,
9395
Comma,
9496
Colon,
95-
Start,
97+
End,
9698
}
9799

100+
// Index -> Colon -> Value -> Comma
98101
#[inline(always)]
99102
pub fn parse_pgvector_svector<T: Zero + Clone, F>(
100103
input: &[u8],
@@ -157,24 +160,59 @@ where
157160
let mut indexes = Vec::<u32>::new();
158161
let mut values = Vec::<T>::new();
159162
let mut index: u32 = u32::MAX;
160-
let mut state = ParseState::Start;
161163

162-
for position in left + 1..right {
163-
let c = input[position];
164-
match c {
165-
b'0'..=b'9' | b'a'..=b'z' | b'A'..=b'Z' | b'.' | b'+' | b'-' => {
166-
if token.is_empty() {
167-
token.push(b'$');
168-
}
169-
if token.try_push(c).is_err() {
170-
return Err(ParseVectorError::TooLongNumber { position });
164+
let mut state = ParseState::Start;
165+
let mut position = left;
166+
loop {
167+
if position == right {
168+
let end_with_number = state == ParseState::Value && !token.is_empty();
169+
let end_with_comma = state == ParseState::Index && token.is_empty();
170+
if end_with_number || end_with_comma {
171+
state = ParseState::End;
172+
} else {
173+
return Err(ParseVectorError::BadCharacter { position });
174+
}
175+
}
176+
match state {
177+
ParseState::Index => {
178+
let c = input[position];
179+
match c {
180+
b'0'..=b'9' | b'a'..=b'z' | b'A'..=b'Z' | b'.' | b'+' | b'-' => {
181+
if token.is_empty() {
182+
token.push(b'$');
183+
}
184+
if token.try_push(c).is_err() {
185+
return Err(ParseVectorError::TooLongNumber { position });
186+
}
187+
position += 1;
188+
}
189+
b':' => {
190+
state = ParseState::Colon;
191+
}
192+
b' ' => position += 1,
193+
_ => return Err(ParseVectorError::BadCharacter { position }),
171194
}
172-
state = ParseState::Number;
173195
}
174-
b',' => {
175-
if state != ParseState::Number {
176-
return Err(ParseVectorError::BadCharacter { position });
196+
ParseState::Value => {
197+
let c = input[position];
198+
match c {
199+
b'0'..=b'9' | b'a'..=b'z' | b'A'..=b'Z' | b'.' | b'+' | b'-' => {
200+
if token.is_empty() {
201+
token.push(b'$');
202+
}
203+
if token.try_push(c).is_err() {
204+
return Err(ParseVectorError::TooLongNumber { position });
205+
}
206+
position += 1;
207+
}
208+
b',' => {
209+
state = ParseState::Comma;
210+
}
211+
b' ' => position += 1,
212+
_ => return Err(ParseVectorError::BadCharacter { position }),
177213
}
214+
}
215+
e @ (ParseState::Comma | ParseState::End) => {
178216
if !token.is_empty() {
179217
// Safety: all bytes in `token` are ascii characters
180218
let s = unsafe { std::str::from_utf8_unchecked(&token[1..]) };
@@ -191,60 +229,44 @@ where
191229
}
192230
index = u32::MAX;
193231
token.clear();
194-
state = ParseState::Comma;
195-
} else {
232+
} else if e != ParseState::End {
196233
return Err(ParseVectorError::TooShortNumber { position });
197234
}
235+
if e == ParseState::End {
236+
break;
237+
} else {
238+
state = ParseState::Index;
239+
position += 1;
240+
}
198241
}
199-
b':' => {
242+
ParseState::Colon => {
200243
if !token.is_empty() {
201244
// Safety: all bytes in `token` are ascii characters
202245
let s = unsafe { std::str::from_utf8_unchecked(&token[1..]) };
203246
index = s
204247
.parse::<u32>()
205248
.map_err(|_| ParseVectorError::BadParsing { position })?;
206249
token.clear();
207-
state = ParseState::Colon;
208250
} else {
209251
return Err(ParseVectorError::TooShortNumber { position });
210252
}
253+
state = ParseState::Value;
254+
position += 1;
255+
}
256+
ParseState::Start => {
257+
state = ParseState::Index;
258+
position += 1;
211259
}
212-
b' ' => (),
213-
_ => return Err(ParseVectorError::BadCharacter { position }),
214-
}
215-
}
216-
// A valid case is either
217-
// - empty string: ""
218-
// - end with number when a index is extracted:"1:2, 3:4"
219-
if state != ParseState::Start && (state != ParseState::Number || index == u32::MAX) {
220-
return Err(ParseVectorError::BadCharacter { position: right });
221-
}
222-
if !token.is_empty() {
223-
let position = right;
224-
// Safety: all bytes in `token` are ascii characters
225-
let s = unsafe { std::str::from_utf8_unchecked(&token[1..]) };
226-
let num = f(s).ok_or(ParseVectorError::BadParsing { position })?;
227-
if index as usize >= dims {
228-
return Err(ParseVectorError::OutOfBound {
229-
dims,
230-
index: index as usize,
231-
});
232-
}
233-
if !num.is_zero() {
234-
indexes.push(index);
235-
values.push(num);
236260
}
237-
token.clear();
238261
}
239-
// sort values and indexes ascend by indexes
240262
let mut indices = (0..indexes.len()).collect::<Vec<_>>();
241263
indices.sort_by_key(|&i| &indexes[i]);
242-
let sortedValues: Vec<T> = indices
264+
let sorted_values: Vec<T> = indices
243265
.iter()
244266
.map(|i| values.get(*i).unwrap().clone())
245267
.collect();
246268
indexes.sort();
247-
Ok((indexes, sortedValues, dims))
269+
Ok((indexes, sorted_values, dims))
248270
}
249271

250272
#[cfg(test)]
@@ -260,6 +282,10 @@ mod tests {
260282
let exprs: HashMap<&str, (Vec<u32>, Vec<F32>, usize)> = HashMap::from([
261283
("{}/1", (vec![], vec![], 1)),
262284
("{0:1}/1", (vec![0], vec![F32(1.0)], 1)),
285+
(
286+
"{0:1, 1:-2, }/2",
287+
(vec![0, 1], vec![F32(1.0), F32(-2.0)], 2),
288+
),
263289
("{0:1, 1:1.5}/2", (vec![0, 1], vec![F32(1.0), F32(1.5)], 2)),
264290
(
265291
"{0:+3, 2:-4.1}/3",
@@ -280,27 +306,47 @@ mod tests {
280306

281307
#[test]
282308
fn test_svector_parse_reject() {
283-
let exprs: Vec<&str> = vec![
284-
"{",
285-
"}",
286-
"{:",
287-
":}",
288-
"{0:1, 1:1.5}/1",
289-
"{0:0, 1:0, 2:0}/2",
290-
"{0:1, 1:2, 2:3}",
291-
"{0:1, 1:2, 2:3",
292-
"{0:1, 1:2}/",
293-
"{0}/5",
294-
"{0:}/5",
295-
"{:0}/5",
296-
"{0:, 1:2}/5",
297-
"{0:1, 1}/5",
298-
"/2",
299-
"{}/1/2",
300-
];
301-
for e in exprs {
309+
let exprs: HashMap<&str, ParseVectorError> = HashMap::from([
310+
("{", ParseVectorError::BadParentheses { character: '{' }),
311+
("}", ParseVectorError::BadParentheses { character: '{' }),
312+
("{:", ParseVectorError::BadCharacter { position: 1 }),
313+
(":}", ParseVectorError::BadCharacter { position: 0 }),
314+
(
315+
"{0:1, 1:1.5}/1",
316+
ParseVectorError::OutOfBound { dims: 1, index: 1 },
317+
),
318+
(
319+
"{0:0, 1:0, 2:0}/2",
320+
ParseVectorError::OutOfBound { dims: 2, index: 2 },
321+
),
322+
(
323+
"{0:1, 1:2, 2:3}",
324+
ParseVectorError::BadCharacter { position: 15 },
325+
),
326+
(
327+
"{0:1, 1:2, 2:3",
328+
ParseVectorError::BadCharacter { position: 12 },
329+
),
330+
("{0:1, 1:2}/", ParseVectorError::BadParsing { position: 10 }),
331+
("{0}/5", ParseVectorError::BadCharacter { position: 2 }),
332+
("{0:}/5", ParseVectorError::BadCharacter { position: 3 }),
333+
("{:0}/5", ParseVectorError::TooShortNumber { position: 1 }),
334+
(
335+
"{0:, 1:2}/5",
336+
ParseVectorError::TooShortNumber { position: 3 },
337+
),
338+
("{0:1, 1}/5", ParseVectorError::BadCharacter { position: 7 }),
339+
("/2", ParseVectorError::BadCharacter { position: 0 }),
340+
("{}/1/2", ParseVectorError::BadCharacter { position: 2 }),
341+
(
342+
"{1,2,3,4}/5",
343+
ParseVectorError::BadCharacter { position: 2 },
344+
),
345+
]);
346+
for (e, err) in exprs {
302347
let ret = parse_pgvector_svector(e.as_bytes(), |s| s.parse::<F32>().ok());
303-
assert!(ret.is_err(), "at expr {e}")
348+
assert!(ret.is_err(), "at expr {e}");
349+
assert_eq!(ret.unwrap_err(), err, "at expr {e}");
304350
}
305351
}
306352
}

0 commit comments

Comments
 (0)