14
14
15
15
use std:: collections:: BTreeMap ;
16
16
use std:: collections:: HashMap ;
17
+ use std:: collections:: HashSet ;
17
18
use std:: string:: String ;
18
19
use std:: sync:: Arc ;
19
20
@@ -32,6 +33,7 @@ use databend_common_expression::types::StringType;
32
33
use databend_common_expression:: types:: ValueType ;
33
34
use databend_common_expression:: with_integer_mapped_type;
34
35
use databend_common_expression:: BlockEntry ;
36
+ use databend_common_expression:: Column ;
35
37
use databend_common_expression:: ColumnBuilder ;
36
38
use databend_common_expression:: DataBlock ;
37
39
use databend_common_expression:: Scalar ;
@@ -53,6 +55,92 @@ use crate::sql::plans::DictGetFunctionArgument;
53
55
use crate :: sql:: plans:: DictionarySource ;
54
56
use crate :: sql:: IndexType ;
55
57
58
+ macro_rules! sqlx_fetch_optional {
59
+ ( $pool: expr, $sql: expr, $key_type: ty, $val_type: ty, $format_val_fn: expr) => { {
60
+ let res: Option <( $key_type, $val_type) > =
61
+ sqlx:: query_as( & $sql) . fetch_optional( $pool) . await ?;
62
+ Ok ( res. map( |( _, v) | $format_val_fn( v) ) )
63
+ } } ;
64
+ }
65
+
66
+ macro_rules! fetch_single_row_by_sqlx {
67
+ ( $pool: expr, $sql: expr, $key_scalar: expr, $val_type: ty, $format_val_fn: expr) => { {
68
+ match $key_scalar {
69
+ DataType :: Boolean => {
70
+ sqlx_fetch_optional!( $pool, $sql, bool , $val_type, $format_val_fn)
71
+ }
72
+ DataType :: String => {
73
+ sqlx_fetch_optional!( $pool, $sql, String , $val_type, $format_val_fn)
74
+ }
75
+ DataType :: Number ( num_ty) => with_integer_mapped_type!( |KEY_NUM_TYPE | match num_ty {
76
+ NumberDataType :: KEY_NUM_TYPE => {
77
+ sqlx_fetch_optional!( $pool, $sql, KEY_NUM_TYPE , $val_type, $format_val_fn)
78
+ }
79
+ NumberDataType :: Float32 => {
80
+ sqlx_fetch_optional!( $pool, $sql, f32 , $val_type, $format_val_fn)
81
+ }
82
+ NumberDataType :: Float64 => {
83
+ sqlx_fetch_optional!( $pool, $sql, f64 , $val_type, $format_val_fn)
84
+ }
85
+ } ) ,
86
+ _ => Err ( ErrorCode :: DictionarySourceError ( format!(
87
+ "MySQL dictionary operator currently does not support value type {}" ,
88
+ $key_scalar,
89
+ ) ) ) ,
90
+ }
91
+ } } ;
92
+ }
93
+
94
+ macro_rules! fetch_all_rows_by_sqlx {
95
+ ( $pool: expr, $sql: expr, $key_scalar: expr, $val_type: ty, $format_key_fn: expr) => {
96
+ match $key_scalar {
97
+ DataType :: Boolean => {
98
+ let res: Vec <( bool , $val_type) > = sqlx:: query_as( $sql) . fetch_all( $pool) . await ?;
99
+ res. into_iter( )
100
+ . map( |( k, v) | ( $format_key_fn( ScalarRef :: Boolean ( k) ) , v) )
101
+ . collect( )
102
+ }
103
+ DataType :: String => {
104
+ let res: Vec <( String , $val_type) > = sqlx:: query_as( $sql) . fetch_all( $pool) . await ?;
105
+ res. into_iter( )
106
+ . map( |( k, v) | ( $format_key_fn( ScalarRef :: String ( & k) ) , v) )
107
+ . collect( )
108
+ }
109
+ DataType :: Number ( num_ty) => {
110
+ with_integer_mapped_type!( |NUM_TYPE | match num_ty {
111
+ NumberDataType :: NUM_TYPE => {
112
+ let res: Vec <( NUM_TYPE , $val_type) > =
113
+ sqlx:: query_as( $sql) . fetch_all( $pool) . await ?;
114
+ res. into_iter( )
115
+ . map( |( k, v) | ( format!( "{}" , k) , v) )
116
+ . collect( )
117
+ }
118
+ NumberDataType :: Float32 => {
119
+ let res: Vec <( f32 , $val_type) > =
120
+ sqlx:: query_as( $sql) . fetch_all( $pool) . await ?;
121
+ res. into_iter( )
122
+ . map( |( k, v) | ( format!( "{}" , k) , v) )
123
+ . collect( )
124
+ }
125
+ NumberDataType :: Float64 => {
126
+ let res: Vec <( f64 , $val_type) > =
127
+ sqlx:: query_as( $sql) . fetch_all( $pool) . await ?;
128
+ res. into_iter( )
129
+ . map( |( k, v) | ( format!( "{}" , k) , v) )
130
+ . collect( )
131
+ }
132
+ } )
133
+ }
134
+ _ => {
135
+ return Err ( ErrorCode :: DictionarySourceError ( format!(
136
+ "MySQL dictionary operator currently does not support value type: {}" ,
137
+ $key_scalar
138
+ ) ) ) ;
139
+ }
140
+ }
141
+ } ;
142
+ }
143
+
56
144
pub ( crate ) enum DictionaryOperator {
57
145
Redis ( ConnectionManager ) ,
58
146
Mysql ( ( MySqlPool , String ) ) ,
@@ -95,21 +183,14 @@ impl DictionaryOperator {
95
183
DictionaryOperator :: Mysql ( ( pool, sql) ) => match value {
96
184
Value :: Scalar ( scalar) => {
97
185
let value = self
98
- . get_data_from_mysql ( scalar. as_ref ( ) , data_type, pool, sql)
186
+ . get_scalar_value_from_mysql ( scalar. as_ref ( ) , data_type, pool, sql)
99
187
. await ?
100
188
. unwrap_or ( default_value. clone ( ) ) ;
101
189
Ok ( Value :: Scalar ( value) )
102
190
}
103
191
Value :: Column ( column) => {
104
- let mut builder = ColumnBuilder :: with_capacity ( data_type, column. len ( ) ) ;
105
- for scalar_ref in column. iter ( ) {
106
- let value = self
107
- . get_data_from_mysql ( scalar_ref, data_type, pool, sql)
108
- . await ?
109
- . unwrap_or ( default_value. clone ( ) ) ;
110
- builder. push ( value. as_ref ( ) ) ;
111
- }
112
- Ok ( Value :: Column ( builder. build ( ) ) )
192
+ self . get_column_values_from_mysql ( column, data_type, default_value, pool, sql)
193
+ . await
113
194
}
114
195
} ,
115
196
}
@@ -239,72 +320,174 @@ impl DictionaryOperator {
239
320
}
240
321
}
241
322
242
- async fn get_data_from_mysql (
323
+ async fn get_scalar_value_from_mysql (
243
324
& self ,
244
325
key : ScalarRef < ' _ > ,
245
- data_type : & DataType ,
326
+ value_type : & DataType ,
246
327
pool : & MySqlPool ,
247
328
sql : & String ,
248
329
) -> Result < Option < Scalar > > {
249
330
if key == ScalarRef :: Null {
250
331
return Ok ( None ) ;
251
332
}
252
- match data_type. remove_nullable ( ) {
333
+ let new_sql = format ! ( "{} ({}) LIMIT 1" , sql, self . format_key( key. clone( ) ) ) ;
334
+ let key_type = key. infer_data_type ( ) . remove_nullable ( ) ;
335
+ match value_type. remove_nullable ( ) {
253
336
DataType :: Boolean => {
254
- let value: Option < bool > = sqlx:: query_scalar ( sql)
255
- . bind ( self . format_key ( key) )
256
- . fetch_optional ( pool)
257
- . await ?;
258
- Ok ( value. map ( Scalar :: Boolean ) )
337
+ fetch_single_row_by_sqlx ! ( pool, new_sql, key_type, bool , Scalar :: Boolean )
259
338
}
260
339
DataType :: String => {
261
- let value: Option < String > = sqlx:: query_scalar ( sql)
262
- . bind ( self . format_key ( key) )
263
- . fetch_optional ( pool)
264
- . await ?;
265
- Ok ( value. map ( Scalar :: String ) )
340
+ fetch_single_row_by_sqlx ! ( pool, new_sql, key_type, String , Scalar :: String )
266
341
}
267
342
DataType :: Number ( num_ty) => {
268
343
with_integer_mapped_type ! ( |NUM_TYPE | match num_ty {
269
344
NumberDataType :: NUM_TYPE => {
270
- let value: Option <NUM_TYPE > = sqlx:: query_scalar( & sql)
271
- . bind( self . format_key( key) )
272
- . fetch_optional( pool)
273
- . await ?;
274
- Ok ( value. map( |v| Scalar :: Number ( NUM_TYPE :: upcast_scalar( v) ) ) )
345
+ fetch_single_row_by_sqlx!( pool, new_sql, key_type, NUM_TYPE , |v| {
346
+ Scalar :: Number ( NUM_TYPE :: upcast_scalar( v) )
347
+ } )
275
348
}
276
349
NumberDataType :: Float32 => {
277
- let value: Option <f32 > = sqlx:: query_scalar( sql)
278
- . bind( self . format_key( key) )
279
- . fetch_optional( pool)
280
- . await ?;
281
- Ok ( value. map( |v| Scalar :: Number ( NumberScalar :: Float32 ( v. into( ) ) ) ) )
350
+ fetch_single_row_by_sqlx!( pool, new_sql, key_type, f32 , |v: f32 | {
351
+ Scalar :: Number ( NumberScalar :: Float32 ( v. into( ) ) )
352
+ } )
282
353
}
283
354
NumberDataType :: Float64 => {
284
- let value: Option <f64 > = sqlx:: query_scalar( sql)
285
- . bind( self . format_key( key) )
286
- . fetch_optional( pool)
287
- . await ?;
288
- Ok ( value. map( |v| Scalar :: Number ( NumberScalar :: Float64 ( v. into( ) ) ) ) )
355
+ fetch_single_row_by_sqlx!( pool, new_sql, key_type, f64 , |v: f64 | {
356
+ Scalar :: Number ( NumberScalar :: Float64 ( v. into( ) ) )
357
+ } )
289
358
}
290
359
} )
291
360
}
292
361
_ => Err ( ErrorCode :: DictionarySourceError ( format ! (
293
- "MySQL dictionary operator currently does not support value type {data_type }"
362
+ "MySQL dictionary operator currently does not support value type {value_type }"
294
363
) ) ) ,
295
364
}
296
365
}
297
366
367
+ async fn get_column_values_from_mysql (
368
+ & self ,
369
+ column : & Column ,
370
+ value_type : & DataType ,
371
+ default_value : & Scalar ,
372
+ pool : & MySqlPool ,
373
+ sql : & String ,
374
+ ) -> Result < Value < AnyType > > {
375
+ // todo: The current method formats the key as a string, which causes some performance overhead.
376
+ // The next step is to use the key's native types directly, such as bool, i32, etc.
377
+ let key_cnt = column. len ( ) ;
378
+ let mut all_keys = Vec :: with_capacity ( key_cnt) ;
379
+ let mut key_set = HashSet :: with_capacity ( key_cnt) ;
380
+ for item in column. iter ( ) {
381
+ if item != ScalarRef :: Null {
382
+ key_set. insert ( item. clone ( ) ) ;
383
+ }
384
+ all_keys. push ( self . format_key ( item) ) ;
385
+ }
386
+
387
+ let mut builder = ColumnBuilder :: with_capacity ( value_type, key_cnt) ;
388
+ if key_set. is_empty ( ) {
389
+ for _ in 0 ..key_cnt {
390
+ builder. push ( default_value. as_ref ( ) ) ;
391
+ }
392
+ return Ok ( Value :: Column ( builder. build ( ) ) ) ;
393
+ }
394
+ let new_sql = format ! ( "{} ({})" , sql, self . format_keys( key_set) ) ;
395
+ let key_type = column. data_type ( ) . remove_nullable ( ) ;
396
+ match value_type. remove_nullable ( ) {
397
+ DataType :: Boolean => {
398
+ let kv_pairs: HashMap < String , bool > =
399
+ fetch_all_rows_by_sqlx ! ( pool, & new_sql, key_type, bool , |k| self . format_key( k) ) ;
400
+ for key in all_keys {
401
+ match kv_pairs. get ( & key) {
402
+ Some ( v) => builder. push ( Scalar :: Boolean ( * v) . as_ref ( ) ) ,
403
+ None => builder. push ( default_value. as_ref ( ) ) ,
404
+ }
405
+ }
406
+ }
407
+ DataType :: String => {
408
+ let kv_pairs: HashMap < String , String > =
409
+ fetch_all_rows_by_sqlx ! ( pool, & new_sql, key_type, String , |k| self
410
+ . format_key( k) ) ;
411
+ for key in all_keys {
412
+ match kv_pairs. get ( & key) {
413
+ Some ( v) => builder. push ( Scalar :: String ( v. to_string ( ) ) . as_ref ( ) ) ,
414
+ None => builder. push ( default_value. as_ref ( ) ) ,
415
+ }
416
+ }
417
+ }
418
+ DataType :: Number ( num_ty) => {
419
+ with_integer_mapped_type ! ( |NUM_TYPE | match num_ty {
420
+ NumberDataType :: NUM_TYPE => {
421
+ let kv_pairs: HashMap <String , NUM_TYPE > =
422
+ fetch_all_rows_by_sqlx!( pool, & new_sql, key_type, NUM_TYPE , |k| self
423
+ . format_key( k) ) ;
424
+ for key in all_keys {
425
+ match kv_pairs. get( & key) {
426
+ Some ( v) => builder
427
+ . push( Scalar :: Number ( NUM_TYPE :: upcast_scalar( * v) ) . as_ref( ) ) ,
428
+ None => builder. push( default_value. as_ref( ) ) ,
429
+ }
430
+ }
431
+ }
432
+ NumberDataType :: Float32 => {
433
+ let kv_pairs: HashMap <String , f32 > =
434
+ fetch_all_rows_by_sqlx!( pool, & new_sql, key_type, f32 , |k| self
435
+ . format_key( k) ) ;
436
+ for key in all_keys {
437
+ match kv_pairs. get( & key) {
438
+ Some ( v) => builder. push(
439
+ Scalar :: Number ( NumberScalar :: Float32 ( ( * v) . into( ) ) ) . as_ref( ) ,
440
+ ) ,
441
+ None => builder. push( default_value. as_ref( ) ) ,
442
+ }
443
+ }
444
+ }
445
+ NumberDataType :: Float64 => {
446
+ let kv_pairs: HashMap <String , f64 > =
447
+ fetch_all_rows_by_sqlx!( pool, & new_sql, key_type, f64 , |k| self
448
+ . format_key( k) ) ;
449
+ for key in all_keys {
450
+ match kv_pairs. get( & key) {
451
+ Some ( v) => builder. push(
452
+ Scalar :: Number ( NumberScalar :: Float64 ( ( * v) . into( ) ) ) . as_ref( ) ,
453
+ ) ,
454
+ None => builder. push( default_value. as_ref( ) ) ,
455
+ }
456
+ }
457
+ }
458
+ } )
459
+ }
460
+ _ => {
461
+ return Err ( ErrorCode :: DictionarySourceError ( format ! (
462
+ "MySQL dictionary operator currently does not support value type {value_type}"
463
+ ) ) ) ;
464
+ }
465
+ }
466
+ Ok ( Value :: Column ( builder. build ( ) ) )
467
+ }
468
+
469
+ #[ inline]
298
470
fn format_key ( & self , key : ScalarRef < ' _ > ) -> String {
299
471
match key {
300
- ScalarRef :: String ( s) => s . to_string ( ) ,
472
+ ScalarRef :: String ( s) => format ! ( "'{}'" , s . replace ( "'" , " \\ '" ) ) ,
301
473
ScalarRef :: Date ( d) => format ! ( "{}" , date_to_string( d as i64 , & TimeZone :: UTC ) ) ,
302
474
ScalarRef :: Timestamp ( t) => {
303
475
format ! ( "{}" , timestamp_to_string( t, & TimeZone :: UTC ) )
304
476
}
305
477
_ => format ! ( "{}" , key) ,
306
478
}
307
479
}
480
+
481
+ #[ inline]
482
+ fn format_keys ( & self , keys : HashSet < ScalarRef > ) -> String {
483
+ format ! (
484
+ "{}" ,
485
+ keys. into_iter( )
486
+ . map( |key| self . format_key( key) )
487
+ . collect:: <Vec <String >>( )
488
+ . join( "," )
489
+ )
490
+ }
308
491
}
309
492
310
493
impl TransformAsyncFunction {
@@ -339,8 +522,11 @@ impl TransformAsyncFunction {
339
522
sqlx:: MySqlPool :: connect ( & sql_source. connection_url ) ,
340
523
) ?;
341
524
let sql = format ! (
342
- "SELECT {} FROM {} WHERE {} = ? LIMIT 1" ,
343
- & sql_source. value_field, & sql_source. table, & sql_source. key_field
525
+ "SELECT {}, {} FROM {} WHERE {} in" ,
526
+ & sql_source. key_field,
527
+ & sql_source. value_field,
528
+ & sql_source. table,
529
+ & sql_source. key_field
344
530
) ;
345
531
operators. insert ( i, Arc :: new ( DictionaryOperator :: Mysql ( ( mysql_pool, sql) ) ) ) ;
346
532
}
0 commit comments