Skip to content

Commit 7171319

Browse files
Freejwwjw
andauthored
fix: fix incorrect agg spill in new agg hashtable (#14995)
* fix agg spill * fix agg spill --------- Co-authored-by: jw <freejw@gmail.com>
1 parent 98652d1 commit 7171319

File tree

8 files changed

+190
-39
lines changed

8 files changed

+190
-39
lines changed

src/common/metrics/src/metrics/transform.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,19 @@ pub fn metrics_inc_aggregate_partial_hashtable_allocated_bytes(c: u64) {
7474
AGGREGATE_PARTIAL_HASHTABLE_ALLOCATED_BYTES.inc_by(c);
7575
}
7676

77+
pub fn metrics_inc_group_by_partial_spill_count() {
78+
let labels = &vec![("spill", "group_by_partial_spill".to_string())];
79+
SPILL_COUNT.get_or_create(labels).inc();
80+
}
81+
82+
pub fn metrics_inc_group_by_partial_spill_cell_count(c: u64) {
83+
AGGREGATE_PARTIAL_SPILL_CELL_COUNT.inc_by(c);
84+
}
85+
86+
pub fn metrics_inc_group_by_partial_hashtable_allocated_bytes(c: u64) {
87+
AGGREGATE_PARTIAL_HASHTABLE_ALLOCATED_BYTES.inc_by(c);
88+
}
89+
7790
pub fn metrics_inc_group_by_spill_write_count() {
7891
let labels = &vec![("spill", "group_by_spill".to_string())];
7992
SPILL_WRITE_COUNT.get_or_create(labels).inc();

src/query/expression/src/aggregate/mod.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ mod payload_row;
2626
mod probe_state;
2727

2828
use std::sync::atomic::AtomicU64;
29+
use std::sync::atomic::Ordering;
2930
use std::sync::Arc;
3031

3132
pub use aggregate_function::*;
@@ -108,4 +109,24 @@ impl HashTableConfig {
108109

109110
self
110111
}
112+
113+
pub fn update_current_max_radix_bits(&self) {
114+
loop {
115+
let current_max_radix_bits = self.current_max_radix_bits.load(Ordering::SeqCst);
116+
if current_max_radix_bits < self.max_radix_bits
117+
&& self
118+
.current_max_radix_bits
119+
.compare_exchange(
120+
current_max_radix_bits,
121+
self.max_radix_bits,
122+
Ordering::SeqCst,
123+
Ordering::SeqCst,
124+
)
125+
.is_err()
126+
{
127+
continue;
128+
}
129+
break;
130+
}
131+
}
111132
}

src/query/service/src/pipelines/processors/transforms/aggregator/aggregate_meta.rs

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -57,20 +57,16 @@ impl SerializedPayload {
5757
&self,
5858
group_types: Vec<DataType>,
5959
aggrs: Vec<Arc<dyn AggregateFunction>>,
60+
radix_bits: u64,
61+
arena: Arc<Bump>,
6062
) -> Result<PartitionedPayload> {
6163
let rows_num = self.data_block.num_rows();
62-
let radix_bits = self.max_partition_count.trailing_zeros() as u64;
6364
let config = HashTableConfig::default().with_initial_radix_bits(radix_bits);
6465
let mut state = ProbeState::default();
6566
let agg_len = aggrs.len();
6667
let group_len = group_types.len();
67-
let mut hashtable = AggregateHashTable::new_directly(
68-
group_types,
69-
aggrs,
70-
config,
71-
rows_num,
72-
Arc::new(Bump::new()),
73-
);
68+
let mut hashtable =
69+
AggregateHashTable::new_directly(group_types, aggrs, config, rows_num, arena);
7470

7571
let agg_states = (0..agg_len)
7672
.map(|i| {
@@ -96,6 +92,8 @@ impl SerializedPayload {
9692
let _ =
9793
hashtable.add_groups(&mut state, &group_columns, &[vec![]], &agg_states, rows_num)?;
9894

95+
hashtable.payload.mark_min_cardinality();
96+
9997
Ok(hashtable.payload)
10098
}
10199
}

src/query/service/src/pipelines/processors/transforms/aggregator/transform_aggregate_final.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ impl<Method: HashMethodBounds> TransformFinalAggregate<Method> {
6868

6969
fn transform_agg_hashtable(&mut self, meta: AggregateMeta<Method, usize>) -> Result<DataBlock> {
7070
let mut agg_hashtable: Option<AggregateHashTable> = None;
71-
if let AggregateMeta::Partitioned { bucket: _, data } = meta {
71+
if let AggregateMeta::Partitioned { bucket, data } = meta {
7272
for bucket_data in data {
7373
match bucket_data {
7474
AggregateMeta::AggregateHashTable(payload) => match agg_hashtable.as_mut() {
@@ -92,16 +92,24 @@ impl<Method: HashMethodBounds> TransformFinalAggregate<Method> {
9292
},
9393
AggregateMeta::Serialized(payload) => match agg_hashtable.as_mut() {
9494
Some(ht) => {
95+
debug_assert!(bucket == payload.bucket);
96+
let arena = Arc::new(Bump::new());
9597
let payload = payload.convert_to_partitioned_payload(
9698
self.params.group_data_types.clone(),
9799
self.params.aggregate_functions.clone(),
100+
0,
101+
arena,
98102
)?;
99103
ht.combine_payloads(&payload, &mut self.flush_state)?;
100104
}
101105
None => {
106+
debug_assert!(bucket == payload.bucket);
107+
let arena = Arc::new(Bump::new());
102108
let payload = payload.convert_to_partitioned_payload(
103109
self.params.group_data_types.clone(),
104110
self.params.aggregate_functions.clone(),
111+
0,
112+
arena,
105113
)?;
106114
let capacity =
107115
AggregateHashTable::get_capacity_for_count(payload.len());

src/query/service/src/pipelines/processors/transforms/aggregator/transform_aggregate_partial.rs

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ use databend_common_expression::BlockMetaInfoDowncast;
2828
use databend_common_expression::Column;
2929
use databend_common_expression::DataBlock;
3030
use databend_common_expression::HashTableConfig;
31+
use databend_common_expression::PayloadFlushState;
3132
use databend_common_expression::ProbeState;
3233
use databend_common_functions::aggregates::StateAddr;
3334
use databend_common_functions::aggregates::StateAddrs;
@@ -50,7 +51,6 @@ use crate::pipelines::processors::transforms::group_by::HashMethodBounds;
5051
use crate::pipelines::processors::transforms::group_by::PartitionedHashMethod;
5152
use crate::pipelines::processors::transforms::group_by::PolymorphicKeysHelper;
5253
use crate::sessions::QueryContext;
53-
5454
#[allow(clippy::enum_variant_names)]
5555
enum HashTable<Method: HashMethodBounds> {
5656
MovedOut,
@@ -401,9 +401,21 @@ impl<Method: HashMethodBounds> AccumulatingTransform for TransformPartialAggrega
401401

402402
let group_types = v.payload.group_types.clone();
403403
let aggrs = v.payload.aggrs.clone();
404-
let config = v.config.clone();
404+
v.config.update_current_max_radix_bits();
405+
let config = v
406+
.config
407+
.clone()
408+
.with_initial_radix_bits(v.config.max_radix_bits);
409+
410+
let mut state = PayloadFlushState::default();
411+
412+
// repartition to max for normalization
413+
let partitioned_payload = v
414+
.payload
415+
.repartition(1 << config.max_radix_bits, &mut state);
416+
405417
let blocks = vec![DataBlock::empty_with_meta(
406-
AggregateMeta::<Method, usize>::create_agg_spilling(v.payload),
418+
AggregateMeta::<Method, usize>::create_agg_spilling(partitioned_payload),
407419
)];
408420

409421
let arena = Arc::new(Bump::new());

src/query/service/src/pipelines/processors/transforms/aggregator/transform_group_by_final.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ impl<Method: HashMethodBounds> TransformFinalGroupBy<Method> {
6262

6363
fn transform_agg_hashtable(&mut self, meta: AggregateMeta<Method, ()>) -> Result<DataBlock> {
6464
let mut agg_hashtable: Option<AggregateHashTable> = None;
65-
if let AggregateMeta::Partitioned { bucket: _, data } = meta {
65+
if let AggregateMeta::Partitioned { bucket, data } = meta {
6666
for bucket_data in data {
6767
match bucket_data {
6868
AggregateMeta::AggregateHashTable(payload) => match agg_hashtable.as_mut() {
@@ -85,16 +85,24 @@ impl<Method: HashMethodBounds> TransformFinalGroupBy<Method> {
8585
},
8686
AggregateMeta::Serialized(payload) => match agg_hashtable.as_mut() {
8787
Some(ht) => {
88+
debug_assert!(bucket == payload.bucket);
89+
let arena = Arc::new(Bump::new());
8890
let payload = payload.convert_to_partitioned_payload(
8991
self.params.group_data_types.clone(),
9092
self.params.aggregate_functions.clone(),
93+
0,
94+
arena,
9195
)?;
9296
ht.combine_payloads(&payload, &mut self.flush_state)?;
9397
}
9498
None => {
99+
debug_assert!(bucket == payload.bucket);
100+
let arena = Arc::new(Bump::new());
95101
let payload = payload.convert_to_partitioned_payload(
96102
self.params.group_data_types.clone(),
97103
self.params.aggregate_functions.clone(),
104+
0,
105+
arena,
98106
)?;
99107
let capacity =
100108
AggregateHashTable::get_capacity_for_count(payload.len());

src/query/service/src/pipelines/processors/transforms/aggregator/transform_group_by_partial.rs

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,10 @@ use databend_common_expression::AggregateHashTable;
2626
use databend_common_expression::Column;
2727
use databend_common_expression::DataBlock;
2828
use databend_common_expression::HashTableConfig;
29+
use databend_common_expression::PayloadFlushState;
2930
use databend_common_expression::ProbeState;
3031
use databend_common_hashtable::HashtableLike;
32+
use databend_common_metrics::transform::*;
3133
use databend_common_pipeline_core::processors::InputPort;
3234
use databend_common_pipeline_core::processors::OutputPort;
3335
use databend_common_pipeline_core::processors::Processor;
@@ -212,6 +214,15 @@ impl<Method: HashMethodBounds> AccumulatingTransform for TransformPartialGroupBy
212214
{
213215
if let HashTable::PartitionedHashTable(v) = std::mem::take(&mut self.hash_table)
214216
{
217+
// perf
218+
{
219+
metrics_inc_group_by_partial_spill_count();
220+
metrics_inc_group_by_partial_spill_cell_count(1);
221+
metrics_inc_group_by_partial_hashtable_allocated_bytes(
222+
v.allocated_bytes() as u64,
223+
);
224+
}
225+
215226
let _dropper = v._dropper.clone();
216227
let blocks = vec![DataBlock::empty_with_meta(
217228
AggregateMeta::<Method, ()>::create_spilling(v),
@@ -234,11 +245,30 @@ impl<Method: HashMethodBounds> AccumulatingTransform for TransformPartialGroupBy
234245
|| GLOBAL_MEM_STAT.get_memory_usage() as usize >= self.settings.max_memory_usage)
235246
{
236247
if let HashTable::AggregateHashTable(v) = std::mem::take(&mut self.hash_table) {
248+
// perf
249+
{
250+
metrics_inc_group_by_partial_spill_count();
251+
metrics_inc_group_by_partial_spill_cell_count(1);
252+
metrics_inc_group_by_partial_hashtable_allocated_bytes(
253+
v.allocated_bytes() as u64,
254+
);
255+
}
256+
237257
let group_types = v.payload.group_types.clone();
238258
let aggrs = v.payload.aggrs.clone();
239-
let config = v.config.clone();
259+
v.config.update_current_max_radix_bits();
260+
let config = v
261+
.config
262+
.clone()
263+
.with_initial_radix_bits(v.config.max_radix_bits);
264+
let mut state = PayloadFlushState::default();
265+
266+
// repartition to max for normalization
267+
let partitioned_payload = v
268+
.payload
269+
.repartition(1 << config.max_radix_bits, &mut state);
240270
let blocks = vec![DataBlock::empty_with_meta(
241-
AggregateMeta::<Method, ()>::create_agg_spilling(v.payload),
271+
AggregateMeta::<Method, ()>::create_agg_spilling(partitioned_payload),
242272
)];
243273

244274
let arena = Arc::new(Bump::new());

0 commit comments

Comments
 (0)