Skip to content

Commit 5649ac8

Browse files
committed
refactor: use sparse struct to parse
Signed-off-by: cutecutecat <junyuchen@tensorchord.ai>
1 parent 4b2d567 commit 5649ac8

File tree

5 files changed

+70
-85
lines changed

5 files changed

+70
-85
lines changed

src/datatype/text_svecf32.rs

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,22 @@
11
use super::memory_svecf32::SVecf32Output;
22
use crate::datatype::memory_svecf32::SVecf32Input;
3-
use crate::datatype::typmod::Typmod;
43
use crate::error::*;
54
use base::scalar::*;
65
use base::vector::*;
7-
use num_traits::Zero;
86
use pgrx::pg_sys::Oid;
97
use std::ffi::{CStr, CString};
108

119
#[pgrx::pg_extern(immutable, strict, parallel_safe)]
12-
fn _vectors_svecf32_in(input: &CStr, _oid: Oid, typmod: i32) -> SVecf32Output {
10+
fn _vectors_svecf32_in(input: &CStr, _oid: Oid, _typmod: i32) -> SVecf32Output {
1311
use crate::utils::parse::parse_pgvector_svector;
14-
let reserve = Typmod::parse_from_i32(typmod)
15-
.unwrap()
16-
.dims()
17-
.map(|x| x.get())
18-
.unwrap_or(0);
19-
let v = parse_pgvector_svector(input.to_bytes(), reserve as usize, |s| {
20-
s.parse::<F32>().ok()
21-
});
12+
let v = parse_pgvector_svector(input.to_bytes(), |s| s.parse::<F32>().ok());
2213
match v {
2314
Err(e) => {
2415
bad_literal(&e.to_string());
2516
}
26-
Ok(vector) => {
27-
check_value_dims_1048575(vector.len());
28-
let mut indexes = Vec::<u32>::new();
29-
let mut values = Vec::<F32>::new();
30-
for (i, &x) in vector.iter().enumerate() {
31-
if !x.is_zero() {
32-
indexes.push(i as u32);
33-
values.push(x);
34-
}
35-
}
36-
SVecf32Output::new(SVecf32Borrowed::new(vector.len() as u32, &indexes, &values))
17+
Ok((indexes, values, dims)) => {
18+
check_value_dims_1048575(dims);
19+
SVecf32Output::new(SVecf32Borrowed::new(dims as u32, &indexes, &values))
3720
}
3821
}
3922
}

src/utils/parse.rs

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -88,16 +88,16 @@ where
8888
#[inline(always)]
8989
pub fn parse_pgvector_svector<T: Zero + Clone, F>(
9090
input: &[u8],
91-
reserve: usize,
9291
f: F,
93-
) -> Result<Vec<T>, ParseVectorError>
92+
) -> Result<(Vec<u32>, Vec<T>, usize), ParseVectorError>
9493
where
9594
F: Fn(&str) -> Option<T>,
9695
{
9796
use arrayvec::ArrayVec;
9897
if input.is_empty() {
9998
return Err(ParseVectorError::EmptyString {});
10099
}
100+
let mut dims: usize = 0;
101101
let left = 'a: {
102102
for position in 0..input.len() - 1 {
103103
match input[position] {
@@ -109,7 +109,6 @@ where
109109
return Err(ParseVectorError::BadParentheses { character: '{' });
110110
};
111111
let mut token: ArrayVec<u8, 48> = ArrayVec::new();
112-
let mut capacity = reserve;
113112
let right = 'a: {
114113
for position in (1..input.len()).rev() {
115114
match input[position] {
@@ -121,7 +120,7 @@ where
121120
b'/' => {
122121
token.reverse();
123122
let s = unsafe { std::str::from_utf8_unchecked(&token[..]) };
124-
capacity = s
123+
dims = s
125124
.parse::<usize>()
126125
.map_err(|_| ParseVectorError::BadParsing { position })?;
127126
}
@@ -135,8 +134,9 @@ where
135134
}
136135
return Err(ParseVectorError::BadParentheses { character: '}' });
137136
};
138-
let mut vector = vec![T::zero(); capacity];
139-
let mut index: usize = 0;
137+
let mut indexes = Vec::<u32>::new();
138+
let mut values = Vec::<T>::new();
139+
let mut index: u32 = 0;
140140
for position in left + 1..right {
141141
let c = input[position];
142142
match c {
@@ -153,7 +153,8 @@ where
153153
// Safety: all bytes in `token` are ascii characters
154154
let s = unsafe { std::str::from_utf8_unchecked(&token[1..]) };
155155
let num = f(s).ok_or(ParseVectorError::BadParsing { position })?;
156-
vector[index] = num;
156+
indexes.push(index);
157+
values.push(num);
157158
token.clear();
158159
} else {
159160
return Err(ParseVectorError::TooShortNumber { position });
@@ -164,7 +165,7 @@ where
164165
// Safety: all bytes in `token` are ascii characters
165166
let s = unsafe { std::str::from_utf8_unchecked(&token[1..]) };
166167
index = s
167-
.parse::<usize>()
168+
.parse::<u32>()
168169
.map_err(|_| ParseVectorError::BadParsing { position })?;
169170
token.clear();
170171
} else {
@@ -180,8 +181,9 @@ where
180181
// Safety: all bytes in `token` are ascii characters
181182
let s = unsafe { std::str::from_utf8_unchecked(&token[1..]) };
182183
let num = f(s).ok_or(ParseVectorError::BadParsing { position })?;
183-
vector[index] = num;
184+
indexes.push(index);
185+
values.push(num);
184186
token.clear();
185187
}
186-
Ok(vector)
188+
Ok((indexes, values, dims))
187189
}

tests/sqllogictest/sparse.slt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,12 @@ DROP TABLE t;
4040
query I
4141
SELECT to_svector(5, '{1,2}', '{1,2}');
4242
----
43-
{2:1, 3:2}/5
43+
{1:1, 2:2}/5
4444

4545
query I
4646
SELECT to_svector(5, '{1,2}', '{1,1}') * to_svector(5, '{1,3}', '{2,2}');
4747
----
48-
[0, 2, 0, 0, 0]
48+
{1:2}/5
4949

5050
statement error Lengths of index and value are not matched.
5151
SELECT to_svector(5, '{1,2,3}', '{1,2}');

tests/sqllogictest/svector.slt

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ CREATE TABLE t (id bigserial, val svector);
66

77
statement ok
88
INSERT INTO t (val)
9-
VALUES ('[1,2,3]'), ('[4,5,6]');
9+
VALUES ('{0:1, 1:2, 2:3}/3'), ('{0:4, 1:5, 2:6}/3');
1010

1111
query I
1212
SELECT vector_dims(val) FROM t;
@@ -23,12 +23,12 @@ SELECT round(vector_norm(val)::numeric, 5) FROM t;
2323
query ?
2424
SELECT avg(val) FROM t;
2525
----
26-
[2.5, 3.5, 4.5]
26+
{0:2.5, 1:3.5, 2:4.5}/3
2727

2828
query ?
2929
SELECT sum(val) FROM t;
3030
----
31-
[5, 7, 9]
31+
{0:5, 1:7, 2:9}/3
3232

3333
statement ok
3434
CREATE TABLE test_vectors (id serial, data vector(1000));
@@ -46,35 +46,35 @@ SELECT count(*) FROM test_vectors;
4646
5000
4747

4848
query R
49-
SELECT vector_norm('[3,4]'::svector);
49+
SELECT vector_norm('{0:3, 1:4}/2'::svector);
5050
----
5151
5
5252

5353
query I
54-
SELECT vector_dims(v) FROM unnest(ARRAY['[1,2]'::svector, '[3]']) v;
54+
SELECT vector_dims(v) FROM unnest(ARRAY['{0:1, 1:2}/2'::svector, '{0:3}/1'::svector]) v;
5555
----
5656
2
5757
1
5858

5959
query ?
60-
SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::svector, '[3,5,7]']) v;
60+
SELECT avg(v) FROM unnest(ARRAY['{0:1, 1:2, 2:3}/3'::svector, '{0:3, 1:5, 2:7}/3'::svector]) v;
6161
----
62-
[2, 3.5, 5]
62+
{0:2, 1:3.5, 2:5}/3
6363

6464
query ?
65-
SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::svector, '[-1,2,-3]']) v;
65+
SELECT avg(v) FROM unnest(ARRAY['{0:1, 1:2, 2:3}/3'::svector, '{0:-1, 1:2, 2:-3}/3'::svector]) v;
6666
----
67-
[0, 2, 0]
67+
{1:2}/3
6868

