Skip to content

Commit 196e626

Browse files
authored
fix: fix spill flaky tests (#13294)
1 parent c8ab979 commit 196e626

File tree

5 files changed

+46
-46
lines changed

5 files changed

+46
-46
lines changed

src/query/service/src/pipelines/processors/transforms/hash_join/build_spill/build_spill_coordinator.rs

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ use common_exception::Result;
2626
use common_expression::DataBlock;
2727
use log::info;
2828
use parking_lot::Mutex;
29+
use parking_lot::RwLock;
2930

3031
/// Coordinate all hash join build processors to spill.
3132
/// It's shared by all hash join build processors.
@@ -41,7 +42,7 @@ pub struct BuildSpillCoordinator {
4142
/// Spill tasks, the size is the same as the total active processor count.
4243
pub(crate) spill_tasks: Mutex<VecDeque<Vec<(u8, DataBlock)>>>,
4344
/// When a build processor won't trigger spill, the field will plus one
44-
pub(crate) non_spill_processors: AtomicUsize,
45+
pub(crate) non_spill_processors: RwLock<usize>,
4546
/// If there is the last active processor, send true to watcher channel
4647
pub(crate) ready_spill_watcher: Sender<bool>,
4748
pub(crate) dummy_ready_spill_receiver: Receiver<bool>,
@@ -74,15 +75,18 @@ impl BuildSpillCoordinator {
7475
.send(false)
7576
.map_err(|_| ErrorCode::TokioError("ready_spill_watcher channel is closed"))?;
7677
}
77-
let old_val = self.waiting_spill_count.fetch_add(1, Ordering::Relaxed);
78+
let non_spill_processors = self.non_spill_processors.read();
79+
let old_val = self.waiting_spill_count.fetch_add(1, Ordering::Release);
7880
let waiting_spill_count = old_val + 1;
79-
let non_spill_processors = self.non_spill_processors.load(Ordering::Relaxed);
8081
info!(
8182
"waiting_spill_count: {:?}, non_spill_processors: {:?}, total_builder_count: {:?}",
82-
waiting_spill_count, non_spill_processors, self.total_builder_count
83+
waiting_spill_count, *non_spill_processors, self.total_builder_count
8384
);
8485

85-
if waiting_spill_count + non_spill_processors == self.total_builder_count {
86+
if (waiting_spill_count + *non_spill_processors == self.total_builder_count)
87+
&& self.get_need_spill()
88+
{
89+
self.no_need_spill();
8690
// Reset waiting_spill_count
8791
self.waiting_spill_count.store(0, Ordering::Relaxed);
8892
// No need to wait spill, the processor is the last one
@@ -116,13 +120,6 @@ impl BuildSpillCoordinator {
116120

117121
// Get active processor count
118122
pub fn active_processor_num(&self) -> usize {
119-
self.total_builder_count - self.non_spill_processors.load(Ordering::Relaxed)
120-
}
121-
122-
// Add one to `non_spill_processors`
123-
// Return value after adding
124-
pub fn increase_non_spill_processors(&self) -> usize {
125-
let old = self.non_spill_processors.fetch_add(1, Ordering::Relaxed);
126-
old + 1
123+
self.total_builder_count - *self.non_spill_processors.read()
127124
}
128125
}

src/query/service/src/pipelines/processors/transforms/hash_join/hash_join_build_state.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ impl HashJoinBuildState {
258258
.store(true, Ordering::Relaxed);
259259
self.hash_join_state
260260
.build_done_watcher
261-
.send(self.send_val.load(Ordering::Relaxed))
261+
.send(self.send_val.load(Ordering::Acquire))
262262
.map_err(|_| ErrorCode::TokioError("build_done_watcher channel is closed"))?;
263263
return Ok(());
264264
}
@@ -615,7 +615,7 @@ impl HashJoinBuildState {
615615
}
616616
self.hash_join_state
617617
.build_done_watcher
618-
.send(self.send_val.load(Ordering::Relaxed))
618+
.send(self.send_val.load(Ordering::Acquire))
619619
.map_err(|_| ErrorCode::TokioError("build_done_watcher channel is closed"))?;
620620
}
621621
Ok(())

src/query/service/src/pipelines/processors/transforms/hash_join/transform_hash_join_build.rs

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,6 @@ impl TransformHashJoinBuild {
9292
if wait {
9393
self.step = HashJoinBuildStep::WaitSpill;
9494
} else {
95-
// Make `need_spill` to false for `SpillCoordinator`
96-
spill_state.spill_coordinator.no_need_spill();
9795
// Before notify all processors to spill, we need to collect all buffered data in `RowSpace` and `Chunks`
9896
// Partition all rows and stat how many partitions and rows in each partition.
9997
// Then choose the largest partitions(which contain rows that can avoid oom exactly) to spill.
@@ -116,12 +114,17 @@ impl TransformHashJoinBuild {
116114

117115
// Called after processor read spilled data
118116
// It means next round build will start, need to reset some variables.
119-
fn reset(&mut self) -> Result<()> {
117+
async fn reset(&mut self) -> Result<()> {
120118
self.finalize_finished = false;
121119
self.from_spill = true;
122120
// Only need to reset the following variables once
123-
if self.build_state.row_space_builders.load(Ordering::Relaxed) == 0 {
124-
self.build_state.send_val.store(2, Ordering::Relaxed);
121+
if self
122+
.build_state
123+
.row_space_builders
124+
.fetch_add(1, Ordering::Acquire)
125+
== 0
126+
{
127+
self.build_state.send_val.store(2, Ordering::Release);
125128
// Before build processors into `WaitProbe` state, set the channel message to false.
126129
// Then after all probe processors are ready, the last one will send true to channel and wake up all build processors.
127130
self.build_state
@@ -130,16 +133,14 @@ impl TransformHashJoinBuild {
130133
.send(false)
131134
.map_err(|_| ErrorCode::TokioError("continue_build_watcher channel is closed"))?;
132135
let worker_num = self.build_state.build_worker_num.load(Ordering::Relaxed) as usize;
133-
self.build_state
134-
.row_space_builders
135-
.store(worker_num, Ordering::Relaxed);
136136
self.build_state
137137
.hash_join_state
138138
.hash_table_builders
139139
.store(worker_num, Ordering::Relaxed);
140140
self.build_state.hash_join_state.reset();
141141
}
142142
self.step = HashJoinBuildStep::Running;
143+
self.build_state.restore_barrier.wait().await;
143144
Ok(())
144145
}
145146
}
@@ -166,11 +167,13 @@ impl Processor for TransformHashJoinBuild {
166167
// The processor won't be triggered spill, because there won't be data from input port
167168
// Add the processor to `non_spill_processors`
168169
let spill_coordinator = &spill_state.spill_coordinator;
169-
let waiting_spill_count = spill_coordinator.waiting_spill_count.load(Ordering::Relaxed);
170-
let non_spill_processors = spill_coordinator.increase_non_spill_processors();
171-
info!("waiting_spill_count: {:?}, non_spill_processors: {:?}, total_builder_count: {:?}", waiting_spill_count, non_spill_processors, spill_state.spill_coordinator.total_builder_count);
172-
if waiting_spill_count != 0 && non_spill_processors + waiting_spill_count == spill_state.spill_coordinator.total_builder_count {
170+
let mut non_spill_processors = spill_coordinator.non_spill_processors.write();
171+
*non_spill_processors += 1;
172+
let waiting_spill_count = spill_coordinator.waiting_spill_count.load(Ordering::Acquire);
173+
info!("waiting_spill_count: {:?}, non_spill_processors: {:?}, total_builder_count: {:?}", waiting_spill_count, *non_spill_processors, spill_state.spill_coordinator.total_builder_count);
174+
if (waiting_spill_count != 0 && *non_spill_processors + waiting_spill_count == spill_state.spill_coordinator.total_builder_count) && spill_coordinator.get_need_spill() {
173175
spill_coordinator.no_need_spill();
176+
drop(non_spill_processors);
174177
let mut spill_task = spill_coordinator.spill_tasks.lock();
175178
spill_state.split_spill_tasks(spill_coordinator.active_processor_num(), &mut spill_task)?;
176179
spill_coordinator.waiting_spill_count.store(0, Ordering::Relaxed);
@@ -201,15 +204,7 @@ impl Processor for TransformHashJoinBuild {
201204
true => {
202205
// If join spill is enabled, we should wait probe to spill.
203206
// Then restore data from disk and build hash table, util all spilled data are processed.
204-
if let Some(spill_state) = &mut self.spill_state {
205-
// Send spilled partition to `HashJoinState`, used by probe spill.
206-
// The method should be called only once.
207-
if !self.send_partition_set {
208-
self.build_state
209-
.hash_join_state
210-
.set_spilled_partition(&spill_state.spiller.spilled_partition_set);
211-
self.send_partition_set = true;
212-
}
207+
if self.spill_state.is_some() {
213208
self.step = HashJoinBuildStep::WaitProbe;
214209
Ok(Event::Async)
215210
} else {
@@ -273,6 +268,16 @@ impl Processor for TransformHashJoinBuild {
273268
self.build_state.finalize(task)
274269
} else {
275270
self.finalize_finished = true;
271+
if let Some(spill_state) = &mut self.spill_state {
272+
// Send spilled partition to `HashJoinState`, used by probe spill.
273+
// The method should be called only once.
274+
if !self.send_partition_set {
275+
self.build_state
276+
.hash_join_state
277+
.set_spilled_partition(&spill_state.spiller.spilled_partition_set);
278+
self.send_partition_set = true;
279+
}
280+
}
276281
self.build_state.build_done()
277282
}
278283
}
@@ -364,7 +369,7 @@ impl Processor for TransformHashJoinBuild {
364369
self.input_data = Some(DataBlock::concat(&spilled_data)?);
365370
}
366371
self.build_state.restore_barrier.wait().await;
367-
self.reset()?;
372+
self.reset().await?;
368373
}
369374
_ => unreachable!(),
370375
}

src/query/service/src/pipelines/processors/transforms/hash_join/transform_hash_join_probe.rs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ impl TransformHashJoinProbe {
223223
Ok(Event::NeedData)
224224
}
225225

226-
fn reset(&mut self) -> Result<()> {
226+
async fn reset(&mut self) -> Result<()> {
227227
self.step = HashJoinProbeStep::Running;
228228
// self.probe_state.reset();
229229
if (self.join_probe_state.hash_join_state.need_outer_scan()
@@ -238,7 +238,7 @@ impl TransformHashJoinProbe {
238238
if self
239239
.join_probe_state
240240
.final_probe_workers
241-
.load(Ordering::Relaxed)
241+
.fetch_add(1, Ordering::Acquire)
242242
== 0
243243
{
244244
// Before probe processor into `WaitBuild` state, send `1` to channel
@@ -248,11 +248,9 @@ impl TransformHashJoinProbe {
248248
.build_done_watcher
249249
.send(1)
250250
.map_err(|_| ErrorCode::TokioError("build_done_watcher channel is closed"))?;
251-
self.join_probe_state
252-
.final_probe_workers
253-
.store(self.join_probe_state.processor_count, Ordering::Relaxed);
254251
}
255252
self.outer_scan_finished = false;
253+
self.join_probe_state.restore_barrier.wait().await;
256254
Ok(())
257255
}
258256
}
@@ -516,7 +514,7 @@ impl Processor for TransformHashJoinProbe {
516514
}
517515
}
518516
self.join_probe_state.restore_barrier.wait().await;
519-
self.reset()?;
517+
self.reset().await?;
520518
}
521519
HashJoinProbeStep::FinalScan | HashJoinProbeStep::FastReturn => unreachable!(),
522520
};

tests/sqllogictests/suites/query/spill.test

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ statement ok
44
set disable_join_reorder = 1;
55

66
statement ok
7-
set join_spilling_threshold = 0;
7+
set join_spilling_threshold = 1;
88

99
statement ok
1010
create table t3(a int);
@@ -65,7 +65,7 @@ statement ok
6565
create table t3 as select number as a from numbers(1000000);
6666

6767
statement ok
68-
set join_spilling_threshold = 0;
68+
set join_spilling_threshold = 100;
6969

7070
query I
7171
select count() from t3 inner join numbers(1000000) on t3.a = number;
@@ -74,7 +74,7 @@ select count() from t3 inner join numbers(1000000) on t3.a = number;
7474

7575
onlyif mysql
7676
statement ok
77-
set join_spilling_threshold = 0;
77+
set join_spilling_threshold = 1024 * 1024 * 1;
7878

7979
onlyif mysql
8080
query I

0 commit comments

Comments
 (0)