Skip to content

Add tests for BatchCoalescer::push_batch_with_filter, fix bug #7774

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
236 changes: 226 additions & 10 deletions arrow-select/src/coalesce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,10 @@ impl BatchCoalescer {
fn create_in_progress_array(data_type: &DataType, batch_size: usize) -> Box<dyn InProgressArray> {
macro_rules! instantiate_primitive {
($t:ty) => {
Box::new(InProgressPrimitiveArray::<$t>::new(batch_size))
Box::new(InProgressPrimitiveArray::<$t>::new(
batch_size,
data_type.clone(),
))
};
}

Expand Down Expand Up @@ -391,9 +394,11 @@ mod tests {
use arrow_array::builder::StringViewBuilder;
use arrow_array::cast::AsArray;
use arrow_array::{
BinaryViewArray, RecordBatchOptions, StringArray, StringViewArray, UInt32Array,
BinaryViewArray, Int64Array, RecordBatchOptions, StringArray, StringViewArray,
TimestampNanosecondArray, UInt32Array,
};
use arrow_schema::{DataType, Field, Schema};
use rand::{Rng, SeedableRng};
use std::ops::Range;

#[test]
Expand Down Expand Up @@ -484,6 +489,98 @@ mod tests {
.run();
}

/// Coalesce multiple batches, 80k rows, with a 0.1% selectivity filter
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea is to get filter coverage prior to writing some specialized code for it

#[test]
fn test_coalesce_filtered_001() {
let mut filter_builder = RandomFilterBuilder {
num_rows: 8000,
selectivity: 0.001,
seed: 0,
};

// add 10 batches of 8000 rows each
// 80k rows, selecting 0.1% means 80 rows
// not exactly 80 as the rows are random;
let mut test = Test::new();
for _ in 0..10 {
test = test
.with_batch(multi_column_batch(0..8000))
.with_filter(filter_builder.next_filter())
}
test.with_batch_size(15)
.with_expected_output_sizes(vec![15, 15, 15, 13])
.run();
}

/// Coalesce multiple batches, 80k rows, with a 1% selectivity filter
#[test]
fn test_coalesce_filtered_01() {
let mut filter_builder = RandomFilterBuilder {
num_rows: 8000,
selectivity: 0.01,
seed: 0,
};

// add 10 batches of 8000 rows each
// 80k rows, selecting 1% means 800 rows
// not exactly 800 as the rows are random;
let mut test = Test::new();
for _ in 0..10 {
test = test
.with_batch(multi_column_batch(0..8000))
.with_filter(filter_builder.next_filter())
}
test.with_batch_size(128)
.with_expected_output_sizes(vec![128, 128, 128, 128, 128, 128, 15])
.run();
}

/// Coalesce multiple batches, 80k rows, with a 10% selectivity filter
#[test]
fn test_coalesce_filtered_1() {
let mut filter_builder = RandomFilterBuilder {
num_rows: 8000,
selectivity: 0.1,
seed: 0,
};

// add 10 batches of 8000 rows each
// 80k rows, selecting 10% means 8000 rows
// not exactly 800 as the rows are random;
let mut test = Test::new();
for _ in 0..10 {
test = test
.with_batch(multi_column_batch(0..8000))
.with_filter(filter_builder.next_filter())
}
test.with_batch_size(1024)
.with_expected_output_sizes(vec![1024, 1024, 1024, 1024, 1024, 1024, 1024, 840])
.run();
}

/// Coalesce multiple batches, 8k rows, with a 90% selectivity filter
#[test]
fn test_coalesce_filtered_90() {
let mut filter_builder = RandomFilterBuilder {
num_rows: 800,
selectivity: 0.90,
seed: 0,
};

// add 10 batches of 800 rows each
// 8k rows, selecting 99% means 7200 rows
// not exactly 7200 as the rows are random;
let mut test = Test::new();
for _ in 0..10 {
test = test
.with_batch(multi_column_batch(0..800))
.with_filter(filter_builder.next_filter())
}
test.with_batch_size(1024)
.with_expected_output_sizes(vec![1024, 1024, 1024, 1024, 1024, 1024, 1024, 13])
.run();
}

#[test]
fn test_coalesce_non_null() {
Test::new()
Expand Down Expand Up @@ -862,6 +959,11 @@ mod tests {
struct Test {
/// Batches to feed to the coalescer.
input_batches: Vec<RecordBatch>,
/// Filters to apply to the corresponding input batches.
///
/// If there are no filters for the input batches, the batch will be
/// pushed as is.
filters: Vec<BooleanArray>,
/// The schema. If not provided, the first batch's schema is used.
schema: Option<SchemaRef>,
/// Expected output sizes of the resulting batches
Expand All @@ -874,6 +976,7 @@ mod tests {
fn default() -> Self {
Self {
input_batches: vec![],
filters: vec![],
schema: None,
expected_output_sizes: vec![],
target_batch_size: 1024,
Expand All @@ -898,6 +1001,12 @@ mod tests {
self
}

/// Extend the filters with `filter`
fn with_filter(mut self, filter: BooleanArray) -> Self {
self.filters.push(filter);
self
}

/// Extends the input batches with `batches`
fn with_batches(mut self, batches: impl IntoIterator<Item = RecordBatch>) -> Self {
self.input_batches.extend(batches);
Expand All @@ -920,23 +1029,29 @@ mod tests {
///
/// Returns the resulting output batches
fn run(self) -> Vec<RecordBatch> {
let expected_output = self.expected_output();
let schema = self.schema();

let Self {
input_batches,
schema,
filters,
schema: _,
target_batch_size,
expected_output_sizes,
} = self;

let schema = schema.unwrap_or_else(|| input_batches[0].schema());

// create a single large input batch for output comparison
let single_input_batch = concat_batches(&schema, &input_batches).unwrap();
let had_input = input_batches.iter().any(|b| b.num_rows() > 0);

let mut coalescer = BatchCoalescer::new(Arc::clone(&schema), target_batch_size);

let had_input = input_batches.iter().any(|b| b.num_rows() > 0);
// feed input batches and filters to the coalescer
let mut filters = filters.into_iter();
for batch in input_batches {
coalescer.push_batch(batch).unwrap();
if let Some(filter) = filters.next() {
coalescer.push_batch_with_filter(batch, &filter).unwrap();
} else {
coalescer.push_batch(batch).unwrap();
}
}
assert_eq!(schema, coalescer.schema());

Expand Down Expand Up @@ -976,7 +1091,7 @@ mod tests {
for (i, (expected_size, batch)) in iter {
// compare the contents of the batch after normalization (using
// `==` compares the underlying memory layout too)
let expected_batch = single_input_batch.slice(starting_idx, *expected_size);
let expected_batch = expected_output.slice(starting_idx, *expected_size);
let expected_batch = normalize_batch(expected_batch);
let batch = normalize_batch(batch.clone());
assert_eq!(
Expand All @@ -988,6 +1103,36 @@ mod tests {
}
output_batches
}

/// Return the expected output schema. If not overridden by `with_schema`, it
/// returns the schema of the first input batch.
fn schema(&self) -> SchemaRef {
self.schema
.clone()
.unwrap_or_else(|| Arc::clone(&self.input_batches[0].schema()))
}

/// Returns the expected output as a single `RecordBatch`
fn expected_output(&self) -> RecordBatch {
let schema = self.schema();
if self.filters.is_empty() {
return concat_batches(&schema, &self.input_batches).unwrap();
}

let mut filters = self.filters.iter();
let filtered_batches = self
.input_batches
.iter()
.map(|batch| {
if let Some(filter) = filters.next() {
filter_record_batch(batch, filter).unwrap()
} else {
batch.clone()
}
})
.collect::<Vec<_>>();
concat_batches(&schema, &filtered_batches).unwrap()
}
}

/// Return a RecordBatch with a UInt32Array with the specified range and
Expand Down Expand Up @@ -1063,6 +1208,77 @@ mod tests {
RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(array)]).unwrap()
}

/// Return a RecordBatch of 100 rows
fn multi_column_batch(range: Range<i32>) -> RecordBatch {
let int64_array = Int64Array::from_iter(range.clone().map(|v| {
if v % 5 == 0 {
None
} else {
Some(v as i64)
}
}));
let string_view_array = StringViewArray::from_iter(range.clone().map(|v| {
if v % 5 == 0 {
None
} else if v % 7 == 0 {
Some(format!("This is a string longer than 12 bytes{v}"))
} else {
Some(format!("Short {v}"))
}
}));
let string_array = StringArray::from_iter(range.clone().map(|v| {
if v % 11 == 0 {
None
} else {
Some(format!("Value {v}"))
}
}));
let timestamp_array = TimestampNanosecondArray::from_iter(range.map(|v| {
if v % 3 == 0 {
None
} else {
Some(v as i64 * 1000) // simulate a timestamp in milliseconds
}
}))
.with_timezone("America/New_York");

RecordBatch::try_from_iter(vec![
("int64", Arc::new(int64_array) as ArrayRef),
("stringview", Arc::new(string_view_array) as ArrayRef),
("string", Arc::new(string_array) as ArrayRef),
("timestamp", Arc::new(timestamp_array) as ArrayRef),
])
.unwrap()
}

/// Return a boolean array that filters out randomly selected rows
/// from the input batch with a `selectivity`.
///
/// For example a `selectivity` of 0.1 will filter out
/// 90% of the rows.
#[derive(Debug)]
struct RandomFilterBuilder {
num_rows: usize,
selectivity: f64,
/// seed for random number generator, increases by one each time
/// `next_filter` is called
seed: u64,
}
impl RandomFilterBuilder {
/// Build the next filter with the current seed and increment the seed
/// by one.
fn next_filter(&mut self) -> BooleanArray {
assert!(self.selectivity >= 0.0 && self.selectivity <= 1.0);
let mut rng = rand::rngs::StdRng::seed_from_u64(self.seed);
self.seed += 1;
BooleanArray::from_iter(
(0..self.num_rows)
.map(|_| rng.random_bool(self.selectivity))
.map(Some),
)
}
}

/// Returns the named column as a StringViewArray
fn col_as_string_view<'b>(name: &str, batch: &'b RecordBatch) -> &'b StringViewArray {
batch
Expand Down
11 changes: 8 additions & 3 deletions arrow-select/src/coalesce/primitive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@ use crate::coalesce::InProgressArray;
use arrow_array::cast::AsArray;
use arrow_array::{Array, ArrayRef, ArrowPrimitiveType, PrimitiveArray};
use arrow_buffer::{NullBufferBuilder, ScalarBuffer};
use arrow_schema::ArrowError;
use arrow_schema::{ArrowError, DataType};
use std::fmt::Debug;
use std::sync::Arc;

/// InProgressArray for [`PrimitiveArray`]
#[derive(Debug)]
pub(crate) struct InProgressPrimitiveArray<T: ArrowPrimitiveType> {
/// Data type of the array
data_type: DataType,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

previously primitive arrays like TimestampNanosecond would lose the timezone information. This fixes the issue

/// The current source, if any
source: Option<ArrayRef>,
/// the target batch size (and thus size for views allocation)
Expand All @@ -38,8 +40,9 @@ pub(crate) struct InProgressPrimitiveArray<T: ArrowPrimitiveType> {

impl<T: ArrowPrimitiveType> InProgressPrimitiveArray<T> {
/// Create a new `InProgressPrimitiveArray`
pub(crate) fn new(batch_size: usize) -> Self {
pub(crate) fn new(batch_size: usize, data_type: DataType) -> Self {
Self {
data_type,
batch_size,
source: None,
nulls: NullBufferBuilder::new(batch_size),
Expand Down Expand Up @@ -95,7 +98,9 @@ impl<T: ArrowPrimitiveType + Debug> InProgressArray for InProgressPrimitiveArray
let nulls = self.nulls.finish();
self.nulls = NullBufferBuilder::new(self.batch_size);

let array = PrimitiveArray::<T>::try_new(ScalarBuffer::from(values), nulls)?;
let array = PrimitiveArray::<T>::try_new(ScalarBuffer::from(values), nulls)?
// preserve timezone / precision+scale if applicable
.with_data_type(self.data_type.clone());
Ok(Arc::new(array))
}
}
Loading