@@ -19,7 +19,7 @@ use databend_common_expression::types::NumberScalar;
19
19
use databend_common_expression:: BlockEntry ;
20
20
use databend_common_expression:: DataBlock ;
21
21
use databend_common_expression:: Scalar ;
22
- use databend_common_pipeline_transforms:: processors :: Transform ;
22
+ use databend_common_pipeline_transforms:: AccumulatingTransform ;
23
23
24
24
pub struct TransformExpandGroupingSets {
25
25
group_bys : Vec < usize > ,
@@ -35,10 +35,10 @@ impl TransformExpandGroupingSets {
35
35
}
36
36
}
37
37
38
- impl Transform for TransformExpandGroupingSets {
38
+ impl AccumulatingTransform for TransformExpandGroupingSets {
39
39
const NAME : & ' static str = "TransformExpandGroupingSets" ;
40
40
41
- fn transform ( & mut self , data : DataBlock ) -> Result < DataBlock > {
41
+ fn transform ( & mut self , data : DataBlock ) -> Result < Vec < DataBlock > > {
42
42
let num_rows = data. num_rows ( ) ;
43
43
let num_group_bys = self . group_bys . len ( ) ;
44
44
let mut output_blocks = Vec :: with_capacity ( self . grouping_ids . len ( ) ) ;
@@ -48,45 +48,53 @@ impl Transform for TransformExpandGroupingSets {
48
48
. map ( |i| data. get_by_offset ( * i) . clone ( ) )
49
49
. collect :: < Vec < _ > > ( ) ;
50
50
51
+ let mut entries = data
52
+ . columns ( )
53
+ . iter ( )
54
+ . cloned ( )
55
+ . chain ( dup_group_by_cols. clone ( ) )
56
+ . collect :: < Vec < _ > > ( ) ;
57
+
58
+ // all group columns should be nullable
59
+ for i in 0 ..num_group_bys {
60
+ let entry = unsafe {
61
+ let offset = self . group_bys . get_unchecked ( i) ;
62
+ entries. get_unchecked_mut ( * offset)
63
+ } ;
64
+ match entry {
65
+ BlockEntry :: Const ( _, data_type, _) => {
66
+ * data_type = data_type. wrap_nullable ( ) ;
67
+ }
68
+ BlockEntry :: Column ( column) => * column = column. clone ( ) . wrap_nullable ( None ) ,
69
+ } ;
70
+ }
71
+
51
72
for & id in & self . grouping_ids {
52
73
// Repeat data for each grouping set.
53
74
let grouping_id_column = BlockEntry :: new_const_column (
54
75
DataType :: Number ( NumberDataType :: UInt32 ) ,
55
76
Scalar :: Number ( NumberScalar :: UInt32 ( id as u32 ) ) ,
56
77
num_rows,
57
78
) ;
58
- let mut entries = data
59
- . columns ( )
60
- . iter ( )
61
- . cloned ( )
62
- . chain ( dup_group_by_cols. clone ( ) )
63
- . chain ( Some ( grouping_id_column) )
64
- . collect :: < Vec < _ > > ( ) ;
79
+ // This is a copy of entries which clones the buffer of columns
80
+ // So it's memory efficient
81
+ let mut current_group_entries = entries. clone ( ) ;
82
+ current_group_entries. push ( grouping_id_column) ;
83
+
65
84
let bits = !id;
66
85
for i in 0 ..num_group_bys {
67
86
let entry = unsafe {
68
87
let offset = self . group_bys . get_unchecked ( i) ;
69
- entries . get_unchecked_mut ( * offset)
88
+ current_group_entries . get_unchecked_mut ( * offset)
70
89
} ;
90
+ // Reset the column to be nullable
71
91
if bits & ( 1 << i) == 0 {
72
- // This column should be set to NULLs.
73
- * entry = BlockEntry :: new_const_column (
74
- entry. data_type ( ) . wrap_nullable ( ) ,
75
- Scalar :: Null ,
76
- num_rows,
77
- )
78
- } else {
79
- match entry {
80
- BlockEntry :: Const ( _, data_type, _) => {
81
- * data_type = data_type. wrap_nullable ( ) ;
82
- }
83
- BlockEntry :: Column ( column) => * column = column. clone ( ) . wrap_nullable ( None ) ,
84
- } ;
92
+ * entry = BlockEntry :: new_const_column ( entry. data_type ( ) , Scalar :: Null , num_rows)
85
93
}
86
94
}
87
- output_blocks. push ( DataBlock :: new ( entries , num_rows) ) ;
95
+ output_blocks. push ( DataBlock :: new ( current_group_entries , num_rows) ) ;
88
96
}
89
97
90
- DataBlock :: concat ( & output_blocks)
98
+ Ok ( output_blocks)
91
99
}
92
100
}
0 commit comments