|
| 1 | +use num_traits::Zero; |
1 | 2 | use thiserror::Error;
|
2 | 3 |
|
3 |
| -#[derive(Debug, Error)] |
| 4 | +#[derive(Debug, Error, PartialEq)] |
4 | 5 | pub enum ParseVectorError {
|
5 | 6 | #[error("The input string is empty.")]
|
6 | 7 | EmptyString {},
|
@@ -83,3 +84,195 @@ where
|
83 | 84 | }
|
84 | 85 | Ok(vector)
|
85 | 86 | }
|
| 87 | + |
| 88 | +#[derive(PartialEq, Debug, Clone)] |
| 89 | +enum ParseState { |
| 90 | + Start, |
| 91 | + LeftBracket, |
| 92 | + Index, |
| 93 | + Colon, |
| 94 | + Value, |
| 95 | + Comma, |
| 96 | + RightBracket, |
| 97 | + Splitter, |
| 98 | + Dims, |
| 99 | +} |
| 100 | + |
| 101 | +#[inline(always)] |
| 102 | +pub fn parse_pgvector_svector<T: Zero + Clone, F>( |
| 103 | + input: &[u8], |
| 104 | + f: F, |
| 105 | +) -> Result<(Vec<u32>, Vec<T>, usize), ParseVectorError> |
| 106 | +where |
| 107 | + F: Fn(&str) -> Option<T>, |
| 108 | +{ |
| 109 | + use arrayvec::ArrayVec; |
| 110 | + if input.is_empty() { |
| 111 | + return Err(ParseVectorError::EmptyString {}); |
| 112 | + } |
| 113 | + let mut token: ArrayVec<u8, 48> = ArrayVec::new(); |
| 114 | + let mut indexes = Vec::<u32>::new(); |
| 115 | + let mut values = Vec::<T>::new(); |
| 116 | + |
| 117 | + let mut state = ParseState::Start; |
| 118 | + for (position, c) in input.iter().copied().enumerate() { |
| 119 | + state = match (&state, c) { |
| 120 | + (_, b' ') => state, |
| 121 | + (ParseState::Start, b'{') => ParseState::LeftBracket, |
| 122 | + ( |
| 123 | + ParseState::LeftBracket | ParseState::Index | ParseState::Comma, |
| 124 | + b'0'..=b'9' | b'a'..=b'z' | b'A'..=b'Z' | b'.' | b'+' | b'-', |
| 125 | + ) => { |
| 126 | + if token.try_push(c).is_err() { |
| 127 | + return Err(ParseVectorError::TooLongNumber { position }); |
| 128 | + } |
| 129 | + ParseState::Index |
| 130 | + } |
| 131 | + (ParseState::Colon, b'0'..=b'9' | b'a'..=b'z' | b'A'..=b'Z' | b'.' | b'+' | b'-') => { |
| 132 | + if token.try_push(c).is_err() { |
| 133 | + return Err(ParseVectorError::TooLongNumber { position }); |
| 134 | + } |
| 135 | + ParseState::Value |
| 136 | + } |
| 137 | + (ParseState::LeftBracket | ParseState::Comma, b'}') => ParseState::RightBracket, |
| 138 | + (ParseState::Index, b':') => { |
| 139 | + let s = unsafe { std::str::from_utf8_unchecked(&token[..]) }; |
| 140 | + let index = s |
| 141 | + .parse::<u32>() |
| 142 | + .map_err(|_| ParseVectorError::BadParsing { position })?; |
| 143 | + indexes.push(index); |
| 144 | + token.clear(); |
| 145 | + ParseState::Colon |
| 146 | + } |
| 147 | + (ParseState::Value, b'0'..=b'9' | b'a'..=b'z' | b'A'..=b'Z' | b'.' | b'+' | b'-') => { |
| 148 | + if token.try_push(c).is_err() { |
| 149 | + return Err(ParseVectorError::TooLongNumber { position }); |
| 150 | + } |
| 151 | + ParseState::Value |
| 152 | + } |
| 153 | + (ParseState::Value, b',') => { |
| 154 | + let s = unsafe { std::str::from_utf8_unchecked(&token[..]) }; |
| 155 | + let num = f(s).ok_or(ParseVectorError::BadParsing { position })?; |
| 156 | + values.push(num); |
| 157 | + token.clear(); |
| 158 | + ParseState::Comma |
| 159 | + } |
| 160 | + (ParseState::Value, b'}') => { |
| 161 | + if token.is_empty() { |
| 162 | + return Err(ParseVectorError::TooShortNumber { position }); |
| 163 | + } |
| 164 | + let s = unsafe { std::str::from_utf8_unchecked(&token[..]) }; |
| 165 | + let num = f(s).ok_or(ParseVectorError::BadParsing { position })?; |
| 166 | + values.push(num); |
| 167 | + token.clear(); |
| 168 | + ParseState::RightBracket |
| 169 | + } |
| 170 | + (ParseState::RightBracket, b'/') => ParseState::Splitter, |
| 171 | + (ParseState::Dims | ParseState::Splitter, b'0'..=b'9') => { |
| 172 | + if token.try_push(c).is_err() { |
| 173 | + return Err(ParseVectorError::TooLongNumber { position }); |
| 174 | + } |
| 175 | + ParseState::Dims |
| 176 | + } |
| 177 | + (_, _) => { |
| 178 | + return Err(ParseVectorError::BadCharacter { position }); |
| 179 | + } |
| 180 | + } |
| 181 | + } |
| 182 | + if state != ParseState::Dims { |
| 183 | + return Err(ParseVectorError::BadParsing { |
| 184 | + position: input.len(), |
| 185 | + }); |
| 186 | + } |
| 187 | + let s = unsafe { std::str::from_utf8_unchecked(&token[..]) }; |
| 188 | + let dims = s |
| 189 | + .parse::<usize>() |
| 190 | + .map_err(|_| ParseVectorError::BadParsing { |
| 191 | + position: input.len(), |
| 192 | + })?; |
| 193 | + Ok((indexes, values, dims)) |
| 194 | +} |
| 195 | + |
| 196 | +#[cfg(test)] |
| 197 | +mod tests { |
| 198 | + use base::scalar::F32; |
| 199 | + |
| 200 | + use super::*; |
| 201 | + |
| 202 | + #[test] |
| 203 | + fn test_svector_parse_accept() { |
| 204 | + let exprs: Vec<(&str, (Vec<u32>, Vec<F32>, usize))> = vec![ |
| 205 | + ("{}/1", (vec![], vec![], 1)), |
| 206 | + ("{0:1}/1", (vec![0], vec![F32(1.0)], 1)), |
| 207 | + ( |
| 208 | + "{0:1, 1:-2, }/2", |
| 209 | + (vec![0, 1], vec![F32(1.0), F32(-2.0)], 2), |
| 210 | + ), |
| 211 | + ("{0:1, 1:1.5}/2", (vec![0, 1], vec![F32(1.0), F32(1.5)], 2)), |
| 212 | + ( |
| 213 | + "{0:+3, 2:-4.1}/3", |
| 214 | + (vec![0, 2], vec![F32(3.0), F32(-4.1)], 3), |
| 215 | + ), |
| 216 | + ( |
| 217 | + "{0:0, 1:0, 2:0}/3", |
| 218 | + (vec![0, 1, 2], vec![F32(0.0), F32(0.0), F32(0.0)], 3), |
| 219 | + ), |
| 220 | + ( |
| 221 | + "{3:3, 2:2, 1:1, 0:0}/4", |
| 222 | + ( |
| 223 | + vec![3, 2, 1, 0], |
| 224 | + vec![F32(3.0), F32(2.0), F32(1.0), F32(0.0)], |
| 225 | + 4, |
| 226 | + ), |
| 227 | + ), |
| 228 | + ]; |
| 229 | + for (e, parsed) in exprs { |
| 230 | + let ret = parse_pgvector_svector(e.as_bytes(), |s| s.parse::<F32>().ok()); |
| 231 | + assert!(ret.is_ok(), "at expr {:?}: {:?}", e, ret); |
| 232 | + assert_eq!(ret.unwrap(), parsed, "parsed at expr {:?}", e); |
| 233 | + } |
| 234 | + } |
| 235 | + |
| 236 | + #[test] |
| 237 | + fn test_svector_parse_reject() { |
| 238 | + let exprs: Vec<(&str, ParseVectorError)> = vec![ |
| 239 | + ("{", ParseVectorError::BadParsing { position: 1 }), |
| 240 | + ("}", ParseVectorError::BadCharacter { position: 0 }), |
| 241 | + ("{:", ParseVectorError::BadCharacter { position: 1 }), |
| 242 | + (":}", ParseVectorError::BadCharacter { position: 0 }), |
| 243 | + ( |
| 244 | + "{0:1, 1:2, 2:3}", |
| 245 | + ParseVectorError::BadParsing { position: 15 }, |
| 246 | + ), |
| 247 | + ( |
| 248 | + "{0:1, 1:2, 2:3", |
| 249 | + ParseVectorError::BadParsing { position: 14 }, |
| 250 | + ), |
| 251 | + ("{0:1, 1:2}/", ParseVectorError::BadParsing { position: 11 }), |
| 252 | + ("{0}/5", ParseVectorError::BadCharacter { position: 2 }), |
| 253 | + ("{0:}/5", ParseVectorError::BadCharacter { position: 3 }), |
| 254 | + ("{:0}/5", ParseVectorError::BadCharacter { position: 1 }), |
| 255 | + ( |
| 256 | + "{0:, 1:2}/5", |
| 257 | + ParseVectorError::BadCharacter { position: 3 }, |
| 258 | + ), |
| 259 | + ("{0:1, 1}/5", ParseVectorError::BadCharacter { position: 7 }), |
| 260 | + ("/2", ParseVectorError::BadCharacter { position: 0 }), |
| 261 | + ("{}/1/2", ParseVectorError::BadCharacter { position: 4 }), |
| 262 | + ( |
| 263 | + "{0:1, 1:2}/4/2", |
| 264 | + ParseVectorError::BadCharacter { position: 12 }, |
| 265 | + ), |
| 266 | + ("{}/-4", ParseVectorError::BadCharacter { position: 3 }), |
| 267 | + ( |
| 268 | + "{1,2,3,4}/5", |
| 269 | + ParseVectorError::BadCharacter { position: 2 }, |
| 270 | + ), |
| 271 | + ]; |
| 272 | + for (e, err) in exprs { |
| 273 | + let ret = parse_pgvector_svector(e.as_bytes(), |s| s.parse::<F32>().ok()); |
| 274 | + assert!(ret.is_err(), "at expr {:?}: {:?}", e, ret); |
| 275 | + assert_eq!(ret.unwrap_err(), err, "parsed at expr {:?}", e); |
| 276 | + } |
| 277 | + } |
| 278 | +} |
0 commit comments