diff --git a/datafusion/physical-plan/src/sorts/merge.rs b/datafusion/physical-plan/src/sorts/merge.rs index 2b42457635f7..0c18a3b6c703 100644 --- a/datafusion/physical-plan/src/sorts/merge.rs +++ b/datafusion/physical-plan/src/sorts/merge.rs @@ -18,7 +18,6 @@ //! Merge that deals with an arbitrary size of streaming inputs. //! This is an order-preserving merge. -use std::collections::VecDeque; use std::pin::Pin; use std::sync::Arc; use std::task::{ready, Context, Poll}; @@ -143,11 +142,8 @@ pub(crate) struct SortPreservingMergeStream { /// number of rows produced produced: usize, - /// This queue contains partition indices in order. When a partition is polled and returns `Poll::Ready`, - /// it is removed from the vector. If a partition returns `Poll::Pending`, it is moved to the end of the - /// vector to ensure the next iteration starts with a different partition, preventing the same partition - /// from being continuously polled. - uninitiated_partitions: VecDeque, + /// This vector contains the indices of the partitions that have not started emitting yet. + uninitiated_partitions: Vec, } impl SortPreservingMergeStream { @@ -216,36 +212,50 @@ impl SortPreservingMergeStream { // Once all partitions have set their corresponding cursors for the loser tree, // we skip the following block. Until then, this function may be called multiple // times and can return Poll::Pending if any partition returns Poll::Pending. + if self.loser_tree.is_empty() { - while let Some(&partition_idx) = self.uninitiated_partitions.front() { + // Manual indexing since we're iterating over the vector and shrinking it in the loop + let mut idx = 0; + while idx < self.uninitiated_partitions.len() { + let partition_idx = self.uninitiated_partitions[idx]; match self.maybe_poll_stream(cx, partition_idx) { Poll::Ready(Err(e)) => { self.aborted = true; return Poll::Ready(Some(Err(e))); } Poll::Pending => { - // If a partition returns Poll::Pending, to avoid continuously polling it - // and potentially increasing upstream buffer sizes, we move it to the - // back of the polling queue. - self.uninitiated_partitions.rotate_left(1); - - // This function could remain in a pending state, so we manually wake it here. - // However, this approach can be investigated further to find a more natural way - // to avoid disrupting the runtime scheduler. - cx.waker().wake_by_ref(); - return Poll::Pending; + // The polled stream is pending which means we're already set up to + // be woken when necessary + // Try the next stream + idx += 1; } _ => { - // If the polling result is Poll::Ready(Some(batch)) or Poll::Ready(None), - // we remove this partition from the queue so it is not polled again. - self.uninitiated_partitions.pop_front(); + // The polled stream is ready + // Remove it from uninitiated_partitions + // Don't bump idx here, since a new element will have taken its + // place which we'll try in the next loop iteration + // swap_remove will change the partition poll order, but that shouldn't + // make a difference since we're waiting for all streams to be ready. + self.uninitiated_partitions.swap_remove(idx); } } } - // Claim the memory for the uninitiated partitions - self.uninitiated_partitions.shrink_to_fit(); - self.init_loser_tree(); + if self.uninitiated_partitions.is_empty() { + // If there are no more uninitiated partitions, set up the loser tree and continue + // to the next phase. + + // Claim the memory for the uninitiated partitions + self.uninitiated_partitions.shrink_to_fit(); + self.init_loser_tree(); + } else { + // There are still uninitiated partitions so return pending. + // We only get here if we've polled all uninitiated streams and at least one of them + // returned pending itself. That means we will be woken as soon as one of the + // streams would like to be polled again. + // There is no need to reschedule ourselves eagerly. + return Poll::Pending; + } } // NB timer records time taken on drop, so there are no diff --git a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs index 272b8f6d75e0..2944ac230f38 100644 --- a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs +++ b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs @@ -378,10 +378,11 @@ impl ExecutionPlan for SortPreservingMergeExec { #[cfg(test)] mod tests { + use std::collections::HashSet; use std::fmt::Formatter; use std::pin::Pin; use std::sync::Mutex; - use std::task::{Context, Poll}; + use std::task::{ready, Context, Poll, Waker}; use std::time::Duration; use super::*; @@ -1285,13 +1286,50 @@ mod tests { "#); } + #[derive(Debug)] + struct CongestionState { + wakers: Vec, + unpolled_partitions: HashSet, + } + + #[derive(Debug)] + struct Congestion { + congestion_state: Mutex, + } + + impl Congestion { + fn new(partition_count: usize) -> Self { + Congestion { + congestion_state: Mutex::new(CongestionState { + wakers: vec![], + unpolled_partitions: (0usize..partition_count).collect(), + }), + } + } + + fn check_congested(&self, partition: usize, cx: &mut Context<'_>) -> Poll<()> { + let mut state = self.congestion_state.lock().unwrap(); + + state.unpolled_partitions.remove(&partition); + + if state.unpolled_partitions.is_empty() { + state.wakers.iter().for_each(|w| w.wake_by_ref()); + state.wakers.clear(); + Poll::Ready(()) + } else { + state.wakers.push(cx.waker().clone()); + Poll::Pending + } + } + } + /// It returns pending for the 2nd partition until the 3rd partition is polled. The 1st /// partition is exhausted from the start, and if it is polled more than one, it panics. #[derive(Debug, Clone)] struct CongestedExec { schema: Schema, cache: PlanProperties, - congestion_cleared: Arc>, + congestion: Arc, } impl CongestedExec { @@ -1346,7 +1384,7 @@ mod tests { Ok(Box::pin(CongestedStream { schema: Arc::new(self.schema.clone()), none_polled_once: false, - congestion_cleared: Arc::clone(&self.congestion_cleared), + congestion: Arc::clone(&self.congestion), partition, })) } @@ -1373,7 +1411,7 @@ mod tests { pub struct CongestedStream { schema: SchemaRef, none_polled_once: bool, - congestion_cleared: Arc>, + congestion: Arc, partition: usize, } @@ -1381,31 +1419,22 @@ mod tests { type Item = Result; fn poll_next( mut self: Pin<&mut Self>, - _cx: &mut Context<'_>, + cx: &mut Context<'_>, ) -> Poll> { match self.partition { 0 => { + let _ = self.congestion.check_congested(self.partition, cx); if self.none_polled_once { - panic!("Exhausted stream is polled more than one") + panic!("Exhausted stream is polled more than once") } else { self.none_polled_once = true; Poll::Ready(None) } } - 1 => { - let cleared = self.congestion_cleared.lock().unwrap(); - if *cleared { - Poll::Ready(None) - } else { - Poll::Pending - } - } - 2 => { - let mut cleared = self.congestion_cleared.lock().unwrap(); - *cleared = true; + _ => { + ready!(self.congestion.check_congested(self.partition, cx)); Poll::Ready(None) } - _ => unreachable!(), } } } @@ -1420,10 +1449,16 @@ mod tests { async fn test_spm_congestion() -> Result<()> { let task_ctx = Arc::new(TaskContext::default()); let schema = Schema::new(vec![Field::new("c1", DataType::UInt64, false)]); + let properties = CongestedExec::compute_properties(Arc::new(schema.clone())); + let &partition_count = match properties.output_partitioning() { + Partitioning::RoundRobinBatch(partitions) => partitions, + Partitioning::Hash(_, partitions) => partitions, + Partitioning::UnknownPartitioning(partitions) => partitions, + }; let source = CongestedExec { schema: schema.clone(), - cache: CongestedExec::compute_properties(Arc::new(schema.clone())), - congestion_cleared: Arc::new(Mutex::new(false)), + cache: properties, + congestion: Arc::new(Congestion::new(partition_count)), }; let spm = SortPreservingMergeExec::new( [PhysicalSortExpr::new_default(Arc::new(Column::new(