@@ -18,9 +18,11 @@ use std::sync::Arc;
18
18
19
19
use bytes:: BytesMut ;
20
20
use common_arrow:: arrow:: bitmap:: Bitmap ;
21
- use common_expression:: ColumnBuilder ;
22
- use common_expression:: Result ;
21
+ use common_exception:: Result ;
22
+ use common_expression:: types:: DataType ;
23
+ use common_expression:: util:: column_merge_validity;
23
24
use common_expression:: Column ;
25
+ use common_expression:: ColumnBuilder ;
24
26
use common_io:: prelude:: BinaryWriteBuf ;
25
27
26
28
use crate :: aggregates:: AggregateFunction ;
@@ -79,10 +81,11 @@ impl<const NULLABLE_RESULT: bool> AggregateFunction for AggregateNullUnaryAdapto
79
81
"AggregateNullUnaryAdaptor"
80
82
}
81
83
82
- fn return_type ( & self ) -> Result < DataTypeImpl > {
84
+ fn return_type ( & self ) -> Result < DataType > {
85
+ let nested = self . nested . return_type ( ) ?;
83
86
match NULLABLE_RESULT {
84
- true => Ok ( wrap_nullable ( & self . nested . return_type ( ) ? ) ) ,
85
- false => Ok ( self . nested . return_type ( ) ? ) ,
87
+ true => Ok ( nested. wrap_nullable ( ) ) ,
88
+ false => Ok ( nested) ,
86
89
}
87
90
}
88
91
@@ -107,24 +110,22 @@ impl<const NULLABLE_RESULT: bool> AggregateFunction for AggregateNullUnaryAdapto
107
110
validity : Option < & Bitmap > ,
108
111
input_rows : usize ,
109
112
) -> Result < ( ) > {
110
- let mut validity = validity. cloned ( ) ;
111
113
let col = & columns[ 0 ] ;
112
- let ( all_null, v) = col. validity ( ) ;
113
- validity = combine_validities ( validity. as_ref ( ) , v) ;
114
- let not_null_columns = Series :: remove_nullable ( col) ;
115
-
116
- self . nested
117
- . accumulate ( place, & [ not_null_columns] , validity. as_ref ( ) , input_rows) ?;
118
-
119
- if !all_null {
120
- match validity {
121
- Some ( v) => {
122
- if v. unset_bits ( ) != input_rows {
123
- self . set_flag ( place, 1 ) ;
124
- }
125
- }
126
- None => self . set_flag ( place, 1 ) ,
114
+ let validity = column_merge_validity ( col, validity. cloned ( ) ) ;
115
+ let not_null_column = col. remove_nullable ( ) ;
116
+
117
+ self . nested . accumulate (
118
+ place,
119
+ & [ not_null_column. clone ( ) ] ,
120
+ validity. as_ref ( ) ,
121
+ input_rows,
122
+ ) ?;
123
+
124
+ match validity {
125
+ Some ( v) if v. unset_bits ( ) != input_rows => {
126
+ self . set_flag ( place, 1 ) ;
127
127
}
128
+ _ => self . set_flag ( place, 1 ) ,
128
129
}
129
130
Ok ( ( ) )
130
131
}
@@ -138,33 +139,29 @@ impl<const NULLABLE_RESULT: bool> AggregateFunction for AggregateNullUnaryAdapto
138
139
input_rows : usize ,
139
140
) -> Result < ( ) > {
140
141
let col = & columns[ 0 ] ;
141
- let ( all_null, validity) = col. validity ( ) ;
142
- let not_null_column = Series :: remove_nullable ( col) ;
143
- let not_null_columns = & [ not_null_column] ;
144
-
145
- if !all_null {
146
- match validity {
147
- Some ( v) if v. unset_bits ( ) > 0 => {
148
- for ( valid, ( row, place) ) in v. iter ( ) . zip ( places. iter ( ) . enumerate ( ) ) {
149
- if valid {
150
- self . set_flag ( place. next ( offset) , 1 ) ;
151
- self . nested . accumulate_row (
152
- place. next ( offset) ,
153
- not_null_columns,
154
- row,
155
- ) ?;
156
- }
142
+ let validity = column_merge_validity ( col, None ) ;
143
+ let not_null_columns = vec ! [ col. remove_nullable( ) ] ;
144
+ let not_null_columns = & not_null_columns;
145
+
146
+ match validity {
147
+ Some ( v) if v. unset_bits ( ) > 0 => {
148
+ for ( valid, ( row, place) ) in v. iter ( ) . zip ( places. iter ( ) . enumerate ( ) ) {
149
+ if valid {
150
+ self . set_flag ( place. next ( offset) , 1 ) ;
151
+ self . nested
152
+ . accumulate_row ( place. next ( offset) , not_null_columns, row) ?;
157
153
}
158
154
}
159
- _ => {
160
- self . nested
161
- . accumulate_keys ( places , offset , not_null_columns , input_rows ) ? ;
162
- places
163
- . iter ( )
164
- . for_each ( |place| self . set_flag ( place . next ( offset ) , 1 ) ) ;
165
- }
155
+ }
156
+ _ => {
157
+ self . nested
158
+ . accumulate_keys ( places, offset , not_null_columns , input_rows ) ? ;
159
+ places
160
+ . iter ( )
161
+ . for_each ( |place| self . set_flag ( place . next ( offset ) , 1 ) ) ;
166
162
}
167
163
}
164
+
168
165
Ok ( ( ) )
169
166
}
170
167
@@ -208,21 +205,22 @@ impl<const NULLABLE_RESULT: bool> AggregateFunction for AggregateNullUnaryAdapto
208
205
self . nested . merge ( place, rhs)
209
206
}
210
207
211
- fn merge_result ( & self , place : StateAddr , column : & mut ColumnBuilder ) -> Result < ( ) > {
208
+ fn merge_result ( & self , place : StateAddr , builder : & mut ColumnBuilder ) -> Result < ( ) > {
212
209
if NULLABLE_RESULT {
213
- let builder: & mut MutableNullableColumn = Series :: check_get_mutable_column ( column) ?;
214
210
if self . get_flag ( place) == 1 {
215
- let inner = builder. inner_mut ( ) ;
216
- self . nested . merge_result ( place, inner. as_mut ( ) ) ?;
217
- let validity = builder. validity_mut ( ) ;
218
-
219
- validity. push ( true ) ;
211
+ match builder {
212
+ ColumnBuilder :: Nullable ( ref mut inner) => {
213
+ self . nested . merge_result ( place, & mut inner. builder ) ?;
214
+ inner. validity . push ( true ) ;
215
+ }
216
+ _ => unreachable ! ( ) ,
217
+ }
220
218
} else {
221
- builder. append_default ( ) ;
219
+ builder. push_default ( ) ;
222
220
}
223
221
Ok ( ( ) )
224
222
} else {
225
- self . nested . merge_result ( place, column )
223
+ self . nested . merge_result ( place, builder )
226
224
}
227
225
}
228
226
0 commit comments