Skip to content

Commit 13e4245

Browse files
feat: temporary support for sparse aggregate. (#489)
* Implement aggregate functions for sparse vector. Signed-off-by: my-vegetable-has-exploded <wy1109468038@gmail.com> * fix null. Signed-off-by: my-vegetable-has-exploded <wy1109468038@gmail.com> * tests: add more e2e test. Signed-off-by: my-vegetable-has-exploded <wy1109468038@gmail.com> --------- Signed-off-by: my-vegetable-has-exploded <wy1109468038@gmail.com>
1 parent 8f99933 commit 13e4245

File tree

4 files changed

+256
-0
lines changed

4 files changed

+256
-0
lines changed

src/datatype/functions_svecf32.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use super::memory_svecf32::*;
22
use crate::error::*;
33
use base::scalar::*;
44
use base::vector::*;
5+
use num_traits::Zero;
56

67
#[pgrx::pg_extern(immutable, strict, parallel_safe)]
78
fn _vectors_svecf32_dims(vector: SVecf32Input<'_>) -> i32 {
@@ -58,3 +59,26 @@ fn _vectors_to_svector(
5859
}
5960
SVecf32Output::new(SVecf32Borrowed::new(dims.get(), &indexes, &values))
6061
}
62+
63+
/// divide a sparse vector by a scalar.
64+
#[pgrx::pg_extern(immutable, strict, parallel_safe)]
65+
fn _vectors_svecf32_div(vector: SVecf32Input<'_>, scalar: f32) -> SVecf32Output {
66+
let scalar = F32(scalar);
67+
let vector = vector.for_borrow();
68+
let indexes = vector.indexes();
69+
let values = vector.values();
70+
let mut new_indexes = Vec::<u32>::with_capacity(indexes.len());
71+
let mut new_values = Vec::<F32>::with_capacity(values.len());
72+
for (value, index) in values.iter().zip(indexes.iter()) {
73+
let v = *value / scalar;
74+
if !v.is_zero() {
75+
new_values.push(v);
76+
new_indexes.push(*index);
77+
}
78+
}
79+
SVecf32Output::new(SVecf32Borrowed::new(
80+
vector.dims(),
81+
&new_indexes,
82+
&new_values,
83+
))
84+
}

src/sql/bootstrap.sql

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,6 @@ CREATE TYPE veci8;
1111
CREATE TYPE vector_index_stat;
1212

1313
CREATE TYPE _vectors_vecf32_aggregate_avg_stype;
14+
CREATE TYPE svector_accumulate_state;
1415

1516
-- bootstrap end

src/sql/finalize.sql

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,11 @@ CREATE TYPE _vectors_vecf32_aggregate_avg_stype (
8686
ALIGNMENT = double
8787
);
8888

89+
CREATE TYPE svector_accumulate_state AS (
90+
count INT,
91+
sum svector
92+
);
93+
8994
-- List of operators
9095

9196
CREATE OPERATOR + (
@@ -659,6 +664,51 @@ IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '_vectors_binari
659664
CREATE FUNCTION to_veci8("len" INT, "alpha" real, "offset" real, "values" INT[]) RETURNS veci8
660665
IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '_vectors_to_veci8_wrapper';
661666

667+
CREATE FUNCTION _vectors_svector_accum("state" svector_accumulate_state, "value" svector) RETURNS svector_accumulate_state AS $$
668+
DECLARE
669+
result svector_accumulate_state;
670+
BEGIN
671+
IF state.count = 0 THEN
672+
result.count := 1;
673+
result.sum := value;
674+
RETURN result;
675+
END IF;
676+
result.count := state.count + 1;
677+
result.sum := state.sum + value;
678+
RETURN result;
679+
END;
680+
$$ LANGUAGE plpgsql STRICT PARALLEL SAFE;
681+
682+
CREATE FUNCTION _vectors_svector_combine("state1" svector_accumulate_state, "state2" svector_accumulate_state) RETURNS svector_accumulate_state AS $$
683+
DECLARE
684+
result svector_accumulate_state;
685+
BEGIN
686+
IF state1.count = 0 THEN
687+
RETURN state2;
688+
END IF;
689+
IF state2.count = 0 THEN
690+
RETURN state1;
691+
END IF;
692+
result.count := state1.count + state2.count;
693+
result.sum := state1.sum + state2.sum;
694+
RETURN result;
695+
END;
696+
$$ LANGUAGE plpgsql STRICT PARALLEL SAFE;
697+
698+
CREATE FUNCTION _vectors_svector_final("state" svector_accumulate_state) RETURNS svector AS $$
699+
DECLARE
700+
result svector;
701+
count INT;
702+
BEGIN
703+
count := state.count;
704+
IF count = 0 THEN
705+
RETURN NULL;
706+
END IF;
707+
result := _vectors_svecf32_div(state.sum, count::real);
708+
RETURN result;
709+
END;
710+
$$ LANGUAGE plpgsql STRICT PARALLEL SAFE;
711+
662712
-- List of aggregates
663713

664714
CREATE AGGREGATE avg(vector) (
@@ -677,6 +727,22 @@ CREATE AGGREGATE sum(vector) (
677727
PARALLEL = SAFE
678728
);
679729

730+
CREATE AGGREGATE avg(svector) (
731+
SFUNC = _vectors_svector_accum,
732+
STYPE = svector_accumulate_state,
733+
COMBINEFUNC = _vectors_svector_combine,
734+
FINALFUNC = _vectors_svector_final,
735+
INITCOND = '(0, [0])',
736+
PARALLEL = SAFE
737+
);
738+
739+
CREATE AGGREGATE sum(svector) (
740+
SFUNC = _vectors_svecf32_operator_add,
741+
STYPE = svector,
742+
COMBINEFUNC = _vectors_svecf32_operator_add,
743+
PARALLEL = SAFE
744+
);
745+
680746
-- List of casts
681747

682748
CREATE CAST (real[] AS vector)

tests/sqllogictest/svector.slt

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
statement ok
2+
SET search_path TO pg_temp, vectors;
3+
4+
statement ok
5+
CREATE TABLE t (id bigserial, val svector);
6+
7+
statement ok
8+
INSERT INTO t (val)
9+
VALUES ('[1,2,3]'), ('[4,5,6]');
10+
11+
query I
12+
SELECT vector_dims(val) FROM t;
13+
----
14+
3
15+
3
16+
17+
query R
18+
SELECT round(vector_norm(val)::numeric, 5) FROM t;
19+
----
20+
3.74166
21+
8.77496
22+
23+
query ?
24+
SELECT avg(val) FROM t;
25+
----
26+
[2.5, 3.5, 4.5]
27+
28+
query ?
29+
SELECT sum(val) FROM t;
30+
----
31+
[5, 7, 9]
32+
33+
statement ok
34+
CREATE TABLE test_vectors (id serial, data vector(1000));
35+
36+
statement ok
37+
INSERT INTO test_vectors (data)
38+
SELECT
39+
ARRAY_AGG(CASE WHEN random() < 0.95 THEN 0 ELSE (random() * 99 + 1)::real END)::real[]::vector AS v
40+
FROM generate_series(1, 1000 * 5000) i
41+
GROUP BY i % 5000;
42+
43+
query ?
44+
SELECT count(*) FROM test_vectors;
45+
----
46+
5000
47+
48+
query R
49+
SELECT vector_norm('[3,4]'::svector);
50+
----
51+
5
52+
53+
query I
54+
SELECT vector_dims(v) FROM unnest(ARRAY['[1,2]'::svector, '[3]']) v;
55+
----
56+
2
57+
1
58+
59+
query ?
60+
SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::svector, '[3,5,7]']) v;
61+
----
62+
[2, 3.5, 5]
63+
64+
query ?
65+
SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::svector, '[-1,2,-3]']) v;
66+
----
67+
[0, 2, 0]
68+
69+
query ?
70+
SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::svector, '[3,5,7]', NULL]) v;
71+
----
72+
[2, 3.5, 5]
73+
74+
query ?
75+
SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::svector,NULL]) v;
76+
----
77+
[1, 2, 3]
78+
79+
query ?
80+
SELECT avg(v) FROM unnest(ARRAY[]::svector[]) v;
81+
----
82+
NULL
83+
84+
query ?
85+
SELECT avg(v) FROM unnest(ARRAY[NULL]::svector[]) v;
86+
----
87+
NULL
88+
89+
query ?
90+
SELECT avg(v) FROM unnest(ARRAY['[3e38]'::svector, '[3e38]']) v;
91+
----
92+
[inf]
93+
94+
statement error differs in dimensions
95+
SELECT avg(v) FROM unnest(ARRAY['[1,2]'::svector, '[3]']) v;
96+
97+
query ?
98+
SELECT avg(v) FROM unnest(ARRAY[to_svector(5, '{0,1}', '{2,3}'), to_svector(5, '{0,2}', '{1,3}'), to_svector(5, '{3,4}', '{3,3}')]) v;
99+
----
100+
[1, 1, 1, 1, 1]
101+
102+
query ?
103+
SELECT avg(v) FROM unnest(ARRAY[to_svector(32, '{0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}', '{1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1}'), to_svector(32, '{1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16}', '{1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1}'), to_svector(32, '{2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17}', '{1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1}') ]) v;
104+
----
105+
[0.33333334, 0.6666667, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0.6666667, 0.33333334, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
106+
107+
# test avg(svector) get the same result as avg(vector)
108+
query ?
109+
SELECT avg(data) = avg(data::svector)::vector FROM test_vectors;
110+
----
111+
t
112+
113+
query ?
114+
SELECT sum(v) FROM unnest(ARRAY['[1,2,3]'::svector, '[3,5,7]']) v;
115+
----
116+
[4, 7, 10]
117+
118+
# test zero element
119+
query ?
120+
SELECT sum(v) FROM unnest(ARRAY['[1,2,3]'::svector, '[-1,2,-3]']) v;
121+
----
122+
[0, 4, 0]
123+
124+
query ?
125+
SELECT sum(v) FROM unnest(ARRAY['[1,2,3]'::svector, '[3,5,7]', NULL]) v;
126+
----
127+
[4, 7, 10]
128+
129+
query ?
130+
SELECT sum(v) FROM unnest(ARRAY[]::svector[]) v;
131+
----
132+
NULL
133+
134+
query ?
135+
SELECT sum(v) FROM unnest(ARRAY[NULL]::svector[]) v;
136+
----
137+
NULL
138+
139+
statement error differs in dimensions
140+
SELECT sum(v) FROM unnest(ARRAY['[1,2]'::svector, '[3]']) v;
141+
142+
# should this return an error ?
143+
query ?
144+
SELECT sum(v) FROM unnest(ARRAY['[3e38]'::svector, '[3e38]']) v;
145+
----
146+
[inf]
147+
148+
query ?
149+
SELECT sum(v) FROM unnest(ARRAY[to_svector(5, '{0,1}', '{1,2}'), to_svector(5, '{0,2}', '{1,2}'), to_svector(5, '{3,4}', '{3,3}')]) v;
150+
----
151+
[2, 2, 2, 3, 3]
152+
153+
query ?
154+
SELECT sum(v) FROM unnest(ARRAY[to_svector(32, '{0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}', '{1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1}'), to_svector(32, '{1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16}', '{1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1}'), to_svector(32, '{2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17}', '{1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1}') ]) v;
155+
----
156+
[1, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
157+
158+
# test sum(svector) get the same result as sum(vector)
159+
query ?
160+
SELECT sum(data) = sum(data::svector)::vector FROM test_vectors;
161+
----
162+
t
163+
164+
statement ok
165+
DROP TABLE t, test_vectors;

0 commit comments

Comments
 (0)