Skip to content

Commit 4d5bd93

Browse files
committed
fix: zero-check, sort and tests
Signed-off-by: cutecutecat <junyuchen@tensorchord.ai>
1 parent 5649ac8 commit 4d5bd93

File tree

1 file changed

+109
-6
lines changed

1 file changed

+109
-6
lines changed

src/utils/parse.rs

Lines changed: 109 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ pub enum ParseVectorError {
1515
TooShortNumber { position: usize },
1616
#[error("Bad parsing at position {position}")]
1717
BadParsing { position: usize },
18+
#[error("Index out of bounds: the dim is {dims} but the index is {index}")]
19+
OutOfBound { dims: usize, index: usize },
1820
}
1921

2022
#[inline(always)]
@@ -85,6 +87,14 @@ where
8587
Ok(vector)
8688
}
8789

90+
#[derive(PartialEq)]
91+
enum ParseState {
92+
Number,
93+
Comma,
94+
Colon,
95+
Start,
96+
}
97+
8898
#[inline(always)]
8999
pub fn parse_pgvector_svector<T: Zero + Clone, F>(
90100
input: &[u8],
@@ -136,7 +146,9 @@ where
136146
};
137147
let mut indexes = Vec::<u32>::new();
138148
let mut values = Vec::<T>::new();
139-
let mut index: u32 = 0;
149+
let mut index: u32 = u32::MAX;
150+
let mut state = ParseState::Start;
151+
140152
for position in left + 1..right {
141153
let c = input[position];
142154
match c {
@@ -147,15 +159,29 @@ where
147159
if token.try_push(c).is_err() {
148160
return Err(ParseVectorError::TooLongNumber { position });
149161
}
162+
state = ParseState::Number;
150163
}
151164
b',' => {
165+
if state != ParseState::Number {
166+
return Err(ParseVectorError::BadCharacter { position });
167+
}
152168
if !token.is_empty() {
153169
// Safety: all bytes in `token` are ascii characters
154170
let s = unsafe { std::str::from_utf8_unchecked(&token[1..]) };
155171
let num = f(s).ok_or(ParseVectorError::BadParsing { position })?;
156-
indexes.push(index);
157-
values.push(num);
172+
if index as usize >= dims {
173+
return Err(ParseVectorError::OutOfBound {
174+
dims,
175+
index: index as usize,
176+
});
177+
}
178+
if !num.is_zero() {
179+
indexes.push(index);
180+
values.push(num);
181+
}
182+
index = u32::MAX;
158183
token.clear();
184+
state = ParseState::Comma;
159185
} else {
160186
return Err(ParseVectorError::TooShortNumber { position });
161187
}
@@ -168,6 +194,7 @@ where
168194
.parse::<u32>()
169195
.map_err(|_| ParseVectorError::BadParsing { position })?;
170196
token.clear();
197+
state = ParseState::Colon;
171198
} else {
172199
return Err(ParseVectorError::TooShortNumber { position });
173200
}
@@ -176,14 +203,90 @@ where
176203
_ => return Err(ParseVectorError::BadCharacter { position }),
177204
}
178205
}
206+
if state != ParseState::Start && (state != ParseState::Number || index == u32::MAX) {
207+
return Err(ParseVectorError::BadCharacter { position: right });
208+
}
179209
if !token.is_empty() {
180210
let position = right;
181211
// Safety: all bytes in `token` are ascii characters
182212
let s = unsafe { std::str::from_utf8_unchecked(&token[1..]) };
183213
let num = f(s).ok_or(ParseVectorError::BadParsing { position })?;
184-
indexes.push(index);
185-
values.push(num);
214+
if index as usize >= dims {
215+
return Err(ParseVectorError::OutOfBound {
216+
dims,
217+
index: index as usize,
218+
});
219+
}
220+
if !num.is_zero() {
221+
indexes.push(index);
222+
values.push(num);
223+
}
186224
token.clear();
187225
}
188-
Ok((indexes, values, dims))
226+
// sort values and indexes ascend by indexes
227+
let mut indices = (0..indexes.len()).collect::<Vec<_>>();
228+
indices.sort_by_key(|&i| &indexes[i]);
229+
let sortedValues: Vec<T> = indices
230+
.iter()
231+
.map(|i| values.get(*i).unwrap().clone())
232+
.collect();
233+
indexes.sort();
234+
Ok((indexes, sortedValues, dims))
235+
}
236+
237+
#[cfg(test)]
238+
mod tests {
239+
use std::collections::HashMap;
240+
241+
use base::scalar::F32;
242+
243+
use super::*;
244+
245+
#[test]
246+
fn test_svector_parse_accept() {
247+
let exprs: HashMap<&str, (Vec<u32>, Vec<F32>, usize)> = HashMap::from([
248+
("{}/1", (vec![], vec![], 1)),
249+
("{0:1}/1", (vec![0], vec![F32(1.0)], 1)),
250+
("{0:1, 1:1.5}/2", (vec![0, 1], vec![F32(1.0), F32(1.5)], 2)),
251+
(
252+
"{0:+3, 2:-4.1}/3",
253+
(vec![0, 2], vec![F32(3.0), F32(-4.1)], 3),
254+
),
255+
("{0:0, 1:0, 2:0}/3", (vec![], vec![], 3)),
256+
(
257+
"{3:3, 2:2, 1:1, 0:0}/4",
258+
(vec![1, 2, 3], vec![F32(1.0), F32(2.0), F32(3.0)], 4),
259+
),
260+
]);
261+
for (e, ans) in exprs {
262+
let ret = parse_pgvector_svector(e.as_bytes(), |s| s.parse::<F32>().ok());
263+
assert!(ret.is_ok(), "at expr {e}");
264+
assert_eq!(ret.unwrap(), ans, "at expr {e}");
265+
}
266+
}
267+
268+
#[test]
269+
fn test_svector_parse_reject() {
270+
let exprs: Vec<&str> = vec![
271+
"{",
272+
"}",
273+
"{:",
274+
":}",
275+
"{0:1, 1:1.5}/1",
276+
"{0:0, 1:0, 2:0}/2",
277+
"{0:1, 1:2, 2:3}",
278+
"{0:1, 1:2, 2:3",
279+
"{0:1, 1:2}/",
280+
"{0}/5",
281+
"{0:}/5",
282+
"{:0}/5",
283+
"{0:, 1:2}/5",
284+
"{0:1, 1}/5",
285+
"/2",
286+
];
287+
for e in exprs {
288+
let ret = parse_pgvector_svector(e.as_bytes(), |s| s.parse::<F32>().ok());
289+
assert!(ret.is_err(), "at expr {e}")
290+
}
291+
}
189292
}

0 commit comments

Comments
 (0)