Skip to content

Commit 03e4d4c

Browse files
committed
feat(query): Implement Vector Index with HNSW Algorithm
1 parent 28aa33b commit 03e4d4c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+3809
-3
lines changed

Cargo.lock

Lines changed: 17 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,8 @@ bstr = "1"
269269
buf-list = "1.0.3"
270270
bumpalo = "3.12.0"
271271
byte-unit = "5.1.6"
272-
bytemuck = { version = "1", features = ["derive"] }
272+
#bytemuck = { version = "1", features = ["derive"] }
273+
bytemuck = { version = "1", features = ["derive", "extern_crate_alloc", "must_cast", "transparentwrapper_extra"] }
273274
byteorder = "1.4.3"
274275
bytes = "1.5.0"
275276
bytesize = "1.1.0"

src/query/functions/src/scalars/vector.rs

Lines changed: 175 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,33 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
use std::sync::Arc;
16+
1517
use databend_common_expression::types::ArrayType;
1618
use databend_common_expression::types::Buffer;
19+
use databend_common_expression::types::DataType;
1720
use databend_common_expression::types::Float32Type;
1821
use databend_common_expression::types::Float64Type;
22+
use databend_common_expression::types::NumberColumn;
23+
use databend_common_expression::types::NumberDataType;
24+
use databend_common_expression::types::NumberScalar;
1925
use databend_common_expression::types::StringType;
26+
use databend_common_expression::types::VectorDataType;
27+
use databend_common_expression::types::VectorScalarRef;
2028
use databend_common_expression::types::F32;
2129
use databend_common_expression::types::F64;
2230
use databend_common_expression::vectorize_with_builder_1_arg;
2331
use databend_common_expression::vectorize_with_builder_2_arg;
32+
use databend_common_expression::Column;
33+
use databend_common_expression::Function;
2434
use databend_common_expression::FunctionDomain;
35+
use databend_common_expression::FunctionEval;
36+
use databend_common_expression::FunctionFactory;
2537
use databend_common_expression::FunctionRegistry;
38+
use databend_common_expression::FunctionSignature;
39+
use databend_common_expression::Scalar;
40+
use databend_common_expression::ScalarRef;
41+
use databend_common_expression::Value;
2642
use databend_common_openai::OpenAI;
2743
use databend_common_vector::cosine_distance;
2844
use databend_common_vector::cosine_distance_64;
@@ -37,7 +53,7 @@ pub fn register(registry: &mut FunctionRegistry) {
3753
|_, _, _| FunctionDomain::MayThrow,
3854
vectorize_with_builder_2_arg::<ArrayType<Float32Type>, ArrayType<Float32Type>, Float32Type>(
3955
|lhs, rhs, output, ctx| {
40-
let l=
56+
let l =
4157
unsafe { std::mem::transmute::<Buffer<F32>, Buffer<f32>>(lhs) };
4258
let r =
4359
unsafe { std::mem::transmute::<Buffer<F32>, Buffer<f32>>(rhs) };
@@ -63,7 +79,7 @@ pub fn register(registry: &mut FunctionRegistry) {
6379
|_, _, _| FunctionDomain::MayThrow,
6480
vectorize_with_builder_2_arg::<ArrayType<Float32Type>, ArrayType<Float32Type>, Float32Type>(
6581
|lhs, rhs, output, ctx| {
66-
let l=
82+
let l =
6783
unsafe { std::mem::transmute::<Buffer<F32>, Buffer<f32>>(lhs) };
6884
let r =
6985
unsafe { std::mem::transmute::<Buffer<F32>, Buffer<f32>>(rhs) };
@@ -226,4 +242,161 @@ pub fn register(registry: &mut FunctionRegistry) {
226242
}
227243
}),
228244
);
245+
246+
let cosine_distance_factory =
247+
FunctionFactory::Closure(Box::new(|_, args_type: &[DataType]| {
248+
if args_type.len() != 2 {
249+
return None;
250+
}
251+
let args_type0 = args_type[0].remove_nullable();
252+
let vector_type0 = args_type0.as_vector()?;
253+
let args_type1 = args_type[1].remove_nullable();
254+
let vector_type1 = args_type1.as_vector()?;
255+
match (vector_type0, vector_type1) {
256+
(VectorDataType::Int8(dim0), VectorDataType::Int8(dim1)) => {
257+
if dim0 != dim1 {
258+
return None;
259+
}
260+
}
261+
(VectorDataType::Float32(dim0), VectorDataType::Float32(dim1)) => {
262+
if dim0 != dim1 {
263+
return None;
264+
}
265+
}
266+
(_, _) => {
267+
return None;
268+
}
269+
}
270+
let args_type = args_type.to_vec();
271+
Some(Arc::new(Function {
272+
signature: FunctionSignature {
273+
name: "cosine_distance".to_string(),
274+
args_type: args_type.clone(),
275+
return_type: DataType::Number(NumberDataType::Float32),
276+
},
277+
eval: FunctionEval::Scalar {
278+
calc_domain: Box::new(|_, _| FunctionDomain::Full),
279+
eval: Box::new(move |args, _| {
280+
let len_opt = args.iter().find_map(|arg| match arg {
281+
Value::Column(col) => Some(col.len()),
282+
_ => None,
283+
});
284+
let len = len_opt.unwrap_or(1);
285+
let mut builder = Vec::with_capacity(len);
286+
for i in 0..len {
287+
let lhs = unsafe { args[0].index_unchecked(i) };
288+
let rhs = unsafe { args[1].index_unchecked(i) };
289+
match (lhs, rhs) {
290+
(
291+
ScalarRef::Vector(VectorScalarRef::Int8(lhs)),
292+
ScalarRef::Vector(VectorScalarRef::Int8(rhs)),
293+
) => {
294+
let l: Vec<_> = lhs.iter().map(|v| *v as f32).collect();
295+
let r: Vec<_> = rhs.iter().map(|v| *v as f32).collect();
296+
let dist = cosine_distance(l.as_slice(), r.as_slice()).unwrap();
297+
builder.push(F32::from(dist));
298+
}
299+
(
300+
ScalarRef::Vector(VectorScalarRef::Float32(lhs)),
301+
ScalarRef::Vector(VectorScalarRef::Float32(rhs)),
302+
) => {
303+
let l = unsafe { std::mem::transmute::<&[F32], &[f32]>(lhs) };
304+
let r = unsafe { std::mem::transmute::<&[F32], &[f32]>(rhs) };
305+
let dist = cosine_distance(l, r).unwrap();
306+
builder.push(F32::from(dist));
307+
}
308+
(_, _) => {
309+
builder.push(F32::from(0.0));
310+
}
311+
}
312+
}
313+
if len_opt.is_some() {
314+
Value::Column(Column::Number(NumberColumn::Float32(Buffer::from(
315+
builder,
316+
))))
317+
} else {
318+
Value::Scalar(Scalar::Number(NumberScalar::Float32(builder[0])))
319+
}
320+
}),
321+
},
322+
}))
323+
}));
324+
registry.register_function_factory("cosine_distance", cosine_distance_factory);
325+
326+
let l2_distance_factory = FunctionFactory::Closure(Box::new(|_, args_type: &[DataType]| {
327+
if args_type.len() != 2 {
328+
return None;
329+
}
330+
let args_type0 = args_type[0].remove_nullable();
331+
let vector_type0 = args_type0.as_vector()?;
332+
let args_type1 = args_type[1].remove_nullable();
333+
let vector_type1 = args_type1.as_vector()?;
334+
match (vector_type0, vector_type1) {
335+
(VectorDataType::Int8(dim0), VectorDataType::Int8(dim1)) => {
336+
if dim0 != dim1 {
337+
return None;
338+
}
339+
}
340+
(VectorDataType::Float32(dim0), VectorDataType::Float32(dim1)) => {
341+
if dim0 != dim1 {
342+
return None;
343+
}
344+
}
345+
(_, _) => {
346+
return None;
347+
}
348+
}
349+
let args_type = args_type.to_vec();
350+
Some(Arc::new(Function {
351+
signature: FunctionSignature {
352+
name: "l2_distance".to_string(),
353+
args_type: args_type.clone(),
354+
return_type: DataType::Number(NumberDataType::Float32),
355+
},
356+
eval: FunctionEval::Scalar {
357+
calc_domain: Box::new(|_, _| FunctionDomain::Full),
358+
eval: Box::new(move |args, _| {
359+
let len_opt = args.iter().find_map(|arg| match arg {
360+
Value::Column(col) => Some(col.len()),
361+
_ => None,
362+
});
363+
let len = len_opt.unwrap_or(1);
364+
let mut builder = Vec::with_capacity(len);
365+
for i in 0..len {
366+
let lhs = unsafe { args[0].index_unchecked(i) };
367+
let rhs = unsafe { args[1].index_unchecked(i) };
368+
match (lhs, rhs) {
369+
(
370+
ScalarRef::Vector(VectorScalarRef::Int8(lhs)),
371+
ScalarRef::Vector(VectorScalarRef::Int8(rhs)),
372+
) => {
373+
let l: Vec<_> = lhs.iter().map(|v| *v as f32).collect();
374+
let r: Vec<_> = rhs.iter().map(|v| *v as f32).collect();
375+
let dist = l2_distance(l.as_slice(), r.as_slice()).unwrap();
376+
builder.push(F32::from(dist));
377+
}
378+
(
379+
ScalarRef::Vector(VectorScalarRef::Float32(lhs)),
380+
ScalarRef::Vector(VectorScalarRef::Float32(rhs)),
381+
) => {
382+
let l = unsafe { std::mem::transmute::<&[F32], &[f32]>(lhs) };
383+
let r = unsafe { std::mem::transmute::<&[F32], &[f32]>(rhs) };
384+
let dist = l2_distance(l, r).unwrap();
385+
builder.push(F32::from(dist));
386+
}
387+
(_, _) => {
388+
builder.push(F32::from(0.0));
389+
}
390+
}
391+
}
392+
if len_opt.is_some() {
393+
Value::Column(Column::Number(NumberColumn::Float32(Buffer::from(builder))))
394+
} else {
395+
Value::Scalar(Scalar::Number(NumberScalar::Float32(builder[0])))
396+
}
397+
}),
398+
},
399+
}))
400+
}));
401+
registry.register_function_factory("l2_distance", l2_distance_factory);
229402
}

src/query/functions/tests/it/scalars/testdata/function_list.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1266,6 +1266,7 @@ Functions overloads:
12661266
1 cosine_distance(Array(Float32) NULL, Array(Float32) NULL) :: Float32 NULL
12671267
2 cosine_distance(Array(Float64), Array(Float64)) :: Float64
12681268
3 cosine_distance(Array(Float64) NULL, Array(Float64) NULL) :: Float64 NULL
1269+
4 cosine_distance FACTORY
12691270
0 cot(Float64) :: Float64
12701271
1 cot(Float64 NULL) :: Float64 NULL
12711272
0 crc32(String) :: UInt32
@@ -2231,6 +2232,7 @@ Functions overloads:
22312232
1 l2_distance(Array(Float32) NULL, Array(Float32) NULL) :: Float32 NULL
22322233
2 l2_distance(Array(Float64), Array(Float64)) :: Float64
22332234
3 l2_distance(Array(Float64) NULL, Array(Float64) NULL) :: Float64 NULL
2235+
4 l2_distance FACTORY
22342236
0 left(String, UInt64) :: String
22352237
1 left(String NULL, UInt64 NULL) :: String NULL
22362238
0 length(Variant NULL) :: UInt32 NULL

src/query/service/src/test_kits/block_writer.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,8 @@ impl<'a> BlockWriter<'a> {
110110
None,
111111
None,
112112
None,
113+
None,
114+
None,
113115
Compression::Lz4Raw,
114116
Some(Utc::now()),
115117
);

src/query/service/tests/it/storages/fuse/bloom_index_meta_size.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,8 @@ fn build_test_segment_info(
335335
bloom_filter_index_size: 0,
336336
inverted_index_size: None,
337337
ngram_filter_index_size: None,
338+
vector_index_size: None,
339+
vector_index_location: None,
338340
virtual_block_meta: None,
339341
compression: Compression::Lz4,
340342
create_on: Some(Utc::now()),

src/query/service/tests/it/storages/fuse/operations/mutation/recluster_mutator.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@ async fn test_recluster_mutator_block_select() -> Result<()> {
7878
None,
7979
None,
8080
None,
81+
None,
82+
None,
8183
meta::Compression::Lz4Raw,
8284
Some(Utc::now()),
8385
));

src/query/service/tests/it/storages/fuse/operations/mutation/segments_compact_mutator.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -774,6 +774,8 @@ impl CompactSegmentTestFixture {
774774
None,
775775
None,
776776
None,
777+
None,
778+
None,
777779
Compression::Lz4Raw,
778780
Some(Utc::now()),
779781
);

src/query/service/tests/it/storages/fuse/operations/read_plan.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,8 @@ fn test_to_partitions() -> Result<()> {
105105
None,
106106
None,
107107
None,
108+
None,
109+
None,
108110
meta::Compression::Lz4Raw,
109111
Some(Utc::now()),
110112
));

src/query/service/tests/it/storages/fuse/statistics.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -629,6 +629,8 @@ fn test_reduce_block_meta() -> databend_common_exception::Result<()> {
629629
None,
630630
None,
631631
None,
632+
None,
633+
None,
632634
Compression::Lz4Raw,
633635
Some(Utc::now()),
634636
);

0 commit comments

Comments
 (0)