12
12
// See the License for the specific language governing permissions and
13
13
// limitations under the License.
14
14
15
+ use std:: sync:: Arc ;
16
+
15
17
use databend_common_expression:: types:: ArrayType ;
16
18
use databend_common_expression:: types:: Buffer ;
19
+ use databend_common_expression:: types:: DataType ;
17
20
use databend_common_expression:: types:: Float32Type ;
18
21
use databend_common_expression:: types:: Float64Type ;
22
+ use databend_common_expression:: types:: NumberColumn ;
23
+ use databend_common_expression:: types:: NumberDataType ;
24
+ use databend_common_expression:: types:: NumberScalar ;
19
25
use databend_common_expression:: types:: StringType ;
26
+ use databend_common_expression:: types:: VectorDataType ;
27
+ use databend_common_expression:: types:: VectorScalarRef ;
20
28
use databend_common_expression:: types:: F32 ;
21
29
use databend_common_expression:: types:: F64 ;
22
30
use databend_common_expression:: vectorize_with_builder_1_arg;
23
31
use databend_common_expression:: vectorize_with_builder_2_arg;
32
+ use databend_common_expression:: Column ;
33
+ use databend_common_expression:: Function ;
24
34
use databend_common_expression:: FunctionDomain ;
35
+ use databend_common_expression:: FunctionEval ;
36
+ use databend_common_expression:: FunctionFactory ;
25
37
use databend_common_expression:: FunctionRegistry ;
38
+ use databend_common_expression:: FunctionSignature ;
39
+ use databend_common_expression:: Scalar ;
40
+ use databend_common_expression:: ScalarRef ;
41
+ use databend_common_expression:: Value ;
26
42
use databend_common_openai:: OpenAI ;
27
43
use databend_common_vector:: cosine_distance;
28
44
use databend_common_vector:: cosine_distance_64;
@@ -37,7 +53,7 @@ pub fn register(registry: &mut FunctionRegistry) {
37
53
|_, _, _| FunctionDomain :: MayThrow ,
38
54
vectorize_with_builder_2_arg :: < ArrayType < Float32Type > , ArrayType < Float32Type > , Float32Type > (
39
55
|lhs, rhs, output, ctx| {
40
- let l=
56
+ let l =
41
57
unsafe { std:: mem:: transmute :: < Buffer < F32 > , Buffer < f32 > > ( lhs) } ;
42
58
let r =
43
59
unsafe { std:: mem:: transmute :: < Buffer < F32 > , Buffer < f32 > > ( rhs) } ;
@@ -63,7 +79,7 @@ pub fn register(registry: &mut FunctionRegistry) {
63
79
|_, _, _| FunctionDomain :: MayThrow ,
64
80
vectorize_with_builder_2_arg :: < ArrayType < Float32Type > , ArrayType < Float32Type > , Float32Type > (
65
81
|lhs, rhs, output, ctx| {
66
- let l=
82
+ let l =
67
83
unsafe { std:: mem:: transmute :: < Buffer < F32 > , Buffer < f32 > > ( lhs) } ;
68
84
let r =
69
85
unsafe { std:: mem:: transmute :: < Buffer < F32 > , Buffer < f32 > > ( rhs) } ;
@@ -226,4 +242,161 @@ pub fn register(registry: &mut FunctionRegistry) {
226
242
}
227
243
} ) ,
228
244
) ;
245
+
246
+ let cosine_distance_factory =
247
+ FunctionFactory :: Closure ( Box :: new ( |_, args_type : & [ DataType ] | {
248
+ if args_type. len ( ) != 2 {
249
+ return None ;
250
+ }
251
+ let args_type0 = args_type[ 0 ] . remove_nullable ( ) ;
252
+ let vector_type0 = args_type0. as_vector ( ) ?;
253
+ let args_type1 = args_type[ 1 ] . remove_nullable ( ) ;
254
+ let vector_type1 = args_type1. as_vector ( ) ?;
255
+ match ( vector_type0, vector_type1) {
256
+ ( VectorDataType :: Int8 ( dim0) , VectorDataType :: Int8 ( dim1) ) => {
257
+ if dim0 != dim1 {
258
+ return None ;
259
+ }
260
+ }
261
+ ( VectorDataType :: Float32 ( dim0) , VectorDataType :: Float32 ( dim1) ) => {
262
+ if dim0 != dim1 {
263
+ return None ;
264
+ }
265
+ }
266
+ ( _, _) => {
267
+ return None ;
268
+ }
269
+ }
270
+ let args_type = args_type. to_vec ( ) ;
271
+ Some ( Arc :: new ( Function {
272
+ signature : FunctionSignature {
273
+ name : "cosine_distance" . to_string ( ) ,
274
+ args_type : args_type. clone ( ) ,
275
+ return_type : DataType :: Number ( NumberDataType :: Float32 ) ,
276
+ } ,
277
+ eval : FunctionEval :: Scalar {
278
+ calc_domain : Box :: new ( |_, _| FunctionDomain :: Full ) ,
279
+ eval : Box :: new ( move |args, _| {
280
+ let len_opt = args. iter ( ) . find_map ( |arg| match arg {
281
+ Value :: Column ( col) => Some ( col. len ( ) ) ,
282
+ _ => None ,
283
+ } ) ;
284
+ let len = len_opt. unwrap_or ( 1 ) ;
285
+ let mut builder = Vec :: with_capacity ( len) ;
286
+ for i in 0 ..len {
287
+ let lhs = unsafe { args[ 0 ] . index_unchecked ( i) } ;
288
+ let rhs = unsafe { args[ 1 ] . index_unchecked ( i) } ;
289
+ match ( lhs, rhs) {
290
+ (
291
+ ScalarRef :: Vector ( VectorScalarRef :: Int8 ( lhs) ) ,
292
+ ScalarRef :: Vector ( VectorScalarRef :: Int8 ( rhs) ) ,
293
+ ) => {
294
+ let l: Vec < _ > = lhs. iter ( ) . map ( |v| * v as f32 ) . collect ( ) ;
295
+ let r: Vec < _ > = rhs. iter ( ) . map ( |v| * v as f32 ) . collect ( ) ;
296
+ let dist = cosine_distance ( l. as_slice ( ) , r. as_slice ( ) ) . unwrap ( ) ;
297
+ builder. push ( F32 :: from ( dist) ) ;
298
+ }
299
+ (
300
+ ScalarRef :: Vector ( VectorScalarRef :: Float32 ( lhs) ) ,
301
+ ScalarRef :: Vector ( VectorScalarRef :: Float32 ( rhs) ) ,
302
+ ) => {
303
+ let l = unsafe { std:: mem:: transmute :: < & [ F32 ] , & [ f32 ] > ( lhs) } ;
304
+ let r = unsafe { std:: mem:: transmute :: < & [ F32 ] , & [ f32 ] > ( rhs) } ;
305
+ let dist = cosine_distance ( l, r) . unwrap ( ) ;
306
+ builder. push ( F32 :: from ( dist) ) ;
307
+ }
308
+ ( _, _) => {
309
+ builder. push ( F32 :: from ( 0.0 ) ) ;
310
+ }
311
+ }
312
+ }
313
+ if len_opt. is_some ( ) {
314
+ Value :: Column ( Column :: Number ( NumberColumn :: Float32 ( Buffer :: from (
315
+ builder,
316
+ ) ) ) )
317
+ } else {
318
+ Value :: Scalar ( Scalar :: Number ( NumberScalar :: Float32 ( builder[ 0 ] ) ) )
319
+ }
320
+ } ) ,
321
+ } ,
322
+ } ) )
323
+ } ) ) ;
324
+ registry. register_function_factory ( "cosine_distance" , cosine_distance_factory) ;
325
+
326
+ let l2_distance_factory = FunctionFactory :: Closure ( Box :: new ( |_, args_type : & [ DataType ] | {
327
+ if args_type. len ( ) != 2 {
328
+ return None ;
329
+ }
330
+ let args_type0 = args_type[ 0 ] . remove_nullable ( ) ;
331
+ let vector_type0 = args_type0. as_vector ( ) ?;
332
+ let args_type1 = args_type[ 1 ] . remove_nullable ( ) ;
333
+ let vector_type1 = args_type1. as_vector ( ) ?;
334
+ match ( vector_type0, vector_type1) {
335
+ ( VectorDataType :: Int8 ( dim0) , VectorDataType :: Int8 ( dim1) ) => {
336
+ if dim0 != dim1 {
337
+ return None ;
338
+ }
339
+ }
340
+ ( VectorDataType :: Float32 ( dim0) , VectorDataType :: Float32 ( dim1) ) => {
341
+ if dim0 != dim1 {
342
+ return None ;
343
+ }
344
+ }
345
+ ( _, _) => {
346
+ return None ;
347
+ }
348
+ }
349
+ let args_type = args_type. to_vec ( ) ;
350
+ Some ( Arc :: new ( Function {
351
+ signature : FunctionSignature {
352
+ name : "l2_distance" . to_string ( ) ,
353
+ args_type : args_type. clone ( ) ,
354
+ return_type : DataType :: Number ( NumberDataType :: Float32 ) ,
355
+ } ,
356
+ eval : FunctionEval :: Scalar {
357
+ calc_domain : Box :: new ( |_, _| FunctionDomain :: Full ) ,
358
+ eval : Box :: new ( move |args, _| {
359
+ let len_opt = args. iter ( ) . find_map ( |arg| match arg {
360
+ Value :: Column ( col) => Some ( col. len ( ) ) ,
361
+ _ => None ,
362
+ } ) ;
363
+ let len = len_opt. unwrap_or ( 1 ) ;
364
+ let mut builder = Vec :: with_capacity ( len) ;
365
+ for i in 0 ..len {
366
+ let lhs = unsafe { args[ 0 ] . index_unchecked ( i) } ;
367
+ let rhs = unsafe { args[ 1 ] . index_unchecked ( i) } ;
368
+ match ( lhs, rhs) {
369
+ (
370
+ ScalarRef :: Vector ( VectorScalarRef :: Int8 ( lhs) ) ,
371
+ ScalarRef :: Vector ( VectorScalarRef :: Int8 ( rhs) ) ,
372
+ ) => {
373
+ let l: Vec < _ > = lhs. iter ( ) . map ( |v| * v as f32 ) . collect ( ) ;
374
+ let r: Vec < _ > = rhs. iter ( ) . map ( |v| * v as f32 ) . collect ( ) ;
375
+ let dist = l2_distance ( l. as_slice ( ) , r. as_slice ( ) ) . unwrap ( ) ;
376
+ builder. push ( F32 :: from ( dist) ) ;
377
+ }
378
+ (
379
+ ScalarRef :: Vector ( VectorScalarRef :: Float32 ( lhs) ) ,
380
+ ScalarRef :: Vector ( VectorScalarRef :: Float32 ( rhs) ) ,
381
+ ) => {
382
+ let l = unsafe { std:: mem:: transmute :: < & [ F32 ] , & [ f32 ] > ( lhs) } ;
383
+ let r = unsafe { std:: mem:: transmute :: < & [ F32 ] , & [ f32 ] > ( rhs) } ;
384
+ let dist = l2_distance ( l, r) . unwrap ( ) ;
385
+ builder. push ( F32 :: from ( dist) ) ;
386
+ }
387
+ ( _, _) => {
388
+ builder. push ( F32 :: from ( 0.0 ) ) ;
389
+ }
390
+ }
391
+ }
392
+ if len_opt. is_some ( ) {
393
+ Value :: Column ( Column :: Number ( NumberColumn :: Float32 ( Buffer :: from ( builder) ) ) )
394
+ } else {
395
+ Value :: Scalar ( Scalar :: Number ( NumberScalar :: Float32 ( builder[ 0 ] ) ) )
396
+ }
397
+ } ) ,
398
+ } ,
399
+ } ) )
400
+ } ) ) ;
401
+ registry. register_function_factory ( "l2_distance" , l2_distance_factory) ;
229
402
}
0 commit comments