From 7345e8584109d8f264a1f8a302f820e71a8ca613 Mon Sep 17 00:00:00 2001 From: Pepijn Van Eeckhoudt Date: Fri, 6 Jun 2025 16:18:54 +0200 Subject: [PATCH] feat: use spawned tasks to reduce call stack depth and avoid busy waiting --- datafusion/common/src/error.rs | 6 + .../physical-plan/src/aggregates/mod.rs | 56 ++--- .../src/aggregates/no_grouping.rs | 200 +++++++----------- .../physical-plan/src/aggregates/row_hash.rs | 12 +- .../src/aggregates/topk/priority_map.rs | 4 - .../src/aggregates/topk_stream.rs | 71 ++++--- .../physical-plan/src/joins/cross_join.rs | 10 +- .../physical-plan/src/joins/hash_join.rs | 25 ++- .../src/joins/nested_loop_join.rs | 10 +- datafusion/physical-plan/src/sorts/sort.rs | 67 ++++-- datafusion/physical-plan/src/stream.rs | 53 ++++- datafusion/physical-plan/src/test.rs | 72 ++++++- 12 files changed, 354 insertions(+), 232 deletions(-) diff --git a/datafusion/common/src/error.rs b/datafusion/common/src/error.rs index b4a537fdce7e..359abfd0df21 100644 --- a/datafusion/common/src/error.rs +++ b/datafusion/common/src/error.rs @@ -350,6 +350,12 @@ impl From for DataFusionError { } } +impl From for DataFusionError { + fn from(e: JoinError) -> Self { + DataFusionError::ExecutionJoin(e) + } +} + impl Display for DataFusionError { fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { let error_prefix = self.error_prefix(); diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 14b2d0a932c2..d13f61dd0b6c 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -21,10 +21,6 @@ use std::any::Any; use std::sync::Arc; use super::{DisplayAs, ExecutionPlanProperties, PlanProperties}; -use crate::aggregates::{ - no_grouping::AggregateStream, row_hash::GroupedHashAggregateStream, - topk_stream::GroupedTopKAggregateStream, -}; use crate::execution_plan::{CardinalityEffect, EmissionType}; use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet}; use crate::windows::get_ordered_partition_by_indices; @@ -358,21 +354,10 @@ impl PartialEq for PhysicalGroupBy { } } -#[allow(clippy::large_enum_variant)] enum StreamType { - AggregateStream(AggregateStream), - GroupedHash(GroupedHashAggregateStream), - GroupedPriorityQueue(GroupedTopKAggregateStream), -} - -impl From for SendableRecordBatchStream { - fn from(stream: StreamType) -> Self { - match stream { - StreamType::AggregateStream(stream) => Box::pin(stream), - StreamType::GroupedHash(stream) => Box::pin(stream), - StreamType::GroupedPriorityQueue(stream) => Box::pin(stream), - } - } + AggregateStream(SendableRecordBatchStream), + GroupedHash(SendableRecordBatchStream), + GroupedPriorityQueue(SendableRecordBatchStream), } /// Hash aggregate execution plan @@ -608,7 +593,7 @@ impl AggregateExec { ) -> Result { // no group by at all if self.group_by.expr.is_empty() { - return Ok(StreamType::AggregateStream(AggregateStream::new( + return Ok(StreamType::AggregateStream(no_grouping::aggregate_stream( self, context, partition, )?)); } @@ -617,13 +602,13 @@ impl AggregateExec { if let Some(limit) = self.limit { if !self.is_unordered_unfiltered_group_by_distinct() { return Ok(StreamType::GroupedPriorityQueue( - GroupedTopKAggregateStream::new(self, context, partition, limit)?, + topk_stream::aggregate_stream(self, context, partition, limit)?, )); } } // grouping by something else and we need to just materialize all results - Ok(StreamType::GroupedHash(GroupedHashAggregateStream::new( + Ok(StreamType::GroupedHash(row_hash::aggregate_stream( self, context, partition, )?)) } @@ -998,8 +983,11 @@ impl ExecutionPlan for AggregateExec { partition: usize, context: Arc, ) -> Result { - self.execute_typed(partition, context) - .map(|stream| stream.into()) + match self.execute_typed(partition, context)? { + StreamType::AggregateStream(s) => Ok(s), + StreamType::GroupedHash(s) => Ok(s), + StreamType::GroupedPriorityQueue(s) => Ok(s), + } } fn metrics(&self) -> Option { @@ -1274,7 +1262,7 @@ pub fn create_accumulators( /// final value (mode = Final, FinalPartitioned and Single) or states (mode = Partial) pub fn finalize_aggregation( accumulators: &mut [AccumulatorItem], - mode: &AggregateMode, + mode: AggregateMode, ) -> Result> { match mode { AggregateMode::Partial => { @@ -2105,20 +2093,20 @@ mod tests { let stream = partial_aggregate.execute_typed(0, Arc::clone(&task_ctx))?; // ensure that we really got the version we wanted - match version { - 0 => { - assert!(matches!(stream, StreamType::AggregateStream(_))); + let stream = match stream { + StreamType::AggregateStream(s) => { + assert_eq!(version, 0); + s } - 1 => { - assert!(matches!(stream, StreamType::GroupedHash(_))); + StreamType::GroupedHash(s) => { + assert!(version == 1 || version == 2); + s } - 2 => { - assert!(matches!(stream, StreamType::GroupedHash(_))); + StreamType::GroupedPriorityQueue(_) => { + panic!("Unexpected GroupedPriorityQueue stream type"); } - _ => panic!("Unknown version: {version}"), - } + }; - let stream: SendableRecordBatchStream = stream.into(); let err = collect(stream).await.unwrap_err(); // error root cause traversal is a bit complicated, see #4172. diff --git a/datafusion/physical-plan/src/aggregates/no_grouping.rs b/datafusion/physical-plan/src/aggregates/no_grouping.rs index 9474a5f88c92..9863683ae11f 100644 --- a/datafusion/physical-plan/src/aggregates/no_grouping.rs +++ b/datafusion/physical-plan/src/aggregates/no_grouping.rs @@ -17,42 +17,47 @@ //! Aggregate without grouping columns +use super::AggregateExec; use crate::aggregates::{ aggregate_expressions, create_accumulators, finalize_aggregation, AccumulatorItem, AggregateMode, }; +use crate::filter::batch_filter; use crate::metrics::{BaselineMetrics, RecordOutput}; -use crate::{RecordBatchStream, SendableRecordBatchStream}; +use crate::stream::RecordBatchStreamAdapter; +use crate::SendableRecordBatchStream; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use datafusion_common::Result; +use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion_execution::TaskContext; use datafusion_physical_expr::PhysicalExpr; -use futures::stream::BoxStream; +use futures::stream; +use futures::stream::StreamExt; use std::borrow::Cow; +use std::future::Future; +use std::pin::Pin; use std::sync::Arc; -use std::task::{Context, Poll}; - -use crate::filter::batch_filter; -use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; -use futures::stream::{Stream, StreamExt}; - -use super::AggregateExec; - -/// stream struct for aggregation without grouping columns -pub(crate) struct AggregateStream { - stream: BoxStream<'static, Result>, - schema: SchemaRef, +use std::task::{ready, Context, Poll}; + +pub fn aggregate_stream( + agg: &AggregateExec, + context: Arc, + partition: usize, +) -> Result { + let aggregate = Aggregate::new(agg, context, partition)?; + + // Spawn a task the first time the stream is polled for the sort phase. + // This ensures the consumer of the aggregate does not poll unnecessarily + // while the aggregation is ongoing + Ok(crate::stream::create_async_then_emit( + Arc::clone(&agg.schema), + aggregate, + )) } -/// Actual implementation of [`AggregateStream`]. -/// -/// This is wrapped into yet another struct because we need to interact with the async memory management subsystem -/// during poll. To have as little code "weirdness" as possible, we chose to just use [`BoxStream`] together with -/// [`futures::stream::unfold`]. -/// -/// The latter requires a state object, which is [`AggregateStreamInner`]. -struct AggregateStreamInner { +/// The state of the aggregation. +struct Aggregate { schema: SchemaRef, mode: AggregateMode, input: SendableRecordBatchStream, @@ -61,17 +66,14 @@ struct AggregateStreamInner { filter_expressions: Vec>>, accumulators: Vec, reservation: MemoryReservation, - finished: bool, } -impl AggregateStream { - /// Create a new AggregateStream - pub fn new( +impl Aggregate { + fn new( agg: &AggregateExec, context: Arc, partition: usize, ) -> Result { - let agg_schema = Arc::clone(&agg.schema); let agg_filter_expr = agg.filter_expr.clone(); let baseline_metrics = BaselineMetrics::new(&agg.metrics, partition); @@ -91,7 +93,7 @@ impl AggregateStream { let reservation = MemoryConsumer::new(format!("AggregateStream[{partition}]")) .register(context.memory_pool()); - let inner = AggregateStreamInner { + Ok(Self { schema: Arc::clone(&agg.schema), mode: agg.mode, input, @@ -100,91 +102,55 @@ impl AggregateStream { filter_expressions, accumulators, reservation, - finished: false, - }; - let stream = futures::stream::unfold(inner, |mut this| async move { - if this.finished { - return None; - } - - let elapsed_compute = this.baseline_metrics.elapsed_compute(); - - loop { - let result = match this.input.next().await { - Some(Ok(batch)) => { - let timer = elapsed_compute.timer(); - let result = aggregate_batch( - &this.mode, - batch, - &mut this.accumulators, - &this.aggregate_expressions, - &this.filter_expressions, - ); - - timer.done(); - - // allocate memory - // This happens AFTER we actually used the memory, but simplifies the whole accounting and we are OK with - // overshooting a bit. Also this means we either store the whole record batch or not. - match result - .and_then(|allocated| this.reservation.try_grow(allocated)) - { - Ok(_) => continue, - Err(e) => Err(e), - } - } - Some(Err(e)) => Err(e), - None => { - this.finished = true; - let timer = this.baseline_metrics.elapsed_compute().timer(); - let result = - finalize_aggregation(&mut this.accumulators, &this.mode) - .and_then(|columns| { - RecordBatch::try_new( - Arc::clone(&this.schema), - columns, - ) - .map_err(Into::into) - }) - .record_output(&this.baseline_metrics); - - timer.done(); - - result - } - }; - - this.finished = true; - return Some((result, this)); - } - }); - - // seems like some consumers call this stream even after it returned `None`, so let's fuse the stream. - let stream = stream.fuse(); - let stream = Box::pin(stream); - - Ok(Self { - schema: agg_schema, - stream, }) } } -impl Stream for AggregateStream { - type Item = Result; +impl Future for Aggregate { + type Output = Result; - fn poll_next( - mut self: std::pin::Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - let this = &mut *self; - this.stream.poll_next_unpin(cx) - } -} + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let elapsed_compute = self.baseline_metrics.elapsed_compute().clone(); + + loop { + match ready!(self.input.poll_next_unpin(cx)) { + Some(Ok(batch)) => { + let timer = elapsed_compute.timer(); + + let result = aggregate_batch(&mut self, &batch); + + timer.done(); -impl RecordBatchStream for AggregateStream { - fn schema(&self) -> SchemaRef { - Arc::clone(&self.schema) + // allocate memory + // This happens AFTER we actually used the memory, but simplifies the whole accounting and we are OK with + // overshooting a bit. Also this means we either store the whole record batch or not. + match result + .and_then(|allocated| self.reservation.try_grow(allocated)) + { + Ok(_) => continue, + Err(e) => return Poll::Ready(Err(e)), + } + } + Some(Err(e)) => return Poll::Ready(Err(e)), + None => { + let timer = elapsed_compute.timer(); + let mode = self.mode; + let result = finalize_aggregation(&mut self.accumulators, mode) + .and_then(|columns| { + RecordBatch::try_new(Arc::clone(&self.schema), columns) + .map_err(Into::into) + }) + .record_output(&self.baseline_metrics); + + timer.done(); + + return Poll::Ready(Ok(Box::pin(RecordBatchStreamAdapter::new( + Arc::clone(&self.schema), + stream::iter(vec![result]), + )))); + } + }; + } } } @@ -193,13 +159,7 @@ impl RecordBatchStream for AggregateStream { /// If successful, this returns the additional number of bytes that were allocated during this process. /// /// TODO: Make this a member function -fn aggregate_batch( - mode: &AggregateMode, - batch: RecordBatch, - accumulators: &mut [AccumulatorItem], - expressions: &[Vec>], - filters: &[Option>], -) -> Result { +fn aggregate_batch(agg: &mut Aggregate, batch: &RecordBatch) -> Result { let mut allocated = 0usize; // 1.1 iterate accumulators and respective expressions together @@ -208,15 +168,15 @@ fn aggregate_batch( // 1.4 update / merge accumulators with the expressions' values // 1.1 - accumulators + agg.accumulators .iter_mut() - .zip(expressions) - .zip(filters) + .zip(&agg.aggregate_expressions) + .zip(&agg.filter_expressions) .try_for_each(|((accum, expr), filter)| { // 1.2 let batch = match filter { - Some(filter) => Cow::Owned(batch_filter(&batch, filter)?), - None => Cow::Borrowed(&batch), + Some(filter) => Cow::Owned(batch_filter(batch, filter)?), + None => Cow::Borrowed(batch), }; let n_rows = batch.num_rows(); @@ -229,7 +189,7 @@ fn aggregate_batch( // 1.4 let size_pre = accum.size(); - let res = match mode { + let res = match agg.mode { AggregateMode::Partial | AggregateMode::Single | AggregateMode::SinglePartitioned => accum.update_batch(&values), diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 1d659d728084..65c41f997a62 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -54,6 +54,16 @@ use futures::ready; use futures::stream::{Stream, StreamExt}; use log::debug; +pub fn aggregate_stream( + agg: &AggregateExec, + context: Arc, + partition: usize, +) -> Result { + Ok(Box::pin(GroupedHashAggregateStream::new( + agg, context, partition, + )?)) +} + #[derive(Debug, Clone)] /// This object tracks the aggregation phase (input/output) pub(crate) enum ExecutionState { @@ -338,7 +348,7 @@ impl SkipAggregationProbe { /// │ 2 │ 2 │ 3.0 │ │ 2 │ 2 │ 3.0 │ └────────────┘ /// └─────────────────┘ └─────────────────┘ /// ``` -pub(crate) struct GroupedHashAggregateStream { +struct GroupedHashAggregateStream { // ======================================================================== // PROPERTIES: // These fields are initialized at the start and remain constant throughout diff --git a/datafusion/physical-plan/src/aggregates/topk/priority_map.rs b/datafusion/physical-plan/src/aggregates/topk/priority_map.rs index a09d70f7471f..9bd52bcf52b2 100644 --- a/datafusion/physical-plan/src/aggregates/topk/priority_map.rs +++ b/datafusion/physical-plan/src/aggregates/topk/priority_map.rs @@ -99,10 +99,6 @@ impl PriorityMap { let ids = unsafe { self.map.take_all(map_idxs) }; Ok(vec![ids, vals]) } - - pub fn is_empty(&self) -> bool { - self.map.len() == 0 - } } #[cfg(test)] diff --git a/datafusion/physical-plan/src/aggregates/topk_stream.rs b/datafusion/physical-plan/src/aggregates/topk_stream.rs index bf02692486cc..c63b43321321 100644 --- a/datafusion/physical-plan/src/aggregates/topk_stream.rs +++ b/datafusion/physical-plan/src/aggregates/topk_stream.rs @@ -22,7 +22,8 @@ use crate::aggregates::{ aggregate_expressions, evaluate_group_by, evaluate_many, AggregateExec, PhysicalGroupBy, }; -use crate::{RecordBatchStream, SendableRecordBatchStream}; +use crate::stream::RecordBatchStreamAdapter; +use crate::SendableRecordBatchStream; use arrow::array::{Array, ArrayRef, RecordBatch}; use arrow::datatypes::SchemaRef; use arrow::util::pretty::print_batches; @@ -30,13 +31,32 @@ use datafusion_common::DataFusionError; use datafusion_common::Result; use datafusion_execution::TaskContext; use datafusion_physical_expr::PhysicalExpr; -use futures::stream::{Stream, StreamExt}; +use futures::stream; +use futures::stream::StreamExt; use log::{trace, Level}; +use std::future::Future; use std::pin::Pin; use std::sync::Arc; -use std::task::{Context, Poll}; +use std::task::{ready, Context, Poll}; -pub struct GroupedTopKAggregateStream { +pub fn aggregate_stream( + aggr: &AggregateExec, + context: Arc, + partition: usize, + limit: usize, +) -> Result { + let aggregate_top_k = AggregateTopK::new(aggr, context, partition, limit)?; + + // Spawn a task the first time the stream is polled for the aggregation phase. + // This ensures the consumer of the aggregation does not poll unnecessarily + // while the aggregation is ongoing + Ok(crate::stream::create_async_then_emit( + Arc::clone(&aggr.schema), + aggregate_top_k, + )) +} + +struct AggregateTopK { partition: usize, row_count: usize, started: bool, @@ -47,7 +67,7 @@ pub struct GroupedTopKAggregateStream { priority_map: PriorityMap, } -impl GroupedTopKAggregateStream { +impl AggregateTopK { pub fn new( aggr: &AggregateExec, context: Arc, @@ -69,7 +89,7 @@ impl GroupedTopKAggregateStream { let priority_map = PriorityMap::new(kt, vt, limit, desc)?; - Ok(GroupedTopKAggregateStream { + Ok(AggregateTopK { partition, started: false, row_count: 0, @@ -82,13 +102,7 @@ impl GroupedTopKAggregateStream { } } -impl RecordBatchStream for GroupedTopKAggregateStream { - fn schema(&self) -> SchemaRef { - Arc::clone(&self.schema) - } -} - -impl GroupedTopKAggregateStream { +impl AggregateTopK { fn intern(&mut self, ids: ArrayRef, vals: ArrayRef) -> Result<()> { let len = ids.len(); self.priority_map.set_batch(ids, Arc::clone(&vals)); @@ -104,15 +118,15 @@ impl GroupedTopKAggregateStream { } } -impl Stream for GroupedTopKAggregateStream { - type Item = Result; +impl Future for AggregateTopK { + type Output = Result; - fn poll_next( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - while let Poll::Ready(res) = self.input.poll_next_unpin(cx) { - match res { + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + loop { + match ready!(self.input.poll_next_unpin(cx)) { + Some(Err(e)) => { + return Poll::Ready(Err(e)); + } // got a batch, convert to rows and append to our TreeMap Some(Ok(batch)) => { self.started = true; @@ -153,10 +167,6 @@ impl Stream for GroupedTopKAggregateStream { } // inner is done, emit all rows and switch to producing output None => { - if self.priority_map.is_empty() { - trace!("partition {} emit None", self.partition); - return Poll::Ready(None); - } let cols = self.priority_map.emit()?; let batch = RecordBatch::try_new(Arc::clone(&self.schema), cols)?; trace!( @@ -167,14 +177,13 @@ impl Stream for GroupedTopKAggregateStream { if log::log_enabled!(Level::Trace) { print_batches(std::slice::from_ref(&batch))?; } - return Poll::Ready(Some(Ok(batch))); - } - // inner had error, return to caller - Some(Err(e)) => { - return Poll::Ready(Some(Err(e))); + + return Poll::Ready(Ok(Box::pin(RecordBatchStreamAdapter::new( + Arc::clone(&self.schema), + stream::iter(vec![Ok(batch)]), + )))); } } } - Poll::Pending } } diff --git a/datafusion/physical-plan/src/joins/cross_join.rs b/datafusion/physical-plan/src/joins/cross_join.rs index e4d554ceb62c..8462afa5b7d8 100644 --- a/datafusion/physical-plan/src/joins/cross_join.rs +++ b/datafusion/physical-plan/src/joins/cross_join.rs @@ -32,7 +32,7 @@ use crate::projection::{ physical_to_column_exprs, ProjectionExec, }; use crate::{ - handle_state, ColumnStatistics, DisplayAs, DisplayFormatType, Distribution, + handle_state, stream, ColumnStatistics, DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties, PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics, }; @@ -303,12 +303,8 @@ impl ExecutionPlan for CrossJoinExec { let left_fut = self.left_fut.try_once(|| { let left_stream = self.left.execute(0, context)?; - - Ok(load_left_input( - left_stream, - join_metrics.clone(), - reservation, - )) + let task = load_left_input(left_stream, join_metrics.clone(), reservation); + Ok(stream::spawn_deferred(task)) })?; if enforce_batch_size_in_joins { diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index 770399290dca..34245c128127 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -39,7 +39,6 @@ use crate::projection::{ ProjectionExec, }; use crate::spill::get_record_batch_memory_size; -use crate::ExecutionPlanProperties; use crate::{ common::can_project, handle_state, @@ -56,6 +55,7 @@ use crate::{ DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics, }; +use crate::{stream, ExecutionPlanProperties}; use arrow::array::{ cast::downcast_array, Array, ArrayRef, BooleanArray, BooleanBufferBuilder, @@ -82,8 +82,9 @@ use datafusion_physical_expr::PhysicalExprRef; use datafusion_physical_expr_common::datum::compare_op_for_nested; use ahash::RandomState; +use datafusion_common_runtime::SpawnedTask; use datafusion_physical_expr_common::physical_expr::fmt_sql; -use futures::{ready, Stream, StreamExt, TryStreamExt}; +use futures::{ready, FutureExt, Stream, StreamExt, TryStreamExt}; use parking_lot::Mutex; /// Hard-coded seed to ensure hash values from the hash join differ from `RepartitionExec`, avoiding collisions. @@ -809,7 +810,7 @@ impl ExecutionPlan for HashJoinExec { let reservation = MemoryConsumer::new("HashJoinInput").register(context.memory_pool()); - Ok(collect_left_input( + let task = collect_left_input( self.random_state.clone(), left_stream, on_left.clone(), @@ -817,7 +818,12 @@ impl ExecutionPlan for HashJoinExec { reservation, need_produce_result_in_final(self.join_type), self.right().output_partitioning().partition_count(), - )) + ); + + // Spawn a task the first time the stream is polled for the build phase. + // This ensures the consumer of the join does not poll unnecessarily + // while the build is ongoing + Ok(stream::spawn_deferred(task)) })?, PartitionMode::Partitioned => { let left_stream = self.left.execute(partition, Arc::clone(&context))?; @@ -826,7 +832,7 @@ impl ExecutionPlan for HashJoinExec { MemoryConsumer::new(format!("HashJoinInput[{partition}]")) .register(context.memory_pool()); - OnceFut::new(collect_left_input( + let task = collect_left_input( self.random_state.clone(), left_stream, on_left.clone(), @@ -834,7 +840,14 @@ impl ExecutionPlan for HashJoinExec { reservation, need_produce_result_in_final(self.join_type), 1, - )) + ); + + OnceFut::new(async move { + // Spawn a task the first time the stream is polled for the build phase. + // This ensures the consumer of the join does not poll unnecessarily + // while the build is ongoing + SpawnedTask::spawn(task).map(|r| r?).await + }) } PartitionMode::Auto => { return plan_err!( diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs b/datafusion/physical-plan/src/joins/nested_loop_join.rs index fcc1107a0e26..bc38b028f998 100644 --- a/datafusion/physical-plan/src/joins/nested_loop_join.rs +++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs @@ -42,7 +42,7 @@ use crate::projection::{ ProjectionExec, }; use crate::{ - handle_state, DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, + handle_state, stream, DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties, PlanProperties, RecordBatchStream, SendableRecordBatchStream, }; @@ -500,13 +500,17 @@ impl ExecutionPlan for NestedLoopJoinExec { let inner_table = self.inner_table.try_once(|| { let stream = self.left.execute(0, Arc::clone(&context))?; - Ok(collect_left_input( + let task = collect_left_input( stream, join_metrics.clone(), load_reservation, need_produce_result_in_final(self.join_type), self.right().output_partitioning().partition_count(), - )) + ); + // Spawn a task the first time the stream is polled for the build phase. + // This ensures the consumer of the join does not poll unnecessarily + // while the build is ongoing + Ok(stream::spawn_deferred(task)) })?; let batch_size = context.session_config().batch_size(); diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index f941827dd036..ddb44650fc1e 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -40,9 +40,9 @@ use crate::spill::spill_manager::SpillManager; use crate::stream::RecordBatchStreamAdapter; use crate::topk::TopK; use crate::{ - DisplayAs, DisplayFormatType, Distribution, EmptyRecordBatchStream, ExecutionPlan, - ExecutionPlanProperties, Partitioning, PlanProperties, SendableRecordBatchStream, - Statistics, + stream, DisplayAs, DisplayFormatType, Distribution, EmptyRecordBatchStream, + ExecutionPlan, ExecutionPlanProperties, Partitioning, PlanProperties, + SendableRecordBatchStream, Statistics, }; use arrow::array::{Array, RecordBatch, RecordBatchOptions, StringViewArray}; @@ -58,6 +58,7 @@ use datafusion_physical_expr::expressions::{lit, DynamicFilterPhysicalExpr}; use datafusion_physical_expr::LexOrdering; use datafusion_physical_expr::PhysicalExpr; +use datafusion_common_runtime::SpawnedTask; use futures::{StreamExt, TryStreamExt}; use log::{debug, trace}; @@ -1154,20 +1155,17 @@ impl ExecutionPlan for SortExec { &self.metrics_set, self.filter.clone(), )?; - Ok(Box::pin(RecordBatchStreamAdapter::new( - self.schema(), - futures::stream::once(async move { - while let Some(batch) = input.next().await { - let batch = batch?; - topk.insert_batch(batch)?; - if topk.finished { - break; - } + + Ok(stream::create_async_then_emit(self.schema(), async move { + while let Some(batch) = input.next().await { + let batch = batch?; + topk.insert_batch(batch)?; + if topk.finished { + break; } - topk.emit() - }) - .try_flatten(), - ))) + } + topk.emit() + })) } (false, None) => { let mut sorter = ExternalSorter::new( @@ -1184,11 +1182,17 @@ impl ExecutionPlan for SortExec { Ok(Box::pin(RecordBatchStreamAdapter::new( self.schema(), futures::stream::once(async move { - while let Some(batch) = input.next().await { - let batch = batch?; - sorter.insert_batch(batch).await?; - } - sorter.sort().await + // Spawn a task the first time the stream is polled for the sort phase. + // This ensures the consumer of the sort does not poll unnecessarily + // while the sort is ongoing + SpawnedTask::spawn(async move { + while let Some(batch) = input.next().await { + let batch = batch?; + sorter.insert_batch(batch).await?; + } + sorter.sort().await + }) + .await? }) .try_flatten(), ))) @@ -1296,9 +1300,9 @@ mod tests { use crate::execution_plan::Boundedness; use crate::expressions::col; use crate::test; - use crate::test::assert_is_pending; use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec}; use crate::test::TestMemoryExec; + use crate::test::{assert_is_pending, panic_exec}; use arrow::array::*; use arrow::compute::SortOptions; @@ -2046,4 +2050,23 @@ mod tests { "#); Ok(()) } + + #[tokio::test] + async fn unpolled_sort_does_not_start_eagerly() -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + let source = panic_exec(1); + let schema = source.schema(); + + let sort_exec = Arc::new(SortExec::new( + [PhysicalSortExpr { + expr: col("i", &schema)?, + options: SortOptions::default(), + }] + .into(), + source, + )); + + let _ = sort_exec.execute(1, task_ctx); + Ok(()) + } } diff --git a/datafusion/physical-plan/src/stream.rs b/datafusion/physical-plan/src/stream.rs index 338ac7d048a3..ce6b3b271bc3 100644 --- a/datafusion/physical-plan/src/stream.rs +++ b/datafusion/physical-plan/src/stream.rs @@ -28,11 +28,11 @@ use crate::displayable; use arrow::{datatypes::SchemaRef, record_batch::RecordBatch}; use datafusion_common::{exec_err, Result}; -use datafusion_common_runtime::JoinSet; +use datafusion_common_runtime::{JoinSet, SpawnedTask}; use datafusion_execution::TaskContext; use futures::stream::BoxStream; -use futures::{Future, Stream, StreamExt}; +use futures::{Future, Stream, StreamExt, TryStreamExt}; use log::debug; use pin_project_lite::pin_project; use tokio::sync::mpsc::{Receiver, Sender}; @@ -522,6 +522,55 @@ impl Stream for ObservedStream { } } +/// Returns a stream that on first poll spawns a task that drives the `create_stream` future to +/// completion and once the future is complete produces the values of the created stream. +/// +/// This construct ensures any intermittent pending results returned by `create_stream` are hidden +/// from the consumer of the returned stream. Instead, the stream consumer will get a pending result +/// once and be woken when the stream creation future has completed. This avoids unnecessarily +/// waking the stream consumer. +/// +/// When `create_stream` is complete, production of `RecordBatch` instances may or may not be +/// multithreaded depending on the `Stream` returned by the future. The stream created by this +/// function will inherit whatever characteristics the stream created by the future has. +pub fn create_async_then_emit( + schema: SchemaRef, + create_stream: F, +) -> SendableRecordBatchStream +where + F: Future> + Send + 'static, +{ + // First create a future that on first poll starts and then awaits a spawned task + // which will drive the `create_stream` future to completion. + // After this statement the `create_stream_deferred` future has not been polled yet + // and so the task has not been spawned yet. + // `create_stream_deferred` awaits the join handle of the spawned task, so the final result + // of this future is the result of the `create_stream` future which is available once the spawned task + // completes. + let create_stream_deferred = spawn_deferred(create_stream); + + // Convert the future into a stream consisting of the result of the `create_stream_deferred` + // future which itself is the result of `create_stream`. In other words `create_stream_stream` + // is a stream containing a single stream. + // Since the stream created by `once` is lazy wrt the future it is given the task still has not + // been spawned when this statement completes. + let create_stream_stream = futures::stream::once(create_stream_deferred); + + // Flatten the stream of streams to get a stream of record batches. + // try_flatten is also lazy, so the task still has not been spawned. + let emit_stream = create_stream_stream.try_flatten(); + + Box::pin(RecordBatchStreamAdapter::new(schema, emit_stream)) +} + +pub(crate) async fn spawn_deferred(task: F) -> Result +where + F: Future> + Send + 'static, + R: Send + 'static, +{ + SpawnedTask::spawn(task).await? +} + #[cfg(test)] mod test { use super::*; diff --git a/datafusion/physical-plan/src/test.rs b/datafusion/physical-plan/src/test.rs index 5e6410a0171e..32f23a7d80fb 100644 --- a/datafusion/physical-plan/src/test.rs +++ b/datafusion/physical-plan/src/test.rs @@ -47,7 +47,7 @@ use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::utils::collect_columns; use datafusion_physical_expr::{EquivalenceProperties, LexOrdering, Partitioning}; -use futures::{Future, FutureExt}; +use futures::{stream, Future, FutureExt}; pub mod exec; @@ -515,10 +515,78 @@ impl PartitionStream for TestPartitionStream { &self.schema } fn execute(&self, _ctx: Arc) -> SendableRecordBatchStream { - let stream = futures::stream::iter(self.batches.clone().into_iter().map(Ok)); + let stream = stream::iter(self.batches.clone().into_iter().map(Ok)); Box::pin(RecordBatchStreamAdapter::new( Arc::clone(&self.schema), stream, )) } } + +/// Returns an `ExecutionPlan` that return a stream which panics if it is ever polled. +/// This can be used to test that execution plan implementations do not eagerly start +/// processing data when `ExecutionPlan::execute is called`. +pub fn panic_exec(partitions: usize) -> Arc { + let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, true)])); + Arc::new(PanicExec::new(schema, partitions)) +} + +#[derive(Debug)] +struct PanicExec { + properties: PlanProperties, +} + +impl PanicExec { + fn new(schema: SchemaRef, partitions: usize) -> Self { + Self { + properties: PlanProperties::new( + EquivalenceProperties::new(Arc::clone(&schema)), + Partitioning::UnknownPartitioning(partitions), + EmissionType::Incremental, + Boundedness::Bounded, + ), + } + } +} + +impl DisplayAs for PanicExec { + fn fmt_as(&self, _: DisplayFormatType, f: &mut Formatter) -> fmt::Result { + write!(f, "Panic") + } +} + +impl ExecutionPlan for PanicExec { + fn name(&self) -> &str { + "PanicExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.properties + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _: Vec>, + ) -> Result> { + Ok(Arc::::clone(&self)) + } + + fn execute( + &self, + _: usize, + _: Arc, + ) -> Result { + Ok(Box::pin(RecordBatchStreamAdapter::new( + self.schema(), + stream::once(async { panic!() }), + ))) + } +}