6969
query ?
70-
SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::svector, '[3,5,7]', NULL]) v;
70+
SELECT avg(v) FROM unnest(ARRAY['{0:1, 1:2, 2:3}/3'::svector, '{0:3, 1:5, 2:7}/3'::svector, NULL]) v;
7171
----
72-
[2, 3.5, 5]
72+
{0:2, 1:3.5, 2:5}/3
7373

7474
query ?
75-
SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::svector,NULL]) v;
75+
SELECT avg(v) FROM unnest(ARRAY['{0:1, 1:2, 2:3}/3'::svector,NULL]) v;
7676
----
77-
[1, 2, 3]
77+
{0:1, 1:2, 2:3}/3
7878

7979
query ?
8080
SELECT avg(v) FROM unnest(ARRAY[]::svector[]) v;
@@ -87,22 +87,22 @@ SELECT avg(v) FROM unnest(ARRAY[NULL]::svector[]) v;
8787
NULL
8888

8989
query ?
90-
SELECT avg(v) FROM unnest(ARRAY['[3e38]'::svector, '[3e38]']) v;
90+
SELECT avg(v) FROM unnest(ARRAY['{0:3e38}/1'::svector, '{0:3e38}/1'::svector]) v;
9191
----
92-
[inf]
92+
{0:inf}/1
9393

