Skip to content

Commit 887a851

Browse files
feat: support element-wise multiplication. (#480)
* feat: support element-wise multiplication. Signed-off-by: my-vegetable-has-exploded <wy1109468038@gmail.com> * fix tap. Signed-off-by: my-vegetable-has-exploded <wy1109468038@gmail.com> * Implement mul for vecf16 & veci8. Signed-off-by: my-vegetable-has-exploded <wy1109468038@gmail.com> * fix zero. Signed-off-by: my-vegetable-has-exploded <wy1109468038@gmail.com> --------- Signed-off-by: my-vegetable-has-exploded <wy1109468038@gmail.com>
1 parent a299b54 commit 887a851

File tree

9 files changed

+127
-0
lines changed

9 files changed

+127
-0
lines changed

src/datatype/operators_svecf32.rs

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,49 @@ fn _vectors_svecf32_operator_minus(lhs: SVecf32Input<'_>, rhs: SVecf32Input<'_>)
8888
SVecf32Output::new(SVecf32Borrowed::new(lhs.dims(), &indexes, &values))
8989
}
9090

91+
/// Calculate the element-wise multiplication of two sparse vectors.
92+
#[pgrx::pg_extern(immutable, strict, parallel_safe)]
93+
fn _vectors_svecf32_operator_mul(lhs: SVecf32Input<'_>, rhs: SVecf32Input<'_>) -> SVecf32Output {
94+
check_matched_dims(lhs.dims() as _, rhs.dims() as _);
95+
96+
let size1 = lhs.len();
97+
let size2 = rhs.len();
98+
let mut pos1 = 0;
99+
let mut pos2 = 0;
100+
let mut pos = 0;
101+
let mut indexes = vec![0; std::cmp::min(size1, size2)];
102+
let mut values = vec![F32::zero(); std::cmp::min(size1, size2)];
103+
let lhs = lhs.for_borrow();
104+
let rhs = rhs.for_borrow();
105+
while pos1 < size1 && pos2 < size2 {
106+
let lhs_index = lhs.indexes()[pos1];
107+
let rhs_index = rhs.indexes()[pos2];
108+
match lhs_index.cmp(&rhs_index) {
109+
std::cmp::Ordering::Less => {
110+
pos1 += 1;
111+
}
112+
std::cmp::Ordering::Equal => {
113+
// only both indexes are not zero, values are multiplied
114+
let lhs_value = lhs.values()[pos1];
115+
let rhs_value = rhs.values()[pos2];
116+
indexes[pos] = lhs_index;
117+
values[pos] = lhs_value * rhs_value;
118+
pos1 += 1;
119+
pos2 += 1;
120+
// only increment pos if the value is not zero
121+
pos += (!values[pos].is_zero()) as usize;
122+
}
123+
std::cmp::Ordering::Greater => {
124+
pos2 += 1;
125+
}
126+
}
127+
}
128+
indexes.truncate(pos);
129+
values.truncate(pos);
130+
131+
SVecf32Output::new(SVecf32Borrowed::new(lhs.dims(), &indexes, &values))
132+
}
133+
91134
#[pgrx::pg_extern(immutable, strict, parallel_safe)]
92135
fn _vectors_svecf32_operator_lt(lhs: SVecf32Input<'_>, rhs: SVecf32Input<'_>) -> bool {
93136
check_matched_dims(lhs.dims() as _, rhs.dims() as _);

src/datatype/operators_vecf16.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,17 @@ fn _vectors_vecf16_operator_minus(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) ->
2626
Vecf16Output::new(Vecf16Borrowed::new(&v))
2727
}
2828

