From 7b3a0a2d0ba7ce20a4fbac03d8232cbf22623895 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Sun, 13 Apr 2025 14:43:32 +0200 Subject: [PATCH 01/24] Optimize topk with filter --- datafusion/physical-plan/src/topk/mod.rs | 101 +++++++++++++++++++---- 1 file changed, 86 insertions(+), 15 deletions(-) diff --git a/datafusion/physical-plan/src/topk/mod.rs b/datafusion/physical-plan/src/topk/mod.rs index 405aa52fe04e..08e1f8d53ebb 100644 --- a/datafusion/physical-plan/src/topk/mod.rs +++ b/datafusion/physical-plan/src/topk/mod.rs @@ -18,9 +18,11 @@ //! TopK: Combination of Sort / LIMIT use arrow::{ - compute::interleave, + array::Scalar, + compute::{and, interleave, FilterBuilder}, row::{RowConverter, Rows, SortField}, }; +use arrow_ord::cmp::lt; use std::mem::size_of; use std::{cmp::Ordering, collections::BinaryHeap, sync::Arc}; @@ -29,8 +31,8 @@ use crate::spill::get_record_batch_memory_size; use crate::{stream::RecordBatchStreamAdapter, SendableRecordBatchStream}; use arrow::array::{Array, ArrayRef, RecordBatch}; use arrow::datatypes::SchemaRef; -use datafusion_common::Result; use datafusion_common::{internal_datafusion_err, HashMap}; +use datafusion_common::{internal_err, Result}; use datafusion_execution::{ memory_pool::{MemoryConsumer, MemoryReservation}, runtime_env::RuntimeEnv, @@ -193,7 +195,7 @@ impl TopK { let baseline = self.metrics.baseline.clone(); let _timer = baseline.elapsed_compute().timer(); - let sort_keys: Vec = self + let mut sort_keys: Vec = self .expr .iter() .map(|expr| { @@ -202,24 +204,93 @@ impl TopK { }) .collect::>>()?; + // selected indices + let mut selected_rows = None; + + // If the heap doesn't have k elements yet, we can't create thresholds + match self.heap.max() { + Some(max_row) => { + // Get the batch that contains the max row + let batch_entry = match self.heap.store.get(max_row.batch_id) { + Some(entry) => entry, + None => return internal_err!("Invalid batch ID in TopKRow"), + }; + + // Extract threshold values for each sort expression + let mut scalar_values = Vec::with_capacity(self.expr.len()); + for sort_expr in self.expr.iter() { + // Extract the value for this column from the max row + let expr = Arc::clone(&sort_expr.expr); + let value = + expr.evaluate(&batch_entry.batch.slice(max_row.index, 1))?; + + // Convert to scalar value - should be a single value since we're evaluating on a single row batch + let scalar = Scalar::new(value.to_array(1)?); + scalar_values.push(scalar); + } + // Create a filter for each sort key + let filter = sort_keys + .iter() + .zip(scalar_values.iter()) + .map(|(expr, scalar)| { + let filter = lt(expr, scalar).expect("Should be valid filter"); + Ok(filter) + }) + .collect::>>()?; + // Combine the masks into a single filter + let filter = filter + .iter() + .fold(filter[0].clone(), |acc, filter| and(&acc, filter).unwrap()); + let filter_predicate = FilterBuilder::new(&filter); + let filter_predicate = if sort_keys.len() > 1 { + filter_predicate.optimize().build() + } else { + filter_predicate.build() + }; + selected_rows = Some(filter); + + sort_keys = sort_keys + .iter() + .map(|key| filter_predicate.filter(key).map_err(|x| x.into())) + .collect::>>()?; + } + None => {} + } + // reuse existing `Rows` to avoid reallocations let rows = &mut self.scratch_rows; rows.clear(); self.row_converter.append(rows, &sort_keys)?; - // TODO make this algorithmically better?: - // Idea: filter out rows >= self.heap.max() early (before passing to `RowConverter`) - // this avoids some work and also might be better vectorizable. let mut batch_entry = self.heap.register_batch(batch.clone()); - for (index, row) in rows.iter().enumerate() { - match self.heap.max() { - // heap has k items, and the new row is greater than the - // current max in the heap ==> it is not a new topk - Some(max_row) if row.as_ref() >= max_row.row() => {} - // don't yet have k items or new item is lower than the currently k low values - None | Some(_) => { - self.heap.add(&mut batch_entry, row, index); - self.metrics.row_replacements.add(1); + + match selected_rows { + Some(filter) => { + for (index, row) in filter.values().set_indices().zip(rows.iter()) { + match self.heap.max() { + // heap has k items, and the new row is greater than the + // current max in the heap ==> it is not a new topk + Some(max_row) if row.as_ref() >= max_row.row() => {} + // don't yet have k items or new item is lower than the currently k low values + None | Some(_) => { + self.heap.add(&mut batch_entry, row, index); + self.metrics.row_replacements.add(1); + } + } + } + } + None => { + for (index, row) in rows.iter().enumerate() { + match self.heap.max() { + // heap has k items, and the new row is greater than the + // current max in the heap ==> it is not a new topk + Some(max_row) if row.as_ref() >= max_row.row() => {} + // don't yet have k items or new item is lower than the currently k low values + None | Some(_) => { + self.heap.add(&mut batch_entry, row, index); + self.metrics.row_replacements.add(1); + } + } } } } From 75acd0748b74eb8b44de6fcec46f11b424603a02 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Sun, 13 Apr 2025 15:12:13 +0200 Subject: [PATCH 02/24] add sort_tpch_limit bench --- benchmarks/bench.sh | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/benchmarks/bench.sh b/benchmarks/bench.sh index 5d3ad3446ddb..fd99d408c4d1 100755 --- a/benchmarks/bench.sh +++ b/benchmarks/bench.sh @@ -212,6 +212,10 @@ main() { # same data as for tpch data_tpch "1" ;; + sort_tpch_limit) + # same data as for tpch + data_tpch "1" + ;; *) echo "Error: unknown benchmark '$BENCHMARK' for data generation" usage @@ -251,6 +255,7 @@ main() { run_cancellation run_parquet run_sort + run_sort_tpch_limit run_clickbench_1 run_clickbench_partitioned run_clickbench_extended @@ -320,6 +325,9 @@ main() { sort_tpch) run_sort_tpch ;; + sort_tpch_limit) + run_sort_tpch_limit + ;; *) echo "Error: unknown benchmark '$BENCHMARK' for run" usage @@ -918,6 +926,15 @@ run_sort_tpch() { $CARGO_COMMAND --bin dfbench -- sort-tpch --iterations 5 --path "${TPCH_DIR}" -o "${RESULTS_FILE}" } +# Runs the sort tpch integration benchmark with limit +run_sort_tpch_limit() { + TPCH_DIR="${DATA_DIR}/tpch_sf1" + RESULTS_FILE="${RESULTS_DIR}/sort_tpch_limit.json" + echo "RESULTS_FILE: ${RESULTS_FILE}" + echo "Running sort tpch benchmark..." + + $CARGO_COMMAND --bin dfbench -- sort-tpch --iterations 5 --path "${TPCH_DIR}" -o "${RESULTS_FILE}" --limit 100 +} compare_benchmarks() { BASE_RESULTS_DIR="${SCRIPT_DIR}/results" From 32f11dcbbb0eb6f47e2dbfdeb09b2e2b50cd29ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Sun, 13 Apr 2025 16:30:23 +0200 Subject: [PATCH 03/24] early return --- datafusion/physical-plan/src/topk/mod.rs | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/datafusion/physical-plan/src/topk/mod.rs b/datafusion/physical-plan/src/topk/mod.rs index 08e1f8d53ebb..c51dc5ce5a2c 100644 --- a/datafusion/physical-plan/src/topk/mod.rs +++ b/datafusion/physical-plan/src/topk/mod.rs @@ -241,6 +241,10 @@ impl TopK { let filter = filter .iter() .fold(filter[0].clone(), |acc, filter| and(&acc, filter).unwrap()); + if filter.true_count() == 0 { + // No rows are less than the max row, so we can skip this batch + return Ok(()); + } let filter_predicate = FilterBuilder::new(&filter); let filter_predicate = if sort_keys.len() > 1 { filter_predicate.optimize().build() @@ -264,6 +268,8 @@ impl TopK { let mut batch_entry = self.heap.register_batch(batch.clone()); + let mut replacements = 0; + match selected_rows { Some(filter) => { for (index, row) in filter.values().set_indices().zip(rows.iter()) { @@ -274,7 +280,7 @@ impl TopK { // don't yet have k items or new item is lower than the currently k low values None | Some(_) => { self.heap.add(&mut batch_entry, row, index); - self.metrics.row_replacements.add(1); + replacements += 1; } } } @@ -288,12 +294,14 @@ impl TopK { // don't yet have k items or new item is lower than the currently k low values None | Some(_) => { self.heap.add(&mut batch_entry, row, index); - self.metrics.row_replacements.add(1); + replacements += 1; } } } } } + + self.metrics.row_replacements.add(replacements); self.heap.insert_batch_entry(batch_entry); // conserve memory @@ -306,7 +314,6 @@ impl TopK { // subsequent batches are guaranteed to be greater (by byte order, after row conversion) than the top K, // which means the top K won't change and the computation can be finished early. self.attempt_early_completion(&batch)?; - Ok(()) } From 351663be2ea60e7f0a4bd4c1ba0a0c4542cbe4c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Sun, 13 Apr 2025 16:43:49 +0200 Subject: [PATCH 04/24] Clippy --- datafusion/physical-plan/src/topk/mod.rs | 96 ++++++++++++------------ 1 file changed, 46 insertions(+), 50 deletions(-) diff --git a/datafusion/physical-plan/src/topk/mod.rs b/datafusion/physical-plan/src/topk/mod.rs index 6afbd70a9cd6..b3b2ac9488bd 100644 --- a/datafusion/physical-plan/src/topk/mod.rs +++ b/datafusion/physical-plan/src/topk/mod.rs @@ -208,57 +208,53 @@ impl TopK { let mut selected_rows = None; // If the heap doesn't have k elements yet, we can't create thresholds - match self.heap.max() { - Some(max_row) => { - // Get the batch that contains the max row - let batch_entry = match self.heap.store.get(max_row.batch_id) { - Some(entry) => entry, - None => return internal_err!("Invalid batch ID in TopKRow"), - }; - - // Extract threshold values for each sort expression - let mut scalar_values = Vec::with_capacity(self.expr.len()); - for sort_expr in self.expr.iter() { - // Extract the value for this column from the max row - let expr = Arc::clone(&sort_expr.expr); - let value = - expr.evaluate(&batch_entry.batch.slice(max_row.index, 1))?; - - // Convert to scalar value - should be a single value since we're evaluating on a single row batch - let scalar = Scalar::new(value.to_array(1)?); - scalar_values.push(scalar); - } - // Create a filter for each sort key - let filter = sort_keys - .iter() - .zip(scalar_values.iter()) - .map(|(expr, scalar)| { - let filter = lt(expr, scalar).expect("Should be valid filter"); - Ok(filter) - }) - .collect::>>()?; - // Combine the masks into a single filter - let filter = filter - .iter() - .fold(filter[0].clone(), |acc, filter| and(&acc, filter).unwrap()); - if filter.true_count() == 0 { - // No rows are less than the max row, so we can skip this batch - return Ok(()); - } - let filter_predicate = FilterBuilder::new(&filter); - let filter_predicate = if sort_keys.len() > 1 { - filter_predicate.optimize().build() - } else { - filter_predicate.build() - }; - selected_rows = Some(filter); - - sort_keys = sort_keys - .iter() - .map(|key| filter_predicate.filter(key).map_err(|x| x.into())) - .collect::>>()?; + if let Some(max_row) = self.heap.max() { + // Get the batch that contains the max row + let batch_entry = match self.heap.store.get(max_row.batch_id) { + Some(entry) => entry, + None => return internal_err!("Invalid batch ID in TopKRow"), + }; + + // Extract threshold values for each sort expression + let mut scalar_values = Vec::with_capacity(self.expr.len()); + for sort_expr in self.expr.iter() { + // Extract the value for this column from the max row + let expr = Arc::clone(&sort_expr.expr); + let value = expr.evaluate(&batch_entry.batch.slice(max_row.index, 1))?; + + // Convert to scalar value - should be a single value since we're evaluating on a single row batch + let scalar = Scalar::new(value.to_array(1)?); + scalar_values.push(scalar); } - None => {} + // Create a filter for each sort key + let filter = sort_keys + .iter() + .zip(scalar_values.iter()) + .map(|(expr, scalar)| { + let filter = lt(expr, scalar).expect("Should be valid filter"); + Ok(filter) + }) + .collect::>>()?; + // Combine the masks into a single filter + let filter = filter + .iter() + .fold(filter[0].clone(), |acc, filter| and(&acc, filter).unwrap()); + if filter.true_count() == 0 { + // No rows are less than the max row, so we can skip this batch + return Ok(()); + } + let filter_predicate = FilterBuilder::new(&filter); + let filter_predicate = if sort_keys.len() > 1 { + filter_predicate.optimize().build() + } else { + filter_predicate.build() + }; + selected_rows = Some(filter); + + sort_keys = sort_keys + .iter() + .map(|key| filter_predicate.filter(key).map_err(|x| x.into())) + .collect::>>()?; } // reuse existing `Rows` to avoid reallocations From eeb8ce4fdd3c18fb8d210018ddbbad89203336f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Sun, 13 Apr 2025 19:12:35 +0200 Subject: [PATCH 05/24] Respect lexicographical ordering, only apply first filter --- datafusion/physical-plan/src/topk/mod.rs | 25 +++++++++--------------- 1 file changed, 9 insertions(+), 16 deletions(-) diff --git a/datafusion/physical-plan/src/topk/mod.rs b/datafusion/physical-plan/src/topk/mod.rs index b3b2ac9488bd..db856975d716 100644 --- a/datafusion/physical-plan/src/topk/mod.rs +++ b/datafusion/physical-plan/src/topk/mod.rs @@ -19,7 +19,7 @@ use arrow::{ array::Scalar, - compute::{and, interleave_record_batch, FilterBuilder}, + compute::{interleave_record_batch, FilterBuilder}, row::{RowConverter, Rows, SortField}, }; use arrow_ord::cmp::lt; @@ -216,29 +216,22 @@ impl TopK { }; // Extract threshold values for each sort expression - let mut scalar_values = Vec::with_capacity(self.expr.len()); - for sort_expr in self.expr.iter() { + // TODO: create a filter for each key that respects lexical ordering + // in the form of col0 < threshold0 || col0 == threshold0 && (col1 < threshold1 || ...) + // This could use BinaryExpr to benefit from short circuiting and early evaluation + let mut thresholds = Vec::with_capacity(self.expr.len()); + + for sort_expr in self.expr[..1].iter() { // Extract the value for this column from the max row let expr = Arc::clone(&sort_expr.expr); let value = expr.evaluate(&batch_entry.batch.slice(max_row.index, 1))?; // Convert to scalar value - should be a single value since we're evaluating on a single row batch let scalar = Scalar::new(value.to_array(1)?); - scalar_values.push(scalar); + thresholds.push(scalar); } // Create a filter for each sort key - let filter = sort_keys - .iter() - .zip(scalar_values.iter()) - .map(|(expr, scalar)| { - let filter = lt(expr, scalar).expect("Should be valid filter"); - Ok(filter) - }) - .collect::>>()?; - // Combine the masks into a single filter - let filter = filter - .iter() - .fold(filter[0].clone(), |acc, filter| and(&acc, filter).unwrap()); + let filter = lt(&sort_keys[0], &thresholds[0])?; if filter.true_count() == 0 { // No rows are less than the max row, so we can skip this batch return Ok(()); From f0290c4371d2628bc633b3c233bcbb75f7462894 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Sun, 13 Apr 2025 19:35:02 +0200 Subject: [PATCH 06/24] Respect lexicographical ordering --- datafusion/physical-plan/src/topk/mod.rs | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/datafusion/physical-plan/src/topk/mod.rs b/datafusion/physical-plan/src/topk/mod.rs index db856975d716..1ec7a2139555 100644 --- a/datafusion/physical-plan/src/topk/mod.rs +++ b/datafusion/physical-plan/src/topk/mod.rs @@ -22,7 +22,7 @@ use arrow::{ compute::{interleave_record_batch, FilterBuilder}, row::{RowConverter, Rows, SortField}, }; -use arrow_ord::cmp::lt; +use arrow_ord::cmp::{gt_eq, lt_eq}; use std::mem::size_of; use std::{cmp::Ordering, collections::BinaryHeap, sync::Arc}; @@ -231,7 +231,10 @@ impl TopK { thresholds.push(scalar); } // Create a filter for each sort key - let filter = lt(&sort_keys[0], &thresholds[0])?; + let filter = match self.expr[0].options.descending{ + true => gt_eq(&sort_keys[0], &thresholds[0])?, + false => lt_eq(&sort_keys[0], &thresholds[0])? + }; if filter.true_count() == 0 { // No rows are less than the max row, so we can skip this batch return Ok(()); From 67aa03a2c0ac3b97f4e9fdd0e0e787df0e1c685d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Sun, 13 Apr 2025 19:39:32 +0200 Subject: [PATCH 07/24] Respect lexicographical ordering, only apply first filter --- datafusion/physical-plan/src/topk/mod.rs | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/datafusion/physical-plan/src/topk/mod.rs b/datafusion/physical-plan/src/topk/mod.rs index 1ec7a2139555..fa20f0269cf6 100644 --- a/datafusion/physical-plan/src/topk/mod.rs +++ b/datafusion/physical-plan/src/topk/mod.rs @@ -22,7 +22,7 @@ use arrow::{ compute::{interleave_record_batch, FilterBuilder}, row::{RowConverter, Rows, SortField}, }; -use arrow_ord::cmp::{gt_eq, lt_eq}; +use arrow_ord::cmp::{gt, gt_eq, lt, lt_eq}; use std::mem::size_of; use std::{cmp::Ordering, collections::BinaryHeap, sync::Arc}; @@ -231,9 +231,12 @@ impl TopK { thresholds.push(scalar); } // Create a filter for each sort key - let filter = match self.expr[0].options.descending{ - true => gt_eq(&sort_keys[0], &thresholds[0])?, - false => lt_eq(&sort_keys[0], &thresholds[0])? + let is_multi_col = self.expr[0].len()> 1; + let filter = match (is_multi_col, self.expr[0].options.descending){ + (true, true) => gt_eq(&sort_keys[0], &thresholds[0])?, + (true, false) => lt_eq(&sort_keys[0], &thresholds[0])?, + (false, true) => gt(&sort_keys[0], &thresholds[0])?, + (false, false) => lt(&sort_keys[0], &thresholds[0])? }; if filter.true_count() == 0 { // No rows are less than the max row, so we can skip this batch From 559b789de36cd8887010461ec0694b5834949d59 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Sun, 13 Apr 2025 19:44:27 +0200 Subject: [PATCH 08/24] Respect lexicographical ordering, only apply first filter --- datafusion/physical-plan/src/topk/mod.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/datafusion/physical-plan/src/topk/mod.rs b/datafusion/physical-plan/src/topk/mod.rs index fa20f0269cf6..ffccb5e61dd4 100644 --- a/datafusion/physical-plan/src/topk/mod.rs +++ b/datafusion/physical-plan/src/topk/mod.rs @@ -231,12 +231,12 @@ impl TopK { thresholds.push(scalar); } // Create a filter for each sort key - let is_multi_col = self.expr[0].len()> 1; - let filter = match (is_multi_col, self.expr[0].options.descending){ + let is_multi_col = self.expr.len() > 1; + let filter = match (is_multi_col, self.expr[0].options.descending) { (true, true) => gt_eq(&sort_keys[0], &thresholds[0])?, (true, false) => lt_eq(&sort_keys[0], &thresholds[0])?, (false, true) => gt(&sort_keys[0], &thresholds[0])?, - (false, false) => lt(&sort_keys[0], &thresholds[0])? + (false, false) => lt(&sort_keys[0], &thresholds[0])?, }; if filter.true_count() == 0 { // No rows are less than the max row, so we can skip this batch From 4a24e75fb9bbb82c0774da29881e3bcd98e0222b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Sun, 13 Apr 2025 20:26:03 +0200 Subject: [PATCH 09/24] Simplify and add link --- datafusion/physical-plan/src/topk/mod.rs | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/datafusion/physical-plan/src/topk/mod.rs b/datafusion/physical-plan/src/topk/mod.rs index ffccb5e61dd4..6642c84202d4 100644 --- a/datafusion/physical-plan/src/topk/mod.rs +++ b/datafusion/physical-plan/src/topk/mod.rs @@ -219,24 +219,21 @@ impl TopK { // TODO: create a filter for each key that respects lexical ordering // in the form of col0 < threshold0 || col0 == threshold0 && (col1 < threshold1 || ...) // This could use BinaryExpr to benefit from short circuiting and early evaluation - let mut thresholds = Vec::with_capacity(self.expr.len()); + // https://github.com/apache/datafusion/issues/15698 + // Extract the value for this column from the max row + let expr = Arc::clone(&self.expr[0].expr); + let value = expr.evaluate(&batch_entry.batch.slice(max_row.index, 1))?; - for sort_expr in self.expr[..1].iter() { - // Extract the value for this column from the max row - let expr = Arc::clone(&sort_expr.expr); - let value = expr.evaluate(&batch_entry.batch.slice(max_row.index, 1))?; + // Convert to scalar value - should be a single value since we're evaluating on a single row batch + let threshold = Scalar::new(value.to_array(1)?); - // Convert to scalar value - should be a single value since we're evaluating on a single row batch - let scalar = Scalar::new(value.to_array(1)?); - thresholds.push(scalar); - } // Create a filter for each sort key let is_multi_col = self.expr.len() > 1; let filter = match (is_multi_col, self.expr[0].options.descending) { - (true, true) => gt_eq(&sort_keys[0], &thresholds[0])?, - (true, false) => lt_eq(&sort_keys[0], &thresholds[0])?, - (false, true) => gt(&sort_keys[0], &thresholds[0])?, - (false, false) => lt(&sort_keys[0], &thresholds[0])?, + (true, true) => gt_eq(&sort_keys[0], &threshold)?, + (true, false) => lt_eq(&sort_keys[0], &threshold)?, + (false, true) => gt(&sort_keys[0], &threshold)?, + (false, false) => lt(&sort_keys[0], &threshold)?, }; if filter.true_count() == 0 { // No rows are less than the max row, so we can skip this batch From 5d42ee74b19872593837e68fe3d343641447b41e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Sun, 13 Apr 2025 21:52:57 +0200 Subject: [PATCH 10/24] Still run early completion --- datafusion/physical-plan/src/topk/mod.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/datafusion/physical-plan/src/topk/mod.rs b/datafusion/physical-plan/src/topk/mod.rs index 6642c84202d4..f06e781635c8 100644 --- a/datafusion/physical-plan/src/topk/mod.rs +++ b/datafusion/physical-plan/src/topk/mod.rs @@ -237,6 +237,9 @@ impl TopK { }; if filter.true_count() == 0 { // No rows are less than the max row, so we can skip this batch + // Early completion is still possible, as last row might be greater + self.attempt_early_completion(&batch)?; + return Ok(()); } let filter_predicate = FilterBuilder::new(&filter); From 7003aed2e615117ec9317a36c29515550596086c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Mon, 14 Apr 2025 08:00:11 +0200 Subject: [PATCH 11/24] Keep null values --- datafusion/physical-plan/src/topk/mod.rs | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/datafusion/physical-plan/src/topk/mod.rs b/datafusion/physical-plan/src/topk/mod.rs index f06e781635c8..bbcab2f385af 100644 --- a/datafusion/physical-plan/src/topk/mod.rs +++ b/datafusion/physical-plan/src/topk/mod.rs @@ -19,10 +19,11 @@ use arrow::{ array::Scalar, - compute::{interleave_record_batch, FilterBuilder}, + compute::{interleave_record_batch, is_null, or, FilterBuilder}, row::{RowConverter, Rows, SortField}, }; use arrow_ord::cmp::{gt, gt_eq, lt, lt_eq}; +use datafusion_expr::BinaryExpr; use std::mem::size_of; use std::{cmp::Ordering, collections::BinaryHeap, sync::Arc}; @@ -229,11 +230,20 @@ impl TopK { // Create a filter for each sort key let is_multi_col = self.expr.len() > 1; + let filter = match (is_multi_col, self.expr[0].options.descending) { - (true, true) => gt_eq(&sort_keys[0], &threshold)?, - (true, false) => lt_eq(&sort_keys[0], &threshold)?, - (false, true) => gt(&sort_keys[0], &threshold)?, - (false, false) => lt(&sort_keys[0], &threshold)?, + (true, true) => { + or(>_eq(&sort_keys[0], &threshold)?, &is_null(&sort_keys[0])?)? + } + (true, false) => { + or(<_eq(&sort_keys[0], &threshold)?, &is_null(&sort_keys[0])?)? + } + (false, true) => { + or(>(&sort_keys[0], &threshold)?, &is_null(&sort_keys[0])?)? + } + (false, false) => { + or(<(&sort_keys[0], &threshold)?, &is_null(&sort_keys[0])?)? + } }; if filter.true_count() == 0 { // No rows are less than the max row, so we can skip this batch @@ -242,6 +252,7 @@ impl TopK { return Ok(()); } + let filter_predicate = FilterBuilder::new(&filter); let filter_predicate = if sort_keys.len() > 1 { filter_predicate.optimize().build() From 1610f789217dc95fb5df6f3e3f0f8a6096e86642 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Mon, 14 Apr 2025 08:58:34 +0200 Subject: [PATCH 12/24] Keep null values --- datafusion/physical-plan/src/topk/mod.rs | 41 +++++++++++++++--------- 1 file changed, 25 insertions(+), 16 deletions(-) diff --git a/datafusion/physical-plan/src/topk/mod.rs b/datafusion/physical-plan/src/topk/mod.rs index bbcab2f385af..a51875e0b2ce 100644 --- a/datafusion/physical-plan/src/topk/mod.rs +++ b/datafusion/physical-plan/src/topk/mod.rs @@ -18,13 +18,12 @@ //! TopK: Combination of Sort / LIMIT use arrow::{ - array::Scalar, + array::{BooleanArray, Scalar}, compute::{interleave_record_batch, is_null, or, FilterBuilder}, row::{RowConverter, Rows, SortField}, }; use arrow_ord::cmp::{gt, gt_eq, lt, lt_eq}; -use datafusion_expr::BinaryExpr; -use std::mem::size_of; +use std::{clone, mem::size_of}; use std::{cmp::Ordering, collections::BinaryHeap, sync::Arc}; use super::metrics::{BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder}; @@ -231,20 +230,30 @@ impl TopK { // Create a filter for each sort key let is_multi_col = self.expr.len() > 1; - let filter = match (is_multi_col, self.expr[0].options.descending) { - (true, true) => { - or(>_eq(&sort_keys[0], &threshold)?, &is_null(&sort_keys[0])?)? - } - (true, false) => { - or(<_eq(&sort_keys[0], &threshold)?, &is_null(&sort_keys[0])?)? - } - (false, true) => { - or(>(&sort_keys[0], &threshold)?, &is_null(&sort_keys[0])?)? - } - (false, false) => { - or(<(&sort_keys[0], &threshold)?, &is_null(&sort_keys[0])?)? - } + let mut filter = match (is_multi_col, self.expr[0].options.descending) { + (true, true) => BooleanArray::new( + gt_eq(&sort_keys[0], &threshold)?.values().clone(), + None, + ), + (true, false) => BooleanArray::new( + lt_eq(&sort_keys[0], &threshold)?.values().clone(), + None, + ), + (false, true) => BooleanArray::new( + gt(&sort_keys[0], &threshold)?.values().clone(), + None, + ), + (false, false) => BooleanArray::new( + lt(&sort_keys[0], &threshold)?.values().clone(), + None, + ), }; + if sort_keys[0].is_nullable() { + // Keep any null values + // TODO it is possible to optimize this based on the current threshold value + // and the nulls first/last option and the number of following sort keys + filter = or(&filter, &is_null(&sort_keys[0])?)?; + } if filter.true_count() == 0 { // No rows are less than the max row, so we can skip this batch // Early completion is still possible, as last row might be greater From b046a73b1c3efdbf528164fd78762fcc148bae3a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Mon, 14 Apr 2025 09:00:55 +0200 Subject: [PATCH 13/24] Update datafusion/physical-plan/src/topk/mod.rs Co-authored-by: Yongting You <2010youy01@gmail.com> --- datafusion/physical-plan/src/topk/mod.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/datafusion/physical-plan/src/topk/mod.rs b/datafusion/physical-plan/src/topk/mod.rs index a51875e0b2ce..0b77b508e4a0 100644 --- a/datafusion/physical-plan/src/topk/mod.rs +++ b/datafusion/physical-plan/src/topk/mod.rs @@ -204,7 +204,9 @@ impl TopK { }) .collect::>>()?; - // selected indices + // Selected indices in the input batch. + // Some indices may be pre-filtered if they exceed the heap’s current max value. + let mut selected_rows = None; // If the heap doesn't have k elements yet, we can't create thresholds From fe6fc48669582370e09c1b8d1edb70bce06fc1d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Mon, 14 Apr 2025 10:10:38 +0200 Subject: [PATCH 14/24] Clippy --- datafusion/physical-plan/src/topk/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/physical-plan/src/topk/mod.rs b/datafusion/physical-plan/src/topk/mod.rs index 0b77b508e4a0..485bf4c7026b 100644 --- a/datafusion/physical-plan/src/topk/mod.rs +++ b/datafusion/physical-plan/src/topk/mod.rs @@ -23,7 +23,7 @@ use arrow::{ row::{RowConverter, Rows, SortField}, }; use arrow_ord::cmp::{gt, gt_eq, lt, lt_eq}; -use std::{clone, mem::size_of}; +use std::mem::size_of; use std::{cmp::Ordering, collections::BinaryHeap, sync::Arc}; use super::metrics::{BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder}; From f457ce8e76efe84224ddc4704285e36436b24be0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Mon, 14 Apr 2025 21:07:55 +0200 Subject: [PATCH 15/24] Refactor --- datafusion/physical-plan/src/topk/mod.rs | 55 +++++++++++++----------- 1 file changed, 31 insertions(+), 24 deletions(-) diff --git a/datafusion/physical-plan/src/topk/mod.rs b/datafusion/physical-plan/src/topk/mod.rs index 485bf4c7026b..0f9f5d33c497 100644 --- a/datafusion/physical-plan/src/topk/mod.rs +++ b/datafusion/physical-plan/src/topk/mod.rs @@ -289,32 +289,18 @@ impl TopK { match selected_rows { Some(filter) => { - for (index, row) in filter.values().set_indices().zip(rows.iter()) { - match self.heap.max() { - // heap has k items, and the new row is greater than the - // current max in the heap ==> it is not a new topk - Some(max_row) if row.as_ref() >= max_row.row() => {} - // don't yet have k items or new item is lower than the currently k low values - None | Some(_) => { - self.heap.add(&mut batch_entry, row, index); - replacements += 1; - } - } - } + self.find_new_topk_items( + filter.values().set_indices(), + &mut batch_entry, + &mut replacements, + ); } None => { - for (index, row) in rows.iter().enumerate() { - match self.heap.max() { - // heap has k items, and the new row is greater than the - // current max in the heap ==> it is not a new topk - Some(max_row) if row.as_ref() >= max_row.row() => {} - // don't yet have k items or new item is lower than the currently k low values - None | Some(_) => { - self.heap.add(&mut batch_entry, row, index); - replacements += 1; - } - } - } + self.find_new_topk_items( + 0..sort_keys[0].len(), + &mut batch_entry, + &mut replacements, + ); } } @@ -334,6 +320,27 @@ impl TopK { Ok(()) } + fn find_new_topk_items( + &mut self, + items: impl Iterator, + batch_entry: &mut RecordBatchEntry, + replacements: &mut usize, + ) { + let rows = &mut self.scratch_rows; + for (index, row) in items.zip(rows.iter()) { + match self.heap.max() { + // heap has k items, and the new row is greater than the + // current max in the heap ==> it is not a new topk + Some(max_row) if row.as_ref() >= max_row.row() => {} + // don't yet have k items or new item is lower than the currently k low values + None | Some(_) => { + self.heap.add(batch_entry, row, index); + *replacements += 1; + } + } + } + } + /// If input ordering shares a common sort prefix with the TopK, and if the TopK's heap is full, /// check if the computation can be finished early. /// This is the case if the last row of the current batch is strictly greater than the max row in the heap, From 40dc1d9cc39ea6df0a3149113b8ebf31b4ef0ab3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Mon, 14 Apr 2025 21:13:54 +0200 Subject: [PATCH 16/24] Ignore null threshold --- datafusion/physical-plan/src/topk/mod.rs | 102 +++++++++++++---------- 1 file changed, 56 insertions(+), 46 deletions(-) diff --git a/datafusion/physical-plan/src/topk/mod.rs b/datafusion/physical-plan/src/topk/mod.rs index 0f9f5d33c497..bc43a247770a 100644 --- a/datafusion/physical-plan/src/topk/mod.rs +++ b/datafusion/physical-plan/src/topk/mod.rs @@ -23,6 +23,7 @@ use arrow::{ row::{RowConverter, Rows, SortField}, }; use arrow_ord::cmp::{gt, gt_eq, lt, lt_eq}; +use datafusion_expr::ColumnarValue; use std::mem::size_of; use std::{cmp::Ordering, collections::BinaryHeap, sync::Arc}; @@ -226,56 +227,65 @@ impl TopK { let expr = Arc::clone(&self.expr[0].expr); let value = expr.evaluate(&batch_entry.batch.slice(max_row.index, 1))?; - // Convert to scalar value - should be a single value since we're evaluating on a single row batch - let threshold = Scalar::new(value.to_array(1)?); - - // Create a filter for each sort key - let is_multi_col = self.expr.len() > 1; - - let mut filter = match (is_multi_col, self.expr[0].options.descending) { - (true, true) => BooleanArray::new( - gt_eq(&sort_keys[0], &threshold)?.values().clone(), - None, - ), - (true, false) => BooleanArray::new( - lt_eq(&sort_keys[0], &threshold)?.values().clone(), - None, - ), - (false, true) => BooleanArray::new( - gt(&sort_keys[0], &threshold)?.values().clone(), - None, - ), - (false, false) => BooleanArray::new( - lt(&sort_keys[0], &threshold)?.values().clone(), - None, - ), + let scalar_is_null = if let ColumnarValue::Scalar(scalar_value) = &value { + scalar_value.is_null() + } else { + false }; - if sort_keys[0].is_nullable() { - // Keep any null values - // TODO it is possible to optimize this based on the current threshold value - // and the nulls first/last option and the number of following sort keys - filter = or(&filter, &is_null(&sort_keys[0])?)?; - } - if filter.true_count() == 0 { - // No rows are less than the max row, so we can skip this batch - // Early completion is still possible, as last row might be greater - self.attempt_early_completion(&batch)?; - return Ok(()); - } + // skip filtering if threshold is null + if !scalar_is_null { + // Convert to scalar value - should be a single value since we're evaluating on a single row batch + let threshold = Scalar::new(value.to_array(1)?); + + // Create a filter for each sort key + let is_multi_col = self.expr.len() > 1; + + let mut filter = match (is_multi_col, self.expr[0].options.descending) { + (true, true) => BooleanArray::new( + gt_eq(&sort_keys[0], &threshold)?.values().clone(), + None, + ), + (true, false) => BooleanArray::new( + lt_eq(&sort_keys[0], &threshold)?.values().clone(), + None, + ), + (false, true) => BooleanArray::new( + gt(&sort_keys[0], &threshold)?.values().clone(), + None, + ), + (false, false) => BooleanArray::new( + lt(&sort_keys[0], &threshold)?.values().clone(), + None, + ), + }; + if sort_keys[0].is_nullable() { + // Keep any null values + // TODO it is possible to optimize this based on the current threshold value + // and the nulls first/last option and the number of following sort keys + filter = or(&filter, &is_null(&sort_keys[0])?)?; + } + if filter.true_count() == 0 { + // No rows are less than the max row, so we can skip this batch + // Early completion is still possible, as last row might be greater + self.attempt_early_completion(&batch)?; - let filter_predicate = FilterBuilder::new(&filter); - let filter_predicate = if sort_keys.len() > 1 { - filter_predicate.optimize().build() - } else { - filter_predicate.build() - }; - selected_rows = Some(filter); + return Ok(()); + } - sort_keys = sort_keys - .iter() - .map(|key| filter_predicate.filter(key).map_err(|x| x.into())) - .collect::>>()?; + let filter_predicate = FilterBuilder::new(&filter); + let filter_predicate = if sort_keys.len() > 1 { + filter_predicate.optimize().build() + } else { + filter_predicate.build() + }; + selected_rows = Some(filter); + + sort_keys = sort_keys + .iter() + .map(|key| filter_predicate.filter(key).map_err(|x| x.into())) + .collect::>>()?; + } } // reuse existing `Rows` to avoid reallocations From 8d1bfe3ab5626924363e0b057f45298ba847ca55 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Tue, 15 Apr 2025 03:13:39 +0200 Subject: [PATCH 17/24] Update datafusion/physical-plan/src/topk/mod.rs Co-authored-by: Andrew Lamb --- datafusion/physical-plan/src/topk/mod.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/datafusion/physical-plan/src/topk/mod.rs b/datafusion/physical-plan/src/topk/mod.rs index bc43a247770a..f099685c85b6 100644 --- a/datafusion/physical-plan/src/topk/mod.rs +++ b/datafusion/physical-plan/src/topk/mod.rs @@ -213,9 +213,8 @@ impl TopK { // If the heap doesn't have k elements yet, we can't create thresholds if let Some(max_row) = self.heap.max() { // Get the batch that contains the max row - let batch_entry = match self.heap.store.get(max_row.batch_id) { - Some(entry) => entry, - None => return internal_err!("Invalid batch ID in TopKRow"), + let Some(entry) = self.heap.store.get(max_row.batch_id) else { + return internal_err!("Invalid batch ID in TopKRow"), }; // Extract threshold values for each sort expression From f735f645ec47e0c472eb18d626107aef305a538b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Tue, 15 Apr 2025 09:18:10 +0200 Subject: [PATCH 18/24] Fix --- datafusion/physical-plan/src/topk/mod.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/physical-plan/src/topk/mod.rs b/datafusion/physical-plan/src/topk/mod.rs index f099685c85b6..e42744ed723a 100644 --- a/datafusion/physical-plan/src/topk/mod.rs +++ b/datafusion/physical-plan/src/topk/mod.rs @@ -213,8 +213,8 @@ impl TopK { // If the heap doesn't have k elements yet, we can't create thresholds if let Some(max_row) = self.heap.max() { // Get the batch that contains the max row - let Some(entry) = self.heap.store.get(max_row.batch_id) else { - return internal_err!("Invalid batch ID in TopKRow"), + let Some(batch_entry) = self.heap.store.get(max_row.batch_id) else { + return internal_err!("Invalid batch ID in TopKRow"); }; // Extract threshold values for each sort expression From 923089f6ffeadedb7773b409544e355f48c40774 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Tue, 15 Apr 2025 09:30:07 +0200 Subject: [PATCH 19/24] Minor improvements --- datafusion/physical-plan/src/topk/mod.rs | 28 ++++++++---------------- 1 file changed, 9 insertions(+), 19 deletions(-) diff --git a/datafusion/physical-plan/src/topk/mod.rs b/datafusion/physical-plan/src/topk/mod.rs index e42744ed723a..5fc952a95f88 100644 --- a/datafusion/physical-plan/src/topk/mod.rs +++ b/datafusion/physical-plan/src/topk/mod.rs @@ -274,6 +274,7 @@ impl TopK { let filter_predicate = FilterBuilder::new(&filter); let filter_predicate = if sort_keys.len() > 1 { + // Optimize filter when it has multiple sort keys filter_predicate.optimize().build() } else { filter_predicate.build() @@ -294,24 +295,12 @@ impl TopK { let mut batch_entry = self.heap.register_batch(batch.clone()); - let mut replacements = 0; - - match selected_rows { + let replacements = match selected_rows { Some(filter) => { - self.find_new_topk_items( - filter.values().set_indices(), - &mut batch_entry, - &mut replacements, - ); + self.find_new_topk_items(filter.values().set_indices(), &mut batch_entry) } - None => { - self.find_new_topk_items( - 0..sort_keys[0].len(), - &mut batch_entry, - &mut replacements, - ); - } - } + None => self.find_new_topk_items(0..sort_keys[0].len(), &mut batch_entry), + }; self.metrics.row_replacements.add(replacements); self.heap.insert_batch_entry(batch_entry); @@ -333,8 +322,8 @@ impl TopK { &mut self, items: impl Iterator, batch_entry: &mut RecordBatchEntry, - replacements: &mut usize, - ) { + ) -> usize { + let mut replacements = 0; let rows = &mut self.scratch_rows; for (index, row) in items.zip(rows.iter()) { match self.heap.max() { @@ -344,10 +333,11 @@ impl TopK { // don't yet have k items or new item is lower than the currently k low values None | Some(_) => { self.heap.add(batch_entry, row, index); - *replacements += 1; + replacements += 1; } } } + replacements } /// If input ordering shares a common sort prefix with the TopK, and if the TopK's heap is full, From 9a99cce6eaed473baf07006ab11d2765742f29d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Fri, 18 Apr 2025 22:50:28 +0200 Subject: [PATCH 20/24] Fix --- Cargo.lock | 8 +- datafusion/physical-plan/src/sorts/sort.rs | 126 +++++++++++++++++++-- docs/source/user-guide/sql/explain.md | 4 +- 3 files changed, 121 insertions(+), 17 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 2b3eeecf5d9b..ee6dda88a0e2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3966,9 +3966,9 @@ dependencies = [ [[package]] name = "libz-rs-sys" -version = "0.4.2" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "902bc563b5d65ad9bba616b490842ef0651066a1a1dc3ce1087113ffcb873c8d" +checksum = "6489ca9bd760fe9642d7644e827b0c9add07df89857b0416ee15c1cc1a3b8c5a" dependencies = [ "zlib-rs", ] @@ -7512,9 +7512,9 @@ dependencies = [ [[package]] name = "zlib-rs" -version = "0.4.2" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b20717f0917c908dc63de2e44e97f1e6b126ca58d0e391cee86d504eb8fbd05" +checksum = "868b928d7949e09af2f6086dfc1e01936064cc7a819253bce650d4e2a2d63ba8" [[package]] name = "zstd" diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index 5cc7512f8813..8c0c6a7e8ea9 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -51,7 +51,7 @@ use arrow::compute::{concat_batches, lexsort_to_indices, take_arrays, SortColumn use arrow::datatypes::{DataType, SchemaRef}; use arrow::row::{RowConverter, Rows, SortField}; use datafusion_common::{ - exec_datafusion_err, internal_datafusion_err, internal_err, Result, + exec_datafusion_err, internal_datafusion_err, internal_err, DataFusionError, Result, }; use datafusion_execution::disk_manager::RefCountedTempFile; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; @@ -323,13 +323,8 @@ impl ExternalSorter { } self.reserve_memory_for_merge()?; - - let size = get_reserved_byte_for_record_batch(&input); - if self.reservation.try_grow(size).is_err() { - self.sort_and_spill_in_mem_batches().await?; - // After spilling all in-memory batches, the retry should succeed - self.reservation.try_grow(size)?; - } + self.reserve_memory_for_batch_and_maybe_spill(&input) + .await?; self.in_mem_batches.push(input); Ok(()) @@ -529,6 +524,12 @@ impl ExternalSorter { /// Sorts the in-memory batches and merges them into a single sorted run, then writes /// the result to spill files. async fn sort_and_spill_in_mem_batches(&mut self) -> Result<()> { + if self.in_mem_batches.is_empty() { + return internal_err!( + "in_mem_batches must not be empty when attempting to sort and spill" + ); + } + // Release the memory reserved for merge back to the pool so // there is some left when `in_mem_sort_stream` requests an // allocation. At the end of this function, memory will be @@ -678,7 +679,8 @@ impl ExternalSorter { let batch = concat_batches(&self.schema, &self.in_mem_batches)?; self.in_mem_batches.clear(); self.reservation - .try_resize(get_reserved_byte_for_record_batch(&batch))?; + .try_resize(get_reserved_byte_for_record_batch(&batch)) + .map_err(Self::err_with_oom_context)?; let reservation = self.reservation.take(); return self.sort_batch_stream(batch, metrics, reservation); } @@ -759,12 +761,51 @@ impl ExternalSorter { if self.runtime.disk_manager.tmp_files_enabled() { let size = self.sort_spill_reservation_bytes; if self.merge_reservation.size() != size { - self.merge_reservation.try_resize(size)?; + self.merge_reservation + .try_resize(size) + .map_err(Self::err_with_oom_context)?; } } Ok(()) } + + /// Reserves memory to be able to accommodate the given batch. + /// If memory is scarce, tries to spill current in-memory batches to disk first. + async fn reserve_memory_for_batch_and_maybe_spill( + &mut self, + input: &RecordBatch, + ) -> Result<()> { + let size = get_reserved_byte_for_record_batch(input); + + match self.reservation.try_grow(size) { + Ok(_) => Ok(()), + Err(e) => { + if self.in_mem_batches.is_empty() { + return Err(Self::err_with_oom_context(e)); + } + + // Spill and try again. + self.sort_and_spill_in_mem_batches().await?; + self.reservation + .try_grow(size) + .map_err(Self::err_with_oom_context) + } + } + } + + /// Wraps the error with a context message suggesting settings to tweak. + /// This is meant to be used with DataFusionError::ResourcesExhausted only. + fn err_with_oom_context(e: DataFusionError) -> DataFusionError { + match e { + DataFusionError::ResourcesExhausted(_) => e.context( + "Not enough memory to continue external sort. \ + Consider increasing the memory limit, or decreasing sort_spill_reservation_bytes" + ), + // This is not an OOM error, so just return it as is. + _ => e, + } + } } /// Estimate how much memory is needed to sort a `RecordBatch`. @@ -1327,7 +1368,7 @@ mod tests { use arrow::datatypes::*; use datafusion_common::cast::as_primitive_array; use datafusion_common::test_util::batches_to_string; - use datafusion_common::{Result, ScalarValue}; + use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_execution::config::SessionConfig; use datafusion_execution::runtime_env::RuntimeEnvBuilder; use datafusion_execution::RecordBatchStream; @@ -1552,6 +1593,69 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_batch_reservation_error() -> Result<()> { + // Pick a memory limit and sort_spill_reservation that make the first batch reservation fail. + // These values assume that the ExternalSorter will reserve 800 bytes for the first batch. + let expected_batch_reservation = 800; + let merge_reservation: usize = 0; // Set to 0 for simplicity + let memory_limit: usize = expected_batch_reservation + merge_reservation - 1; // Just short of what we need + + let session_config = + SessionConfig::new().with_sort_spill_reservation_bytes(merge_reservation); + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(memory_limit, 1.0) + .build_arc()?; + let task_ctx = Arc::new( + TaskContext::default() + .with_session_config(session_config) + .with_runtime(runtime), + ); + + let plan = test::scan_partitioned(1); + + // Read the first record batch to assert that our memory limit and sort_spill_reservation + // settings trigger the test scenario. + { + let mut stream = plan.execute(0, Arc::clone(&task_ctx))?; + let first_batch = stream.next().await.unwrap()?; + let batch_reservation = get_reserved_byte_for_record_batch(&first_batch); + + assert_eq!(batch_reservation, expected_batch_reservation); + assert!(memory_limit < (merge_reservation + batch_reservation)); + } + + let sort_exec = Arc::new(SortExec::new( + LexOrdering::new(vec![PhysicalSortExpr { + expr: col("i", &plan.schema())?, + options: SortOptions::default(), + }]), + plan, + )); + + let result = collect( + Arc::clone(&sort_exec) as Arc, + Arc::clone(&task_ctx), + ) + .await; + + let err = result.unwrap_err(); + assert!( + matches!(err, DataFusionError::Context(..)), + "Assertion failed: expected a Context error, but got: {:?}", + err + ); + + // Assert that the context error is wrapping a resources exhausted error. + assert!( + matches!(err.find_root(), DataFusionError::ResourcesExhausted(_)), + "Assertion failed: expected a ResourcesExhausted error, but got: {:?}", + err + ); + + Ok(()) + } + #[tokio::test] async fn test_sort_spill_utf8_strings() -> Result<()> { let session_config = SessionConfig::new() diff --git a/docs/source/user-guide/sql/explain.md b/docs/source/user-guide/sql/explain.md index 39d42f1c4982..9984de147ecc 100644 --- a/docs/source/user-guide/sql/explain.md +++ b/docs/source/user-guide/sql/explain.md @@ -71,7 +71,7 @@ to see the high level structure of the plan | | ┌─────────────┴─────────────┐ | | | │ RepartitionExec │ | | | │ -------------------- │ | -| | │ output_partition_count: │ | +| | │ input_partition_count: │ | | | │ 16 │ | | | │ │ | | | │ partitioning_scheme: │ | @@ -80,7 +80,7 @@ to see the high level structure of the plan | | ┌─────────────┴─────────────┐ | | | │ RepartitionExec │ | | | │ -------------------- │ | -| | │ output_partition_count: │ | +| | │ input_partition_count: │ | | | │ 1 │ | | | │ │ | | | │ partitioning_scheme: │ | From ea480b7f01eb058ba62aceb5e643d803cf19f414 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Fri, 18 Apr 2025 22:51:59 +0200 Subject: [PATCH 21/24] Fix --- docs/source/user-guide/sql/explain.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/user-guide/sql/explain.md b/docs/source/user-guide/sql/explain.md index 9984de147ecc..67532dbee0da 100644 --- a/docs/source/user-guide/sql/explain.md +++ b/docs/source/user-guide/sql/explain.md @@ -71,7 +71,7 @@ to see the high level structure of the plan | | ┌─────────────┴─────────────┐ | | | │ RepartitionExec │ | | | │ -------------------- │ | -| | │ input_partition_count: │ | +| | │ output_partition_count: │ | | | │ 16 │ | | | │ │ | | | │ partitioning_scheme: │ | @@ -80,7 +80,7 @@ to see the high level structure of the plan | | ┌─────────────┴─────────────┐ | | | │ RepartitionExec │ | | | │ -------------------- │ | -| | │ input_partition_count: │ | +| | │ output_partition_count: │ | | | │ 1 │ | | | │ │ | | | │ partitioning_scheme: │ | From 386596097b950015b54e77fa7ecc3dcfef89c201 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Fri, 18 Apr 2025 22:52:24 +0200 Subject: [PATCH 22/24] Fix --- docs/source/user-guide/sql/explain.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/user-guide/sql/explain.md b/docs/source/user-guide/sql/explain.md index 67532dbee0da..39d42f1c4982 100644 --- a/docs/source/user-guide/sql/explain.md +++ b/docs/source/user-guide/sql/explain.md @@ -71,7 +71,7 @@ to see the high level structure of the plan | | ┌─────────────┴─────────────┐ | | | │ RepartitionExec │ | | | │ -------------------- │ | -| | │ output_partition_count: │ | +| | │ output_partition_count: │ | | | │ 16 │ | | | │ │ | | | │ partitioning_scheme: │ | @@ -80,7 +80,7 @@ to see the high level structure of the plan | | ┌─────────────┴─────────────┐ | | | │ RepartitionExec │ | | | │ -------------------- │ | -| | │ output_partition_count: │ | +| | │ output_partition_count: │ | | | │ 1 │ | | | │ │ | | | │ partitioning_scheme: │ | From 2425b9368077ba644d7d933b5616f7daa755f14a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Fri, 18 Apr 2025 23:10:20 +0200 Subject: [PATCH 23/24] Add scalarvalue api --- datafusion/physical-plan/src/topk/mod.rs | 46 +++++++++++++++++------- 1 file changed, 33 insertions(+), 13 deletions(-) diff --git a/datafusion/physical-plan/src/topk/mod.rs b/datafusion/physical-plan/src/topk/mod.rs index 5fc952a95f88..76839b7deb0a 100644 --- a/datafusion/physical-plan/src/topk/mod.rs +++ b/datafusion/physical-plan/src/topk/mod.rs @@ -24,15 +24,15 @@ use arrow::{ }; use arrow_ord::cmp::{gt, gt_eq, lt, lt_eq}; use datafusion_expr::ColumnarValue; -use std::mem::size_of; use std::{cmp::Ordering, collections::BinaryHeap, sync::Arc}; +use std::{mem::size_of, sync::RwLock}; use super::metrics::{BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder}; use crate::spill::get_record_batch_memory_size; use crate::{stream::RecordBatchStreamAdapter, SendableRecordBatchStream}; use arrow::array::{ArrayRef, RecordBatch}; use arrow::datatypes::SchemaRef; -use datafusion_common::{internal_datafusion_err, HashMap}; +use datafusion_common::{internal_datafusion_err, HashMap, ScalarValue}; use datafusion_common::{internal_err, Result}; use datafusion_execution::{ memory_pool::{MemoryConsumer, MemoryReservation}, @@ -120,6 +120,8 @@ pub struct TopK { /// to be greater (by byte order, after row conversion) than the top K, /// which means the top K won't change and the computation can be finished early. pub(crate) finished: bool, + + thresholds: Arc>>>, } // Guesstimate for memory allocation: estimated number of bytes used per row in the RowConverter @@ -173,7 +175,7 @@ impl TopK { build_sort_fields(&common_sort_prefix, &schema)?; Some(RowConverter::new(input_sort_fields)?) }; - + let num_exprs = expr.len(); Ok(Self { schema: Arc::clone(&schema), metrics: TopKMetrics::new(metrics, partition_id), @@ -186,6 +188,7 @@ impl TopK { common_sort_prefix_converter: prefix_row_converter, common_sort_prefix: Arc::from(common_sort_prefix), finished: false, + thresholds: Arc::new(RwLock::new(vec![None; num_exprs])), }) } @@ -223,19 +226,31 @@ impl TopK { // This could use BinaryExpr to benefit from short circuiting and early evaluation // https://github.com/apache/datafusion/issues/15698 // Extract the value for this column from the max row - let expr = Arc::clone(&self.expr[0].expr); - let value = expr.evaluate(&batch_entry.batch.slice(max_row.index, 1))?; - - let scalar_is_null = if let ColumnarValue::Scalar(scalar_value) = &value { - scalar_value.is_null() - } else { - false - }; + let thresholds: Vec<_> = self + .expr + .iter() + .map(|expr| { + let value = expr + .expr + .evaluate(&batch_entry.batch.slice(max_row.index, 1))?; + Ok(Some(match value { + ColumnarValue::Array(array) => { + ScalarValue::try_from_array(&array, 0)? + } + ColumnarValue::Scalar(scalar_value) => scalar_value, + })) + }) + .collect::>()?; + self.thresholds + .write() + .expect("Write lock should succeed") + .clone_from(&thresholds); + let threshold0 = thresholds[0].as_ref().unwrap(); // skip filtering if threshold is null - if !scalar_is_null { + if !threshold0.is_null() { // Convert to scalar value - should be a single value since we're evaluating on a single row batch - let threshold = Scalar::new(value.to_array(1)?); + let threshold = Scalar::new(threshold0.to_array_of_size(1)?); // Create a filter for each sort key let is_multi_col = self.expr.len() > 1; @@ -429,6 +444,7 @@ impl TopK { common_sort_prefix_converter: _, common_sort_prefix: _, finished: _, + thresholds: _, } = self; let _timer = metrics.baseline.elapsed_compute().timer(); // time updated on drop @@ -461,6 +477,10 @@ impl TopK { + self.scratch_rows.size() + self.heap.size() } + + pub fn thresholds(&self) -> &Arc>>> { + &self.thresholds + } } struct TopKMetrics { From 63eca0b96b55a2c5f97ba3f8bc1af4a6a5e448f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Sat, 19 Apr 2025 00:07:32 +0200 Subject: [PATCH 24/24] Only update if heap updated --- datafusion/physical-plan/src/topk/mod.rs | 70 ++++++++++++------------ 1 file changed, 36 insertions(+), 34 deletions(-) diff --git a/datafusion/physical-plan/src/topk/mod.rs b/datafusion/physical-plan/src/topk/mod.rs index 76839b7deb0a..685f440f3b7d 100644 --- a/datafusion/physical-plan/src/topk/mod.rs +++ b/datafusion/physical-plan/src/topk/mod.rs @@ -33,7 +33,7 @@ use crate::{stream::RecordBatchStreamAdapter, SendableRecordBatchStream}; use arrow::array::{ArrayRef, RecordBatch}; use arrow::datatypes::SchemaRef; use datafusion_common::{internal_datafusion_err, HashMap, ScalarValue}; -use datafusion_common::{internal_err, Result}; +use datafusion_common::Result; use datafusion_execution::{ memory_pool::{MemoryConsumer, MemoryReservation}, runtime_env::RuntimeEnv, @@ -213,40 +213,14 @@ impl TopK { let mut selected_rows = None; - // If the heap doesn't have k elements yet, we can't create thresholds - if let Some(max_row) = self.heap.max() { - // Get the batch that contains the max row - let Some(batch_entry) = self.heap.store.get(max_row.batch_id) else { - return internal_err!("Invalid batch ID in TopKRow"); - }; - - // Extract threshold values for each sort expression - // TODO: create a filter for each key that respects lexical ordering - // in the form of col0 < threshold0 || col0 == threshold0 && (col1 < threshold1 || ...) - // This could use BinaryExpr to benefit from short circuiting and early evaluation - // https://github.com/apache/datafusion/issues/15698 - // Extract the value for this column from the max row - let thresholds: Vec<_> = self - .expr - .iter() - .map(|expr| { - let value = expr - .expr - .evaluate(&batch_entry.batch.slice(max_row.index, 1))?; - Ok(Some(match value { - ColumnarValue::Array(array) => { - ScalarValue::try_from_array(&array, 0)? - } - ColumnarValue::Scalar(scalar_value) => scalar_value, - })) - }) - .collect::>()?; - self.thresholds - .write() - .expect("Write lock should succeed") - .clone_from(&thresholds); - let threshold0 = thresholds[0].as_ref().unwrap(); + let threshold0 = self + .thresholds + .read() + .expect("Read lock should succeed")[0].clone(); + // If the heap doesn't have k elements yet, we can't create thresholds + if let Some(threshold0) = threshold0 { + let threshold0 = threshold0.clone(); // skip filtering if threshold is null if !threshold0.is_null() { // Convert to scalar value - should be a single value since we're evaluating on a single row batch @@ -318,6 +292,34 @@ impl TopK { }; self.metrics.row_replacements.add(replacements); + + if replacements > 0 { + // Extract threshold values for each sort expression + // TODO: create a filter for each key that respects lexical ordering + // in the form of col0 < threshold0 || col0 == threshold0 && (col1 < threshold1 || ...) + // This could use BinaryExpr to benefit from short circuiting and early evaluation + // https://github.com/apache/datafusion/issues/15698 + // Extract the value for this column from the max row + let thresholds: Vec<_> = self + .expr + .iter() + .map(|expr| { + let value = expr + .expr + .evaluate(&batch_entry.batch.slice(self.heap.max().unwrap().index, 1))?; + Ok(Some(match value { + ColumnarValue::Array(array) => { + ScalarValue::try_from_array(&array, 0)? + } + ColumnarValue::Scalar(scalar_value) => scalar_value, + })) + }) + .collect::>()?; + self.thresholds + .write() + .expect("Write lock should succeed") + .clone_from(&thresholds); + } self.heap.insert_batch_entry(batch_entry); // conserve memory