Skip to content

Commit c00deb2

Browse files
committed
Refactor the merge stream to coalesce batches rather than keep in memory
1 parent e305353 commit c00deb2

File tree

1 file changed

+61
-35
lines changed
  • datafusion/physical-plan/src/sorts

1 file changed

+61
-35
lines changed

datafusion/physical-plan/src/sorts/merge.rs

Lines changed: 61 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ use std::pin::Pin;
2323
use std::sync::Arc;
2424
use std::task::{ready, Context, Poll};
2525

26+
use crate::coalesce::BatchCoalescer;
2627
use crate::metrics::BaselineMetrics;
2728
use crate::sorts::builder::BatchBuilder;
2829
use crate::sorts::cursor::{Cursor, CursorValues};
@@ -44,6 +45,8 @@ type CursorStream<C> = Box<dyn PartitionedStream<Output = Result<(C, RecordBatch
4445
pub(crate) struct SortPreservingMergeStream<C: CursorValues> {
4546
in_progress: BatchBuilder,
4647

48+
coalescer: BatchCoalescer,
49+
4750
/// The sorted input streams to merge together
4851
streams: CursorStream<C>,
4952

@@ -163,7 +166,13 @@ impl<C: CursorValues> SortPreservingMergeStream<C> {
163166
let stream_count = streams.partitions();
164167

165168
Self {
166-
in_progress: BatchBuilder::new(schema, stream_count, batch_size, reservation),
169+
in_progress: BatchBuilder::new(
170+
schema.clone(),
171+
stream_count,
172+
batch_size,
173+
reservation,
174+
),
175+
coalescer: BatchCoalescer::new(schema, batch_size, fetch),
167176
streams,
168177
metrics,
169178
aborted: false,
@@ -213,38 +222,37 @@ impl<C: CursorValues> SortPreservingMergeStream<C> {
213222
if self.aborted {
214223
return Poll::Ready(None);
215224
}
225+
226+
while let Some(&partition_idx) = self.uninitiated_partitions.front() {
227+
match self.maybe_poll_stream(cx, partition_idx) {
228+
Poll::Ready(Err(e)) => {
229+
self.aborted = true;
230+
return Poll::Ready(Some(Err(e)));
231+
}
232+
Poll::Pending => {
233+
// If a partition returns Poll::Pending, to avoid continuously polling it
234+
// and potentially increasing upstream buffer sizes, we move it to the
235+
// back of the polling queue.
236+
self.uninitiated_partitions.rotate_left(1);
237+
238+
// This function could remain in a pending state, so we manually wake it here.
239+
// However, this approach can be investigated further to find a more natural way
240+
// to avoid disrupting the runtime scheduler.
241+
cx.waker().wake_by_ref();
242+
return Poll::Pending;
243+
}
244+
_ => {
245+
// If the polling result is Poll::Ready(Some(batch)) or Poll::Ready(None),
246+
// we remove this partition from the queue so it is not polled again.
247+
self.uninitiated_partitions.pop_front();
248+
}
249+
}
250+
}
251+
216252
// Once all partitions have set their corresponding cursors for the loser tree,
217253
// we skip the following block. Until then, this function may be called multiple
218254
// times and can return Poll::Pending if any partition returns Poll::Pending.
219255
if self.loser_tree.is_empty() {
220-
while let Some(&partition_idx) = self.uninitiated_partitions.front() {
221-
match self.maybe_poll_stream(cx, partition_idx) {
222-
Poll::Ready(Err(e)) => {
223-
self.aborted = true;
224-
return Poll::Ready(Some(Err(e)));
225-
}
226-
Poll::Pending => {
227-
// If a partition returns Poll::Pending, to avoid continuously polling it
228-
// and potentially increasing upstream buffer sizes, we move it to the
229-
// back of the polling queue.
230-
self.uninitiated_partitions.rotate_left(1);
231-
232-
// This function could remain in a pending state, so we manually wake it here.
233-
// However, this approach can be investigated further to find a more natural way
234-
// to avoid disrupting the runtime scheduler.
235-
cx.waker().wake_by_ref();
236-
return Poll::Pending;
237-
}
238-
_ => {
239-
// If the polling result is Poll::Ready(Some(batch)) or Poll::Ready(None),
240-
// we remove this partition from the queue so it is not polled again.
241-
self.uninitiated_partitions.pop_front();
242-
}
243-
}
244-
}
245-
246-
// Claim the memory for the uninitiated partitions
247-
self.uninitiated_partitions.shrink_to_fit();
248256
self.init_loser_tree();
249257
}
250258

@@ -254,13 +262,18 @@ impl<C: CursorValues> SortPreservingMergeStream<C> {
254262
let _timer = elapsed_compute.timer();
255263

256264
loop {
265+
// If we have reached the end of one of our cursors
266+
if !self.uninitiated_partitions.is_empty() {
267+
cx.waker().wake_by_ref();
268+
if let Some(batch) = self.in_progress.build_record_batch()? {
269+
self.produced += batch.num_rows();
270+
self.coalescer.push_batch(batch);
271+
}
272+
return Poll::Pending;
273+
}
274+
257275
// Adjust the loser tree if necessary, returning control if needed
258276
if !self.loser_tree_adjusted {
259-
let winner = self.loser_tree[0];
260-
if let Err(e) = ready!(self.maybe_poll_stream(cx, winner)) {
261-
self.aborted = true;
262-
return Poll::Ready(Some(Err(e)));
263-
}
264277
self.update_loser_tree();
265278
}
266279

@@ -279,7 +292,17 @@ impl<C: CursorValues> SortPreservingMergeStream<C> {
279292

280293
self.produced += self.in_progress.len();
281294

282-
return Poll::Ready(self.in_progress.build_record_batch().transpose());
295+
if let Some(batch) = self.in_progress.build_record_batch()? {
296+
self.coalescer.push_batch(batch);
297+
}
298+
299+
let next_batch = self.coalescer.finish_batch()?;
300+
301+
if next_batch.num_rows() != 0 {
302+
return Poll::Ready(Some(Ok(next_batch)));
303+
} else {
304+
return Poll::Ready(None);
305+
}
283306
}
284307
}
285308

@@ -321,6 +344,9 @@ impl<C: CursorValues> SortPreservingMergeStream<C> {
321344
if cursor.is_finished() {
322345
// Take the current cursor, leaving `None` in its place
323346
self.prev_cursors[stream_idx] = self.cursors[stream_idx].take();
347+
348+
// Make it poll next run
349+
self.uninitiated_partitions.push_front(stream_idx);
324350
}
325351
true
326352
} else {

0 commit comments

Comments
 (0)