29+
/// Calculate the element-wise multiplication of two f16 vectors.
30+
#[pgrx::pg_extern(immutable, strict, parallel_safe)]
31+
fn _vectors_vecf16_operator_mul(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> Vecf16Output {
32+
let n = check_matched_dims(lhs.dims(), rhs.dims());
33+
let mut v = vec![F16::zero(); n];
34+
for i in 0..n {
35+
v[i] = lhs[i] * rhs[i];
36+
}
37+
Vecf16Output::new(Vecf16Borrowed::new(&v))
38+
}
39+
2940
#[pgrx::pg_extern(immutable, strict, parallel_safe)]
3041
fn _vectors_vecf16_operator_lt(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> bool {
3142
check_matched_dims(lhs.dims(), rhs.dims());

src/datatype/operators_vecf32.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,17 @@ fn _vectors_vecf32_operator_minus(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) ->
2626
Vecf32Output::new(Vecf32Borrowed::new(&v))
2727
}
2828

29+
/// Calculate the element-wise multiplication of two vectors.
30+
#[pgrx::pg_extern(immutable, strict, parallel_safe)]
31+
fn _vectors_vecf32_operator_mul(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> Vecf32Output {
32+
let n = check_matched_dims(lhs.dims(), rhs.dims());
33+
let mut v = vec![F32::zero(); n];
34+
for i in 0..n {
35+
v[i] = lhs[i] * rhs[i];
36+
}
37+
Vecf32Output::new(Vecf32Borrowed::new(&v))
38+
}
39+
2940
#[pgrx::pg_extern(immutable, strict, parallel_safe)]
3041
fn _vectors_vecf32_operator_lt(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> bool {
3142
check_matched_dims(lhs.dims(), rhs.dims());

src/datatype/operators_veci8.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,20 @@ fn _vectors_veci8_operator_minus(lhs: Veci8Input<'_>, rhs: Veci8Input<'_>) -> Ve
3131
)
3232
}
3333

34+
/// Calculate the element-wise multiplication of two i8 vectors.
35+
#[pgrx::pg_extern(immutable, strict, parallel_safe)]
36+
fn _vectors_veci8_operator_mul(lhs: Veci8Input<'_>, rhs: Veci8Input<'_>) -> Veci8Output {
37+
check_matched_dims(lhs.len(), rhs.len());
38+
let data = (0..lhs.len())
39+
.map(|i| lhs.index(i) * rhs.index(i))
40+
.collect::<Vec<_>>();
41+
let (vector, alpha, offset) = veci8::i8_quantization(&data);
42+
let (sum, l2_norm) = veci8::i8_precompute(&vector, alpha, offset);
43+
Veci8Output::new(
44+
Veci8Borrowed::new_checked(lhs.len() as u32, &vector, alpha, offset, sum, l2_norm).unwrap(),
45+
)
46+
}
47+
3448
#[pgrx::pg_extern(immutable, strict, parallel_safe)]
3549
fn _vectors_veci8_operator_lt(lhs: Veci8Input<'_>, rhs: Veci8Input<'_>) -> bool {
3650
check_matched_dims(lhs.len(), rhs.len());

src/sql/finalize.sql

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,34 @@ CREATE OPERATOR - (
140140
RIGHTARG = veci8
141141
);
142142

143+
CREATE OPERATOR * (
144+
PROCEDURE = _vectors_vecf32_operator_mul,
145+
LEFTARG = vector,
146+
RIGHTARG = vector,
147+
COMMUTATOR = *
148+
);
149+
150+
CREATE OPERATOR * (
151+
PROCEDURE = _vectors_vecf16_operator_mul,
152+
LEFTARG = vecf16,
153+
RIGHTARG = vecf16,
154+
COMMUTATOR = *
155+
);
156+
157+
CREATE OPERATOR * (
158+
PROCEDURE = _vectors_svecf32_operator_mul,
159+
LEFTARG = svector,
160+
RIGHTARG = svector,
161+
COMMUTATOR = *
162+
);
163+
164+
CREATE OPERATOR * (
165+
PROCEDURE = _vectors_veci8_operator_mul,
166+
LEFTARG = veci8,
167+
RIGHTARG = veci8,
168+
COMMUTATOR = *
169+
);
170+
143171
CREATE OPERATOR & (
144172
PROCEDURE = _vectors_bvecf32_operator_and,
145173
LEFTARG = bvector,

tests/sqllogictest/fp16.slt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,5 +35,10 @@ SELECT COUNT(1) FROM (SELECT 1 FROM t ORDER BY val <=> '[0.5,0.5,0.5]'::vecf16 l
3535
----
3636
10
3737

38+
query I
39+
SELECT '[1,2,3]'::vecf16 * '[4,5,6]'::vecf16;
40+
----
41+
[4, 10, 18]
42+
3843
statement ok
3944
DROP TABLE t;

tests/sqllogictest/int8.slt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,11 @@ SELECT to_veci8(5, 1, 0, '{0,1,2,0,0}');
4343
----
4444
[0, 1, 2, 0, 0]
4545

46+
query I
47+
SELECT '[2,2,2]'::veci8 * '[2,2,2]'::veci8;
48+
----
49+
[4, 4, 4]
50+
4651
statement error Lengths of values and len are not matched.
4752
SELECT to_veci8(5, 1, 0, '{0,1,2,0,0,0}');
4853

tests/sqllogictest/sparse.slt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,11 @@ SELECT to_svector(5, '{1,2}', '{1,2}');
4242
----
4343
[0, 1, 2, 0, 0]
4444

45+
query I
46+
SELECT to_svector(5, '{1,2}', '{1,1}') * to_svector(5, '{1,3}', '{2,2}');
47+
----
48+
[0, 2, 0, 0, 0]
49+
4550
statement error Lengths of index and value are not matched.
4651
SELECT to_svector(5, '{1,2,3}', '{1,2}');
4752

tests/sqllogictest/vector.slt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,11 @@ SELECT vector_dims(v) FROM unnest(ARRAY['[1,2]'::vector, '[3]']) v;
4141
2
4242
1
4343

44+
query I
45+
SELECT '[1,2,3]'::vector * '[4,5,6]'::vector;
46+
----
47+
[4, 10, 18]
48+
4449
query ?
4550
SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]']) v;
4651
----

0 commit comments

Comments
 (0)