9494
statement error differs in dimensions
95-
SELECT avg(v) FROM unnest(ARRAY['[1,2]'::svector, '[3]']) v;
95+
SELECT avg(v) FROM unnest(ARRAY['{0:1, 1:2}/2'::svector, '{0:3}/1'::svector]) v;
9696

9797
query ?
9898
SELECT avg(v) FROM unnest(ARRAY[to_svector(5, '{0,1}', '{2,3}'), to_svector(5, '{0,2}', '{1,3}'), to_svector(5, '{3,4}', '{3,3}')]) v;
9999
----
100-
[1, 1, 1, 1, 1]
100+
{0:1, 1:1, 2:1, 3:1, 4:1}/5
101101

102102
query ?
103103
SELECT avg(v) FROM unnest(ARRAY[to_svector(32, '{0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}', '{1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1}'), to_svector(32, '{1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16}', '{1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1}'), to_svector(32, '{2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17}', '{1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1}') ]) v;
104104
----
105-
[0.33333334, 0.6666667, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0.6666667, 0.33333334, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
105+
{0:0.33333334, 1:0.6666667, 2:1, 3:1, 4:1, 5:1, 6:1, 7:1, 8:1, 9:1, 10:1, 11:1, 12:1, 13:1, 14:1, 15:1, 16:0.6666667, 17:0.33333334}/32
106106

107107
# test avg(svector) get the same result as avg(vector)
108108
query ?
@@ -111,20 +111,20 @@ SELECT avg(data) = avg(data::svector)::vector FROM test_vectors;
111111
t
112112

113113
query ?
114-
SELECT sum(v) FROM unnest(ARRAY['[1,2,3]'::svector, '[3,5,7]']) v;
114+
SELECT sum(v) FROM unnest(ARRAY['{0:1, 1:2, 2:3}/3'::svector, '{0:3, 1:5, 2:7}/3'::svector]) v;
115115
----
116-
[4, 7, 10]
116+
{0:4, 1:7, 2:10}/3
117117

118118
# test zero element
119119
query ?
120-
SELECT sum(v) FROM unnest(ARRAY['[1,2,3]'::svector, '[-1,2,-3]']) v;
120+
SELECT sum(v) FROM unnest(ARRAY['{0:1, 1:2, 2:3}/3'::svector, '{0:-1, 1:2, 2:-3}/3'::svector]) v;
121121
----
122-
[0, 4, 0]
122+
{1:4}/3
123123

124124
query ?
125-
SELECT sum(v) FROM unnest(ARRAY['[1,2,3]'::svector, '[3,5,7]', NULL]) v;
125+
SELECT sum(v) FROM unnest(ARRAY['{0:1, 1:2, 2:3}/3'::svector, '{0:3, 1:5, 2:7}/3'::svector, NULL]) v;
126126
----
127-
[4, 7, 10]
127+
{0:4, 1:7, 2:10}/3
128128

129129
query ?
130130
SELECT sum(v) FROM unnest(ARRAY[]::svector[]) v;
@@ -137,23 +137,23 @@ SELECT sum(v) FROM unnest(ARRAY[NULL]::svector[]) v;
137137
NULL
138138

139139
statement error differs in dimensions
140-
SELECT sum(v) FROM unnest(ARRAY['[1,2]'::svector, '[3]']) v;
140+
SELECT sum(v) FROM unnest(ARRAY['{0:1, 1:2}/2'::svector, '{0:3}/1'::svector]) v;
141141

142142
# should this return an error ?
143143
query ?
144-
SELECT sum(v) FROM unnest(ARRAY['[3e38]'::svector, '[3e38]']) v;
144+
SELECT sum(v) FROM unnest(ARRAY['{0:3e38}/1'::svector, '{0:3e38}/1'::svector]) v;
145145
----
146-
[inf]
146+
{0:inf}/1
147147

148148
query ?
149149
SELECT sum(v) FROM unnest(ARRAY[to_svector(5, '{0,1}', '{1,2}'), to_svector(5, '{0,2}', '{1,2}'), to_svector(5, '{3,4}', '{3,3}')]) v;
150150
----
151-
[2, 2, 2, 3, 3]
151+
{0:2, 1:2, 2:2, 3:3, 4:3}/5
152152

153153
query ?
154154
SELECT sum(v) FROM unnest(ARRAY[to_svector(32, '{0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}', '{1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1}'), to_svector(32, '{1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16}', '{1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1}'), to_svector(32, '{2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17}', '{1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1}') ]) v;
155155
----
156-
[1, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
156+
{0:1, 1:2, 2:3, 3:3, 4:3, 5:3, 6:3, 7:3, 8:3, 9:3, 10:3, 11:3, 12:3, 13:3, 14:3, 15:3, 16:2, 17:1}/32
157157

158158
# test sum(svector) get the same result as sum(vector)
159159
query ?

tests/sqllogictest/svector_subscript.slt

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,83 +2,83 @@ statement ok
22
SET search_path TO pg_temp, vectors;
33

44
query I
5-
SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[3:6];
5+
SELECT ('{1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7}/8'::svector)[3:6];
66
----
7-
{1:3, 2:4, 3:5}/3
7+
{0:3, 1:4, 2:5}/3
88

99
query I
10-
SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[:4];
10+
SELECT ('{1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7}/8'::svector)[:4];
1111
----
12-
{2:1, 3:2, 4:3}/4
12+
{1:1, 2:2, 3:3}/4
1313

1414
query I
15-
SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[5:];
15+
SELECT ('{1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7}/8'::svector)[5:];
1616
----
17-
{1:5, 2:6, 3:7}/3
17+
{0:5, 1:6, 2:7}/3
1818

1919
query I
20-
SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[1:8];
20+
SELECT ('{1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7}/8'::svector)[1:8];
2121
----
22-
{1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7}/7
22+
{0:1, 1:2, 2:3, 3:4, 4:5, 5:6, 6:7}/7
2323

2424
statement error type svector does only support one subscript
25-
SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[3:3][1:1];
25+
SELECT ('{1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7}/8'::svector)[3:3][1:1];
2626

2727
statement error type svector does only support slice fetch
28-
SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[3];
28+
SELECT ('{1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7}/8'::svector)[3];
2929

3030
query I
31-
SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[5:4];
31+
SELECT ('{1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7}/8'::svector)[5:4];
3232
----
3333
NULL
3434

3535
query I
36-
SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[9:];
36+
SELECT ('{1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7}/8'::svector)[9:];
3737
----
3838
NULL
3939

4040
query I
41-
SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[:0];
41+
SELECT ('{1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7}/8'::svector)[:0];
4242
----
4343
NULL
4444

4545
query I
46-
SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[:-1];
46+
SELECT ('{1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7}/8'::svector)[:-1];
4747
----
4848
NULL
4949

5050
query I
51-
SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[NULL:NULL];
51+
SELECT ('{1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7}/8'::svector)[NULL:NULL];
5252
----
5353
NULL
5454

5555
query I
56-
SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[NULL:8];
56+
SELECT ('{1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7}/8'::svector)[NULL:8];
5757
----
5858
NULL
5959

6060
query I
61-
SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[1:NULL];
61+
SELECT ('{1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7}/8'::svector)[1:NULL];
6262
----
6363
NULL
6464

6565
query I
66-
SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[NULL:];
66+
SELECT ('{1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7}/8'::svector)[NULL:];
6767
----
6868
NULL
6969

7070
query I
71-
SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[:NULL];
71+
SELECT ('{1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7}/8'::svector)[:NULL];
7272
----
7373
NULL
7474

7575
query I
76-
SELECT ('{3:2, 5:4, 8:7}/8'::svector)[3:7];
76+
SELECT ('{2:2, 4:4, 7:7}/8'::svector)[3:7];
7777
----
78-
{2:4}/4
78+
{1:4}/4
7979

8080
query I
81-
SELECT ('{3:2, 5:4, 8:7}/8'::svector)[5:7];
81+
SELECT ('{2:2, 4:4, 7:7}/8'::svector)[5:7];
8282
----
8383
{}/2
8484

0 commit comments

Comments
 (0)