Skip to content

bug: remove busy-wait while sort is ongoing #16322

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 33 additions & 23 deletions datafusion/physical-plan/src/sorts/merge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
//! Merge that deals with an arbitrary size of streaming inputs.
//! This is an order-preserving merge.

use std::collections::VecDeque;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{ready, Context, Poll};
Expand Down Expand Up @@ -143,11 +142,8 @@ pub(crate) struct SortPreservingMergeStream<C: CursorValues> {
/// number of rows produced
produced: usize,

/// This queue contains partition indices in order. When a partition is polled and returns `Poll::Ready`,
/// it is removed from the vector. If a partition returns `Poll::Pending`, it is moved to the end of the
/// vector to ensure the next iteration starts with a different partition, preventing the same partition
/// from being continuously polled.
uninitiated_partitions: VecDeque<usize>,
/// This vector contains the indices of the partitions that have not started emitting yet.
uninitiated_partitions: Vec<usize>,
}

impl<C: CursorValues> SortPreservingMergeStream<C> {
Expand Down Expand Up @@ -216,36 +212,50 @@ impl<C: CursorValues> SortPreservingMergeStream<C> {
// Once all partitions have set their corresponding cursors for the loser tree,
// we skip the following block. Until then, this function may be called multiple
// times and can return Poll::Pending if any partition returns Poll::Pending.

if self.loser_tree.is_empty() {
while let Some(&partition_idx) = self.uninitiated_partitions.front() {
// Manual indexing since we're iterating over the vector and shrinking it in the loop
let mut idx = 0;
while idx < self.uninitiated_partitions.len() {
let partition_idx = self.uninitiated_partitions[idx];
match self.maybe_poll_stream(cx, partition_idx) {
Poll::Ready(Err(e)) => {
self.aborted = true;
return Poll::Ready(Some(Err(e)));
}
Poll::Pending => {
// If a partition returns Poll::Pending, to avoid continuously polling it
// and potentially increasing upstream buffer sizes, we move it to the
// back of the polling queue.
self.uninitiated_partitions.rotate_left(1);

// This function could remain in a pending state, so we manually wake it here.
// However, this approach can be investigated further to find a more natural way
// to avoid disrupting the runtime scheduler.
cx.waker().wake_by_ref();
return Poll::Pending;
// The polled stream is pending which means we're already set up to
// be woken when necessary
// Try the next stream
idx += 1;
}
_ => {
// If the polling result is Poll::Ready(Some(batch)) or Poll::Ready(None),
// we remove this partition from the queue so it is not polled again.
self.uninitiated_partitions.pop_front();
// The polled stream is ready
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you for these comments 👍

// Remove it from uninitiated_partitions
// Don't bump idx here, since a new element will have taken its
// place which we'll try in the next loop iteration
// swap_remove will change the partition poll order, but that shouldn't
// make a difference since we're waiting for all streams to be ready.
self.uninitiated_partitions.swap_remove(idx);
}
}
}

// Claim the memory for the uninitiated partitions
self.uninitiated_partitions.shrink_to_fit();
self.init_loser_tree();
if self.uninitiated_partitions.is_empty() {
// If there are no more uninitiated partitions, set up the loser tree and continue
// to the next phase.

// Claim the memory for the uninitiated partitions
self.uninitiated_partitions.shrink_to_fit();
self.init_loser_tree();
} else {
// There are still uninitiated partitions so return pending.
// We only get here if we've polled all uninitiated streams and at least one of them
// returned pending itself. That means we will be woken as soon as one of the
// streams would like to be polled again.
// There is no need to reschedule ourselves eagerly.
return Poll::Pending;
}
}

// NB timer records time taken on drop, so there are no
Expand Down
75 changes: 55 additions & 20 deletions datafusion/physical-plan/src/sorts/sort_preserving_merge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -378,10 +378,11 @@ impl ExecutionPlan for SortPreservingMergeExec {

#[cfg(test)]
mod tests {
use std::collections::HashSet;
use std::fmt::Formatter;
use std::pin::Pin;
use std::sync::Mutex;
use std::task::{Context, Poll};
use std::task::{ready, Context, Poll, Waker};
use std::time::Duration;

use super::*;
Expand Down Expand Up @@ -1285,13 +1286,50 @@ mod tests {
"#);
}

#[derive(Debug)]
struct CongestionState {
wakers: Vec<Waker>,
unpolled_partitions: HashSet<usize>,
}

#[derive(Debug)]
struct Congestion {
congestion_state: Mutex<CongestionState>,
}

impl Congestion {
fn new(partition_count: usize) -> Self {
Congestion {
congestion_state: Mutex::new(CongestionState {
wakers: vec![],
unpolled_partitions: (0usize..partition_count).collect(),
}),
}
}

fn check_congested(&self, partition: usize, cx: &mut Context<'_>) -> Poll<()> {
let mut state = self.congestion_state.lock().unwrap();

state.unpolled_partitions.remove(&partition);

if state.unpolled_partitions.is_empty() {
state.wakers.iter().for_each(|w| w.wake_by_ref());
state.wakers.clear();
Poll::Ready(())
} else {
state.wakers.push(cx.waker().clone());
Poll::Pending
}
}
}

/// It returns pending for the 2nd partition until the 3rd partition is polled. The 1st
/// partition is exhausted from the start, and if it is polled more than one, it panics.
#[derive(Debug, Clone)]
struct CongestedExec {
schema: Schema,
cache: PlanProperties,
congestion_cleared: Arc<Mutex<bool>>,
congestion: Arc<Congestion>,
}

impl CongestedExec {
Expand Down Expand Up @@ -1346,7 +1384,7 @@ mod tests {
Ok(Box::pin(CongestedStream {
schema: Arc::new(self.schema.clone()),
none_polled_once: false,
congestion_cleared: Arc::clone(&self.congestion_cleared),
congestion: Arc::clone(&self.congestion),
partition,
}))
}
Expand All @@ -1373,39 +1411,30 @@ mod tests {
pub struct CongestedStream {
schema: SchemaRef,
none_polled_once: bool,
congestion_cleared: Arc<Mutex<bool>>,
congestion: Arc<Congestion>,
partition: usize,
}

impl Stream for CongestedStream {
type Item = Result<RecordBatch>;
fn poll_next(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
match self.partition {
0 => {
let _ = self.congestion.check_congested(self.partition, cx);
if self.none_polled_once {
panic!("Exhausted stream is polled more than one")
panic!("Exhausted stream is polled more than once")
} else {
self.none_polled_once = true;
Poll::Ready(None)
}
}
1 => {
let cleared = self.congestion_cleared.lock().unwrap();
if *cleared {
Poll::Ready(None)
} else {
Poll::Pending
}
}
2 => {
let mut cleared = self.congestion_cleared.lock().unwrap();
*cleared = true;
_ => {
ready!(self.congestion.check_congested(self.partition, cx));
Poll::Ready(None)
}
_ => unreachable!(),
}
}
}
Expand All @@ -1420,10 +1449,16 @@ mod tests {
async fn test_spm_congestion() -> Result<()> {
let task_ctx = Arc::new(TaskContext::default());
let schema = Schema::new(vec![Field::new("c1", DataType::UInt64, false)]);
let properties = CongestedExec::compute_properties(Arc::new(schema.clone()));
let &partition_count = match properties.output_partitioning() {
Partitioning::RoundRobinBatch(partitions) => partitions,
Partitioning::Hash(_, partitions) => partitions,
Partitioning::UnknownPartitioning(partitions) => partitions,
};
let source = CongestedExec {
schema: schema.clone(),
cache: CongestedExec::compute_properties(Arc::new(schema.clone())),
congestion_cleared: Arc::new(Mutex::new(false)),
cache: properties,
congestion: Arc::new(Congestion::new(partition_count)),
};
let spm = SortPreservingMergeExec::new(
[PhysicalSortExpr::new_default(Arc::new(Column::new(
Expand Down