diff --git a/Cargo.lock b/Cargo.lock index 2ce381711cd6..25b64c3cd720 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2228,6 +2228,7 @@ dependencies = [ "indexmap 2.9.0", "insta", "paste", + "rand 0.9.1", "recursive", "serde_json", "sqlparser", @@ -2509,6 +2510,7 @@ dependencies = [ "parking_lot", "pin-project-lite", "rand 0.9.1", + "rand_distr", "rstest", "rstest_reuse", "tempfile", diff --git a/Cargo.toml b/Cargo.toml index 758bdfb510bd..8fce7434b01e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -164,6 +164,7 @@ pbjson-types = "0.7" insta = { version = "1.43.1", features = ["glob", "filters"] } prost = "0.13.1" rand = "0.9" +rand_distr = "0.5" recursive = "0.1.1" regex = "1.8" rstest = "0.25.0" diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 9747f4424060..5241dc253fd7 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -157,7 +157,7 @@ env_logger = { workspace = true } insta = { workspace = true } paste = "^1.0" rand = { workspace = true, features = ["small_rng"] } -rand_distr = "0.5" +rand_distr = { workspace = true } regex = { workspace = true } rstest = { workspace = true } serde_json = { workspace = true } diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 02a18f22c916..6ab03b468d28 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -2262,6 +2262,50 @@ impl DataFrame { let df = ctx.read_batch(batch)?; Ok(df) } + + /// Sample a fraction of the rows from the DataFrame. + /// + /// # Arguments + /// * `fraction` - The fraction of rows to sample. + /// * `with_replacement` - Whether to sample with replacement. + /// * `seed` - The seed for the random number generator. + /// + /// # Example + /// ``` + /// use datafusion::prelude::*; + /// # use datafusion::error::Result; + /// # #[tokio::main] + /// # async fn main() -> Result<()> { + /// let df = dataframe!( + /// "id" => [1, 2, 3], + /// "name" => ["foo", "bar", "baz"] + /// )?; + /// let df = df.sample(0.5, Some(true), Some(42))?; + /// df.show().await?; + /// # Ok(()) + /// # } + /// // +----+------+ + /// // | id | name | + /// // +----+------+ + /// // | 3 | baz | + /// // +----+------+ + /// ``` + /// + pub fn sample( + self, + fraction: f64, + with_replacement: Option, + seed: Option, + ) -> Result { + let plan = LogicalPlanBuilder::from(self.plan) + .sample(fraction, with_replacement, seed)? + .build()?; + Ok(DataFrame { + session_state: self.session_state, + plan, + projection_requires_validation: self.projection_requires_validation, + }) + } } /// Macro for creating DataFrame. diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 8bf513a55a66..5fe594a3f3a0 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -79,7 +79,7 @@ use datafusion_expr::expr_rewriter::unnormalize_cols; use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary; use datafusion_expr::{ Analyze, DescribeTable, DmlStatement, Explain, ExplainFormat, Extension, FetchType, - Filter, JoinType, RecursiveQuery, SkipType, StringifiedPlan, WindowFrame, + Filter, JoinType, RecursiveQuery, Sample, SkipType, StringifiedPlan, WindowFrame, WindowFrameBound, WriteOp, }; use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr}; @@ -93,6 +93,7 @@ use datafusion_physical_plan::execution_plan::InvariantLevel; use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; use datafusion_physical_plan::recursive_query::RecursiveQueryExec; use datafusion_physical_plan::unnest::ListUnnest; +use datafusion_physical_plan::SampleExec; use sqlparser::ast::NullTreatment; use async_trait::async_trait; @@ -902,6 +903,23 @@ impl DefaultPhysicalPlanner { Arc::new(GlobalLimitExec::new(input, skip, fetch)) } + LogicalPlan::Sample(Sample { + lower_bound, + upper_bound, + seed, + with_replacement, + .. + }) => { + let input = children.one()?; + let sample = SampleExec::try_new( + input, + *lower_bound, + *upper_bound, + *with_replacement, + *seed, + )?; + Arc::new(sample) + } LogicalPlan::Unnest(Unnest { list_type_columns, struct_type_columns, diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 8d60dbea3d01..a548675bb96f 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -6137,3 +6137,73 @@ async fn test_dataframe_macro() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn test_dataframe_sample() -> Result<()> { + let df = dataframe!( + "a" => (1..40).collect::>(), + )?; + + // Test sampling 20% of rows with replacement + let df_sampled = df.clone().sample(0.2, Some(true), Some(42))?; + assert_batches_eq!( + #[rustfmt::skip] + &[ + "+----+", + "| a |", + "+----+", + "| 8 |", + "| 10 |", + "| 19 |", + "| 29 |", + "| 29 |", + "| 36 |", + "+----+", + ], + &df_sampled.collect().await? + ); + + // Test sampling 20% of rows without replacement + let df_sampled = df.clone().sample(0.2, Some(false), Some(42))?; + assert_batches_eq!( + #[rustfmt::skip] + &[ + "+----+", + "| a |", + "+----+", + "| 5 |", + "| 9 |", + "| 10 |", + "| 14 |", + "| 17 |", + "| 19 |", + "| 24 |", + "| 39 |", + "+----+", + ], + &df_sampled.collect().await? + ); + + // Test sampling with None parameters (should use defaults) + let df_sampled_default = df.clone().sample(0.2, None, Some(42))?; + assert_batches_eq!( + #[rustfmt::skip] + &[ + "+----+", + "| a |", + "+----+", + "| 5 |", + "| 9 |", + "| 10 |", + "| 14 |", + "| 17 |", + "| 19 |", + "| 24 |", + "| 39 |", + "+----+", + ], + &df_sampled_default.collect().await? + ); + + Ok(()) +} diff --git a/datafusion/expr/Cargo.toml b/datafusion/expr/Cargo.toml index 812544587bf9..ef25d2d4843f 100644 --- a/datafusion/expr/Cargo.toml +++ b/datafusion/expr/Cargo.toml @@ -52,6 +52,7 @@ datafusion-functions-window-common = { workspace = true } datafusion-physical-expr-common = { workspace = true } indexmap = { workspace = true } paste = "^1.0" +rand = { workspace = true } recursive = { workspace = true, optional = true } serde_json = { workspace = true } sqlparser = { workspace = true } diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 93dd6c2b89fc..2e02216ae526 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -30,6 +30,7 @@ use crate::expr_rewriter::{ normalize_col_with_schemas_and_ambiguity_check, normalize_cols, normalize_sorts, rewrite_sort_cols_by_aggs, }; +use crate::logical_plan::plan::Sample; use crate::logical_plan::{ Aggregate, Analyze, Distinct, DistinctOn, EmptyRelation, Explain, Filter, Join, JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, Prepare, @@ -1474,6 +1475,21 @@ impl LogicalPlanBuilder { unnest_with_options(Arc::unwrap_or_clone(self.plan), columns, options) .map(Self::new) } + + pub fn sample( + self, + fraction: f64, + with_replacement: Option, + seed: Option, + ) -> Result { + Ok(Self::new(LogicalPlan::Sample(Sample { + input: self.plan, + lower_bound: 0.0, + upper_bound: fraction, + with_replacement: with_replacement.unwrap_or(false), + seed: seed.unwrap_or_else(rand::random), + }))) + } } impl From for LogicalPlanBuilder { diff --git a/datafusion/expr/src/logical_plan/display.rs b/datafusion/expr/src/logical_plan/display.rs index f1e455f46db3..055aa338b893 100644 --- a/datafusion/expr/src/logical_plan/display.rs +++ b/datafusion/expr/src/logical_plan/display.rs @@ -20,6 +20,7 @@ use std::collections::HashMap; use std::fmt; +use crate::logical_plan::plan::Sample; use crate::{ expr_vec_fmt, Aggregate, DescribeTable, Distinct, DistinctOn, DmlStatement, Expr, Filter, Join, Limit, LogicalPlan, Partitioning, Projection, RecursiveQuery, @@ -650,6 +651,21 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> { "StructColumn": expr_vec_fmt!(struct_type_columns), }) } + LogicalPlan::Sample(Sample { + input: _, + lower_bound, + upper_bound, + with_replacement, + seed, + }) => { + json!({ + "Node Type": "Sample", + "Lower Bound": lower_bound, + "Upper Bound": upper_bound, + "With Replacement": with_replacement, + "Seed": seed, + }) + } } } } diff --git a/datafusion/expr/src/logical_plan/mod.rs b/datafusion/expr/src/logical_plan/mod.rs index a55f4d97b212..4f1b08b5863e 100644 --- a/datafusion/expr/src/logical_plan/mod.rs +++ b/datafusion/expr/src/logical_plan/mod.rs @@ -40,8 +40,8 @@ pub use plan::{ projection_schema, Aggregate, Analyze, ColumnUnnestList, DescribeTable, Distinct, DistinctOn, EmptyRelation, Explain, ExplainFormat, Extension, FetchType, Filter, Join, JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, - Projection, RecursiveQuery, Repartition, SkipType, Sort, StringifiedPlan, Subquery, - SubqueryAlias, TableScan, ToStringifiedPlan, Union, Unnest, Values, Window, + Projection, RecursiveQuery, Repartition, Sample, SkipType, Sort, StringifiedPlan, + Subquery, SubqueryAlias, TableScan, ToStringifiedPlan, Union, Unnest, Values, Window, }; pub use statement::{ Deallocate, Execute, Prepare, SetVariable, Statement, TransactionAccessMode, diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 876c14f1000f..43c993e26e89 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -287,6 +287,8 @@ pub enum LogicalPlan { Unnest(Unnest), /// A variadic query (e.g. "Recursive CTEs") RecursiveQuery(RecursiveQuery), + /// Sample the input table. This is used to implement SQL `SAMPLE` + Sample(Sample), } impl Default for LogicalPlan { @@ -347,6 +349,7 @@ impl LogicalPlan { LogicalPlan::Copy(CopyTo { input, .. }) => input.schema(), LogicalPlan::Ddl(ddl) => ddl.schema(), LogicalPlan::Unnest(Unnest { schema, .. }) => schema, + LogicalPlan::Sample(Sample { input, .. }) => input.schema(), LogicalPlan::RecursiveQuery(RecursiveQuery { static_term, .. }) => { // we take the schema of the static term as the schema of the entire recursive query static_term.schema() @@ -474,6 +477,7 @@ impl LogicalPlan { .. }) => vec![static_term, recursive_term], LogicalPlan::Statement(stmt) => stmt.inputs(), + LogicalPlan::Sample(Sample { input, .. }) => vec![input], // plans without inputs LogicalPlan::TableScan { .. } | LogicalPlan::EmptyRelation { .. } @@ -539,7 +543,8 @@ impl LogicalPlan { | LogicalPlan::Sort(Sort { input, .. }) | LogicalPlan::Limit(Limit { input, .. }) | LogicalPlan::Repartition(Repartition { input, .. }) - | LogicalPlan::Window(Window { input, .. }) => input.head_output_expr(), + | LogicalPlan::Window(Window { input, .. }) + | LogicalPlan::Sample(Sample { input, .. }) => input.head_output_expr(), LogicalPlan::Join(Join { left, right, @@ -739,6 +744,7 @@ impl LogicalPlan { LogicalPlan::EmptyRelation(_) => Ok(self), LogicalPlan::Statement(_) => Ok(self), LogicalPlan::DescribeTable(_) => Ok(self), + LogicalPlan::Sample(Sample { .. }) => Ok(self), LogicalPlan::Unnest(Unnest { input, exec_columns, @@ -892,6 +898,24 @@ impl LogicalPlan { fetch: *fetch, })) } + LogicalPlan::Sample(Sample { + with_replacement, + seed, + lower_bound, + upper_bound, + .. + }) => { + self.assert_no_expressions(expr)?; + let input = self.only_input(inputs)?; + + Ok(LogicalPlan::Sample(Sample { + input: Arc::new(input), + lower_bound: *lower_bound, + upper_bound: *upper_bound, + with_replacement: *with_replacement, + seed: *seed, + })) + } LogicalPlan::Join(Join { join_type, join_constraint, @@ -1362,6 +1386,7 @@ impl LogicalPlan { Ok(FetchType::Literal(s)) => s, _ => None, }, + LogicalPlan::Sample(Sample { input, .. }) => input.max_rows(), LogicalPlan::Distinct( Distinct::All(input) | Distinct::On(DistinctOn { input, .. }), ) => input.max_rows(), @@ -1952,6 +1977,9 @@ impl LogicalPlan { ) } }, + LogicalPlan::Sample(Sample { lower_bound, upper_bound, with_replacement, seed, .. }) => { + write!(f, "Sample: lower_bound={lower_bound}, upper_bound={upper_bound}, with_replacement={with_replacement}, seed={seed}") + } LogicalPlan::Limit(limit) => { // Attempt to display `skip` and `fetch` as literals if possible, otherwise as expressions. let skip_str = match limit.get_skip_type() { @@ -3995,6 +4023,39 @@ impl PartialOrd for Unnest { } } +/// Sample the input table. This is used to implement SQL `SAMPLE` +#[derive(Debug, Clone, PartialOrd)] +pub struct Sample { + /// The input table + pub input: Arc, + pub lower_bound: f64, + pub upper_bound: f64, + pub with_replacement: bool, + pub seed: u64, +} + +impl PartialEq for Sample { + fn eq(&self, other: &Self) -> bool { + self.input == other.input + && self.lower_bound.to_bits() == other.lower_bound.to_bits() + && self.upper_bound.to_bits() == other.upper_bound.to_bits() + && self.with_replacement == other.with_replacement + && self.seed == other.seed + } +} + +impl Eq for Sample {} + +impl Hash for Sample { + fn hash(&self, state: &mut H) { + self.input.hash(state); + self.lower_bound.to_bits().hash(state); + self.upper_bound.to_bits().hash(state); + self.with_replacement.hash(state); + self.seed.hash(state); + } +} + #[cfg(test)] mod tests { diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index 527248ad39c2..cb33c620ccc2 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -37,6 +37,7 @@ //! * [`LogicalPlan::with_new_exprs`]: Create a new plan with different expressions //! * [`LogicalPlan::expressions`]: Return a copy of the plan's expressions +use crate::logical_plan::plan::Sample; use crate::{ dml::CopyTo, Aggregate, Analyze, CreateMemoryTable, CreateView, DdlStatement, Distinct, DistinctOn, DmlStatement, Execute, Explain, Expr, Extension, Filter, Join, @@ -148,6 +149,21 @@ impl TreeNode for LogicalPlan { LogicalPlan::Limit(Limit { skip, fetch, input }) => input .map_elements(f)? .update_data(|input| LogicalPlan::Limit(Limit { skip, fetch, input })), + LogicalPlan::Sample(Sample { + input, + lower_bound, + upper_bound, + with_replacement, + seed, + }) => input.map_elements(f)?.update_data(|input| { + LogicalPlan::Sample(Sample { + input, + lower_bound, + upper_bound, + with_replacement, + seed, + }) + }), LogicalPlan::Subquery(Subquery { subquery, outer_ref_columns, @@ -471,6 +487,7 @@ impl LogicalPlan { | LogicalPlan::Dml(_) | LogicalPlan::Ddl(_) | LogicalPlan::Copy(_) + | LogicalPlan::Sample(_) | LogicalPlan::DescribeTable(_) => Ok(TreeNodeRecursion::Continue), } } @@ -651,6 +668,7 @@ impl LogicalPlan { | LogicalPlan::Dml(_) | LogicalPlan::Ddl(_) | LogicalPlan::Copy(_) + | LogicalPlan::Sample(_) | LogicalPlan::DescribeTable(_) => Transformed::no(self), }) } diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 6a49e5d22087..2679d7f67df6 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -554,6 +554,7 @@ impl OptimizerRule for CommonSubexprEliminate { | LogicalPlan::Subquery(_) | LogicalPlan::SubqueryAlias(_) | LogicalPlan::Limit(_) + | LogicalPlan::Sample(_) | LogicalPlan::Ddl(_) | LogicalPlan::Explain(_) | LogicalPlan::Analyze(_) diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index 33af52824a29..b2e48638259e 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -304,6 +304,16 @@ fn optimize_projections( .map(|input| indices.clone().with_plan_exprs(&plan, input.schema())) .collect::>()? } + LogicalPlan::Sample(_) => { + // Pass index requirements from the parent as well as column indices + // that appear in this plan's expressions to its child. These operators + // do not benefit from "small" inputs, so the projection_beneficial + // flag is `false`. + plan.inputs() + .into_iter() + .map(|input| indices.clone().with_plan_exprs(&plan, input.schema())) + .collect::>()? + } LogicalPlan::Copy(_) | LogicalPlan::Ddl(_) | LogicalPlan::Dml(_) diff --git a/datafusion/physical-plan/Cargo.toml b/datafusion/physical-plan/Cargo.toml index 095ee78cd0d6..f20e6212ac06 100644 --- a/datafusion/physical-plan/Cargo.toml +++ b/datafusion/physical-plan/Cargo.toml @@ -64,6 +64,8 @@ itertools = { workspace = true, features = ["use_std"] } log = { workspace = true } parking_lot = { workspace = true } pin-project-lite = "^0.2.7" +rand = { workspace = true } +rand_distr = { workspace = true } tokio = { workspace = true } [dev-dependencies] @@ -71,7 +73,6 @@ criterion = { workspace = true, features = ["async_futures"] } datafusion-functions-aggregate = { workspace = true } datafusion-functions-window = { workspace = true } insta = { workspace = true } -rand = { workspace = true } rstest = { workspace = true } rstest_reuse = "0.7.0" tempfile = "3.19.1" diff --git a/datafusion/physical-plan/src/lib.rs b/datafusion/physical-plan/src/lib.rs index 4d3adebb91c6..c0796e486ce3 100644 --- a/datafusion/physical-plan/src/lib.rs +++ b/datafusion/physical-plan/src/lib.rs @@ -47,6 +47,7 @@ pub use crate::execution_plan::{ }; pub use crate::metrics::Metric; pub use crate::ordering::InputOrderMode; +pub use crate::sample::SampleExec; pub use crate::stream::EmptyRecordBatchStream; pub use crate::topk::TopK; pub use crate::visitor::{accept, visit_execution_plan, ExecutionPlanVisitor}; @@ -79,6 +80,7 @@ pub mod placeholder_row; pub mod projection; pub mod recursive_query; pub mod repartition; +pub mod sample; pub mod sorts; pub mod spill; pub mod stream; diff --git a/datafusion/physical-plan/src/sample.rs b/datafusion/physical-plan/src/sample.rs new file mode 100644 index 000000000000..e9c63a957407 --- /dev/null +++ b/datafusion/physical-plan/src/sample.rs @@ -0,0 +1,495 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines the SAMPLE operator + +use rand_distr::{Distribution, Poisson}; +use std::any::Any; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use super::{ + DisplayAs, ExecutionPlanProperties, PlanProperties, RecordBatchStream, + SendableRecordBatchStream, Statistics, +}; +use crate::{ + metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}, + DisplayFormatType, ExecutionPlan, +}; + +use arrow::array::UInt32Array; +use arrow::compute; +use arrow::datatypes::SchemaRef; +use arrow::record_batch::RecordBatch; +use datafusion_common::{internal_err, plan_datafusion_err, Result}; +use datafusion_execution::TaskContext; +use datafusion_physical_expr::EquivalenceProperties; + +use futures::stream::{Stream, StreamExt}; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; + +trait Sampler: Send + Sync { + fn sample(&mut self, batch: &RecordBatch) -> Result; +} + +struct BernoulliSampler { + lower_bound: f64, + upper_bound: f64, + rng: StdRng, +} + +impl BernoulliSampler { + fn new(lower_bound: f64, upper_bound: f64, seed: u64) -> Self { + Self { + lower_bound, + upper_bound, + rng: StdRng::seed_from_u64(seed), + } + } +} + +impl Sampler for BernoulliSampler { + fn sample(&mut self, batch: &RecordBatch) -> Result { + if self.upper_bound <= self.lower_bound { + return Ok(RecordBatch::new_empty(batch.schema())); + } + + let mut indices = Vec::new(); + + for i in 0..batch.num_rows() { + let rnd: f64 = self.rng.random(); + + if rnd >= self.lower_bound && rnd < self.upper_bound { + indices.push(i as u32); + } + } + + if indices.is_empty() { + return Ok(RecordBatch::new_empty(batch.schema())); + } + let indices = UInt32Array::from(indices); + compute::take_record_batch(batch, &indices).map_err(|e| e.into()) + } +} + +struct PoissonSampler { + ratio: f64, + poisson: Poisson, + rng: StdRng, +} + +impl PoissonSampler { + fn try_new(ratio: f64, seed: u64) -> Result { + let poisson = Poisson::new(ratio).map_err(|e| plan_datafusion_err!("{}", e))?; + Ok(Self { + ratio, + poisson, + rng: StdRng::seed_from_u64(seed), + }) + } +} + +impl Sampler for PoissonSampler { + fn sample(&mut self, batch: &RecordBatch) -> Result { + if self.ratio <= 0.0 { + return Ok(RecordBatch::new_empty(batch.schema())); + } + + let mut indices = Vec::new(); + + for i in 0..batch.num_rows() { + let k = self.poisson.sample(&mut self.rng) as i32; + for _ in 0..k { + indices.push(i as u32); + } + } + + if indices.is_empty() { + return Ok(RecordBatch::new_empty(batch.schema())); + } + + let indices = UInt32Array::from(indices); + compute::take_record_batch(batch, &indices).map_err(|e| e.into()) + } +} + +/// SampleExec samples rows from its input based on a sampling method. +/// This is used to implement SQL `SAMPLE` clause. +#[derive(Debug, Clone)] +pub struct SampleExec { + /// The input plan + input: Arc, + /// The lower bound of the sampling ratio + lower_bound: f64, + /// The upper bound of the sampling ratio + upper_bound: f64, + /// Whether to sample with replacement + with_replacement: bool, + /// Random seed for reproducible sampling + seed: u64, + /// Execution metrics + metrics: ExecutionPlanMetricsSet, + /// Properties equivalence properties, partitioning, etc. + cache: PlanProperties, +} + +impl SampleExec { + /// Create a new SampleExec with a custom sampling method + pub fn try_new( + input: Arc, + lower_bound: f64, + upper_bound: f64, + with_replacement: bool, + seed: u64, + ) -> Result { + if lower_bound < 0.0 || upper_bound > 1.0 || lower_bound > upper_bound { + return internal_err!( + "Sampling bounds must be between 0.0 and 1.0, and lower_bound <= upper_bound, got [{}, {}]", + lower_bound, upper_bound + ); + } + + let cache = Self::compute_properties(&input); + + Ok(Self { + input, + lower_bound, + upper_bound, + with_replacement, + seed, + metrics: ExecutionPlanMetricsSet::new(), + cache, + }) + } + + fn create_sampler(&self, partition: usize) -> Result> { + if self.with_replacement { + Ok(Box::new(PoissonSampler::try_new( + self.upper_bound - self.lower_bound, + self.seed + partition as u64, + )?)) + } else { + Ok(Box::new(BernoulliSampler::new( + self.lower_bound, + self.upper_bound, + self.seed + partition as u64, + ))) + } + } + + /// Whether to sample with replacement + pub fn with_replacement(&self) -> bool { + self.with_replacement + } + + /// The lower bound of the sampling ratio + pub fn lower_bound(&self) -> f64 { + self.lower_bound + } + + /// The upper bound of the sampling ratio + pub fn upper_bound(&self) -> f64 { + self.upper_bound + } + + /// The random seed + pub fn seed(&self) -> u64 { + self.seed + } + + /// The input plan + pub fn input(&self) -> &Arc { + &self.input + } + + /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. + fn compute_properties(input: &Arc) -> PlanProperties { + PlanProperties::new( + EquivalenceProperties::new(input.schema()), + input.output_partitioning().clone(), + input.pipeline_behavior(), + input.boundedness(), + ) + } +} + +impl DisplayAs for SampleExec { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!( + f, + "SampleExec: lower_bound={}, upper_bound={}, with_replacement={}, seed={}", + self.lower_bound, self.upper_bound, self.with_replacement, self.seed + ) + } + DisplayFormatType::TreeRender => { + write!( + f, + "SampleExec: lower_bound={}, upper_bound={}, with_replacement={}, seed={}", + self.lower_bound, self.upper_bound, self.with_replacement, self.seed + ) + } + } + } +} + +impl ExecutionPlan for SampleExec { + fn name(&self) -> &'static str { + "SampleExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.cache + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.input] + } + + fn maintains_input_order(&self) -> Vec { + vec![false] // Sampling does not maintain input order + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + // Get the ratio from the current method + Ok(Arc::new(SampleExec::try_new( + Arc::clone(&children[0]), + self.lower_bound, + self.upper_bound, + self.with_replacement, + self.seed, + )?)) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + let input_stream = self.input.execute(partition, context)?; + let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); + + Ok(Box::pin(SampleExecStream { + input: input_stream, + sampler: self.create_sampler(partition)?, + baseline_metrics, + })) + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + + fn partition_statistics(&self, partition: Option) -> Result { + let input_stats = self.input.partition_statistics(partition)?; + + // Apply sampling ratio to statistics + let mut stats = input_stats; + let ratio = self.upper_bound - self.lower_bound; + + stats.num_rows = stats + .num_rows + .map(|nr| (nr as f64 * ratio) as usize) + .to_inexact(); + stats.total_byte_size = stats + .total_byte_size + .map(|tb| (tb as f64 * ratio) as usize) + .to_inexact(); + + Ok(stats) + } +} + +/// Stream for the SampleExec operator +struct SampleExecStream { + /// The input stream + input: SendableRecordBatchStream, + /// The sampling method + sampler: Box, + /// Runtime metrics recording + baseline_metrics: BaselineMetrics, +} + +impl Stream for SampleExecStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let poll = self.input.poll_next_unpin(cx); + let baseline_metrics = &mut self.baseline_metrics; + + match poll { + Poll::Ready(Some(Ok(batch))) => { + let start = baseline_metrics.elapsed_compute().clone(); + let result = self.sampler.sample(&batch); + let _timer = start.timer(); + Poll::Ready(Some(result)) + } + Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))), + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + } + } +} + +impl RecordBatchStream for SampleExecStream { + fn schema(&self) -> SchemaRef { + self.input.schema() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::Int32Array; + use arrow::datatypes::{Field, Schema}; + use datafusion_common::assert_batches_eq; + use datafusion_execution::TaskContext; + use futures::TryStreamExt; + use std::sync::Arc; + + #[tokio::test] + async fn test_sample_exec_bernoulli() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new( + "id", + arrow::datatypes::DataType::Int32, + false, + )])); + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]))], + )?; + + let input = Arc::new(crate::test::TestMemoryExec::try_new( + &[vec![batch]], + Arc::clone(&schema), + None, + )?); + + let sample_exec = SampleExec::try_new(input, 0.6, 1.0, false, 42)?; + + let context = Arc::new(TaskContext::default()); + let stream = sample_exec.execute(0, context)?; + + let batches = stream.try_collect::>().await?; + assert_batches_eq!( + &["+----+", "| id |", "+----+", "| 3 |", "+----+",], + &batches + ); + + Ok(()) + } + + #[tokio::test] + async fn test_sample_exec_poisson() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new( + "id", + arrow::datatypes::DataType::Int32, + false, + )])); + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]))], + )?; + + let input = Arc::new(crate::test::TestMemoryExec::try_new( + &[vec![batch]], + Arc::clone(&schema), + None, + )?); + + let sample_exec = SampleExec::try_new(input, 0.0, 0.5, true, 42)?; + + let context = Arc::new(TaskContext::default()); + let stream = sample_exec.execute(0, context)?; + + let batches = stream.try_collect::>().await?; + assert_batches_eq!( + #[rustfmt::skip] + &[ + "+----+", + "| id |", + "+----+", + "| 3 |", + "+----+", + ], + &batches + ); + + Ok(()) + } + + #[test] + fn test_sampler_trait() { + let mut bernoulli_sampler = BernoulliSampler::new(0.0, 0.5, 42); + let batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new( + "id", + arrow::datatypes::DataType::Int32, + false, + )])), + vec![Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]))], + ) + .unwrap(); + + let result = bernoulli_sampler.sample(&batch).unwrap(); + assert_batches_eq!( + #[rustfmt::skip] + &[ + "+----+", + "| id |", + "+----+", + "| 4 |", + "| 5 |", + "+----+", + ], + &[result] + ); + + let mut poisson_sampler = PoissonSampler::try_new(0.5, 42).unwrap(); + let result = poisson_sampler.sample(&batch).unwrap(); + assert_batches_eq!( + #[rustfmt::skip] + &[ + "+----+", + "| id |", + "+----+", + "| 3 |", + "+----+", + ], + &[result] + ); + } +} diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 64789f5de0d2..c60bca32cbe8 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -62,6 +62,7 @@ message LogicalPlanNode { RecursiveQueryNode recursive_query = 31; CteWorkTableScanNode cte_work_table_scan = 32; DmlNode dml = 33; + SampleNode sample = 34; } } @@ -727,6 +728,7 @@ message PhysicalPlanNode { UnnestExecNode unnest = 30; JsonScanExecNode json_scan = 31; CooperativeExecNode cooperative = 32; + SampleExecNode sample = 33; } } @@ -1291,3 +1293,19 @@ message CteWorkTableScanNode { string name = 1; datafusion_common.Schema schema = 2; } + +message SampleNode { + LogicalPlanNode input = 1; + double lower_bound = 2; + double upper_bound = 3; + bool with_replacement = 4; + uint64 seed = 5; +} + +message SampleExecNode { + PhysicalPlanNode input = 1; + double lower_bound = 2; + double upper_bound = 3; + bool with_replacement = 4; + uint64 seed = 5; +} diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 92309ea6a5cb..2442233e162e 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -11153,6 +11153,9 @@ impl serde::Serialize for LogicalPlanNode { logical_plan_node::LogicalPlanType::Dml(v) => { struct_ser.serialize_field("dml", v)?; } + logical_plan_node::LogicalPlanType::Sample(v) => { + struct_ser.serialize_field("sample", v)?; + } } } struct_ser.end() @@ -11212,6 +11215,7 @@ impl<'de> serde::Deserialize<'de> for LogicalPlanNode { "cte_work_table_scan", "cteWorkTableScan", "dml", + "sample", ]; #[allow(clippy::enum_variant_names)] @@ -11248,6 +11252,7 @@ impl<'de> serde::Deserialize<'de> for LogicalPlanNode { RecursiveQuery, CteWorkTableScan, Dml, + Sample, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -11301,6 +11306,7 @@ impl<'de> serde::Deserialize<'de> for LogicalPlanNode { "recursiveQuery" | "recursive_query" => Ok(GeneratedField::RecursiveQuery), "cteWorkTableScan" | "cte_work_table_scan" => Ok(GeneratedField::CteWorkTableScan), "dml" => Ok(GeneratedField::Dml), + "sample" => Ok(GeneratedField::Sample), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -11545,6 +11551,13 @@ impl<'de> serde::Deserialize<'de> for LogicalPlanNode { return Err(serde::de::Error::duplicate_field("dml")); } logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Dml) +; + } + GeneratedField::Sample => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("sample")); + } + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Sample) ; } } @@ -15894,6 +15907,9 @@ impl serde::Serialize for PhysicalPlanNode { physical_plan_node::PhysicalPlanType::Cooperative(v) => { struct_ser.serialize_field("cooperative", v)?; } + physical_plan_node::PhysicalPlanType::Sample(v) => { + struct_ser.serialize_field("sample", v)?; + } } } struct_ser.end() @@ -15953,6 +15969,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { "json_scan", "jsonScan", "cooperative", + "sample", ]; #[allow(clippy::enum_variant_names)] @@ -15988,6 +16005,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { Unnest, JsonScan, Cooperative, + Sample, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -16040,6 +16058,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { "unnest" => Ok(GeneratedField::Unnest), "jsonScan" | "json_scan" => Ok(GeneratedField::JsonScan), "cooperative" => Ok(GeneratedField::Cooperative), + "sample" => Ok(GeneratedField::Sample), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -16277,6 +16296,13 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { return Err(serde::de::Error::duplicate_field("cooperative")); } physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Cooperative) +; + } + GeneratedField::Sample => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("sample")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Sample) ; } } @@ -18635,6 +18661,346 @@ impl<'de> serde::Deserialize<'de> for RollupNode { deserializer.deserialize_struct("datafusion.RollupNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for SampleExecNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.input.is_some() { + len += 1; + } + if self.lower_bound != 0. { + len += 1; + } + if self.upper_bound != 0. { + len += 1; + } + if self.with_replacement { + len += 1; + } + if self.seed != 0 { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.SampleExecNode", len)?; + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; + } + if self.lower_bound != 0. { + struct_ser.serialize_field("lowerBound", &self.lower_bound)?; + } + if self.upper_bound != 0. { + struct_ser.serialize_field("upperBound", &self.upper_bound)?; + } + if self.with_replacement { + struct_ser.serialize_field("withReplacement", &self.with_replacement)?; + } + if self.seed != 0 { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("seed", ToString::to_string(&self.seed).as_str())?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for SampleExecNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "input", + "lower_bound", + "lowerBound", + "upper_bound", + "upperBound", + "with_replacement", + "withReplacement", + "seed", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Input, + LowerBound, + UpperBound, + WithReplacement, + Seed, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "input" => Ok(GeneratedField::Input), + "lowerBound" | "lower_bound" => Ok(GeneratedField::LowerBound), + "upperBound" | "upper_bound" => Ok(GeneratedField::UpperBound), + "withReplacement" | "with_replacement" => Ok(GeneratedField::WithReplacement), + "seed" => Ok(GeneratedField::Seed), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = SampleExecNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.SampleExecNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut input__ = None; + let mut lower_bound__ = None; + let mut upper_bound__ = None; + let mut with_replacement__ = None; + let mut seed__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); + } + input__ = map_.next_value()?; + } + GeneratedField::LowerBound => { + if lower_bound__.is_some() { + return Err(serde::de::Error::duplicate_field("lowerBound")); + } + lower_bound__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::UpperBound => { + if upper_bound__.is_some() { + return Err(serde::de::Error::duplicate_field("upperBound")); + } + upper_bound__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::WithReplacement => { + if with_replacement__.is_some() { + return Err(serde::de::Error::duplicate_field("withReplacement")); + } + with_replacement__ = Some(map_.next_value()?); + } + GeneratedField::Seed => { + if seed__.is_some() { + return Err(serde::de::Error::duplicate_field("seed")); + } + seed__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + } + } + Ok(SampleExecNode { + input: input__, + lower_bound: lower_bound__.unwrap_or_default(), + upper_bound: upper_bound__.unwrap_or_default(), + with_replacement: with_replacement__.unwrap_or_default(), + seed: seed__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.SampleExecNode", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for SampleNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.input.is_some() { + len += 1; + } + if self.lower_bound != 0. { + len += 1; + } + if self.upper_bound != 0. { + len += 1; + } + if self.with_replacement { + len += 1; + } + if self.seed != 0 { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.SampleNode", len)?; + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; + } + if self.lower_bound != 0. { + struct_ser.serialize_field("lowerBound", &self.lower_bound)?; + } + if self.upper_bound != 0. { + struct_ser.serialize_field("upperBound", &self.upper_bound)?; + } + if self.with_replacement { + struct_ser.serialize_field("withReplacement", &self.with_replacement)?; + } + if self.seed != 0 { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("seed", ToString::to_string(&self.seed).as_str())?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for SampleNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "input", + "lower_bound", + "lowerBound", + "upper_bound", + "upperBound", + "with_replacement", + "withReplacement", + "seed", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Input, + LowerBound, + UpperBound, + WithReplacement, + Seed, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "input" => Ok(GeneratedField::Input), + "lowerBound" | "lower_bound" => Ok(GeneratedField::LowerBound), + "upperBound" | "upper_bound" => Ok(GeneratedField::UpperBound), + "withReplacement" | "with_replacement" => Ok(GeneratedField::WithReplacement), + "seed" => Ok(GeneratedField::Seed), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = SampleNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.SampleNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut input__ = None; + let mut lower_bound__ = None; + let mut upper_bound__ = None; + let mut with_replacement__ = None; + let mut seed__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); + } + input__ = map_.next_value()?; + } + GeneratedField::LowerBound => { + if lower_bound__.is_some() { + return Err(serde::de::Error::duplicate_field("lowerBound")); + } + lower_bound__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::UpperBound => { + if upper_bound__.is_some() { + return Err(serde::de::Error::duplicate_field("upperBound")); + } + upper_bound__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::WithReplacement => { + if with_replacement__.is_some() { + return Err(serde::de::Error::duplicate_field("withReplacement")); + } + with_replacement__ = Some(map_.next_value()?); + } + GeneratedField::Seed => { + if seed__.is_some() { + return Err(serde::de::Error::duplicate_field("seed")); + } + seed__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + } + } + Ok(SampleNode { + input: input__, + lower_bound: lower_bound__.unwrap_or_default(), + upper_bound: upper_bound__.unwrap_or_default(), + with_replacement: with_replacement__.unwrap_or_default(), + seed: seed__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.SampleNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for ScalarUdfExprNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index b0fc0ce60436..064f7a181686 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -5,7 +5,7 @@ pub struct LogicalPlanNode { #[prost( oneof = "logical_plan_node::LogicalPlanType", - tags = "1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33" + tags = "1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34" )] pub logical_plan_type: ::core::option::Option, } @@ -77,6 +77,8 @@ pub mod logical_plan_node { CteWorkTableScan(super::CteWorkTableScanNode), #[prost(message, tag = "33")] Dml(::prost::alloc::boxed::Box), + #[prost(message, tag = "34")] + Sample(::prost::alloc::boxed::Box), } } #[derive(Clone, PartialEq, ::prost::Message)] @@ -1048,7 +1050,7 @@ pub mod table_reference { pub struct PhysicalPlanNode { #[prost( oneof = "physical_plan_node::PhysicalPlanType", - tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32" + tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33" )] pub physical_plan_type: ::core::option::Option, } @@ -1120,6 +1122,8 @@ pub mod physical_plan_node { JsonScan(super::JsonScanExecNode), #[prost(message, tag = "32")] Cooperative(::prost::alloc::boxed::Box), + #[prost(message, tag = "33")] + Sample(::prost::alloc::boxed::Box), } } #[derive(Clone, PartialEq, ::prost::Message)] @@ -1941,6 +1945,32 @@ pub struct CteWorkTableScanNode { #[prost(message, optional, tag = "2")] pub schema: ::core::option::Option, } +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct SampleNode { + #[prost(message, optional, boxed, tag = "1")] + pub input: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(double, tag = "2")] + pub lower_bound: f64, + #[prost(double, tag = "3")] + pub upper_bound: f64, + #[prost(bool, tag = "4")] + pub with_replacement: bool, + #[prost(uint64, tag = "5")] + pub seed: u64, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct SampleExecNode { + #[prost(message, optional, boxed, tag = "1")] + pub input: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(double, tag = "2")] + pub lower_bound: f64, + #[prost(double, tag = "3")] + pub upper_bound: f64, + #[prost(bool, tag = "4")] + pub with_replacement: bool, + #[prost(uint64, tag = "5")] + pub seed: u64, +} #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum WindowFrameUnits { diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 1acf1ee27bfe..bacfe482bff7 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -71,8 +71,8 @@ use datafusion_expr::{ Statement, WindowUDF, }; use datafusion_expr::{ - AggregateUDF, ColumnUnnestList, DmlStatement, FetchType, RecursiveQuery, SkipType, - TableSource, Unnest, + AggregateUDF, ColumnUnnestList, DmlStatement, FetchType, RecursiveQuery, Sample, + SkipType, TableSource, Unnest, }; use self::to_proto::{serialize_expr, serialize_exprs}; @@ -994,6 +994,16 @@ impl AsLogicalPlan for LogicalPlanNode { Arc::new(into_logical_plan!(dml_node.input, ctx, extension_codec)?), ), )), + LogicalPlanType::Sample(sample) => { + let input = into_logical_plan!(sample.input, ctx, extension_codec)?; + Ok(LogicalPlan::Sample(Sample { + input: Arc::new(input), + lower_bound: sample.lower_bound, + upper_bound: sample.upper_bound, + with_replacement: sample.with_replacement, + seed: sample.seed, + })) + } } } @@ -1806,6 +1816,23 @@ impl AsLogicalPlan for LogicalPlanNode { ))), }) } + LogicalPlan::Sample(sample) => { + let input = LogicalPlanNode::try_from_logical_plan( + sample.input.as_ref(), + extension_codec, + )?; + Ok(LogicalPlanNode { + logical_plan_type: Some(LogicalPlanType::Sample(Box::new( + protobuf::SampleNode { + input: Some(Box::new(input)), + lower_bound: sample.lower_bound, + upper_bound: sample.upper_bound, + with_replacement: sample.with_replacement, + seed: sample.seed, + }, + ))), + }) + } } } } diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 242b36786d07..283de36102fc 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -84,7 +84,7 @@ use datafusion::physical_plan::union::{InterleaveExec, UnionExec}; use datafusion::physical_plan::unnest::{ListUnnest, UnnestExec}; use datafusion::physical_plan::windows::{BoundedWindowAggExec, WindowAggExec}; use datafusion::physical_plan::{ - ExecutionPlan, InputOrderMode, PhysicalExpr, WindowExpr, + ExecutionPlan, InputOrderMode, PhysicalExpr, SampleExec, WindowExpr, }; use datafusion_common::config::TableParquetOptions; use datafusion_common::{internal_err, not_impl_err, DataFusionError, Result}; @@ -331,6 +331,12 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { runtime, extension_codec, ), + PhysicalPlanType::Sample(sample) => self.try_into_sample_physical_plan( + sample, + registry, + runtime, + extension_codec, + ), } } @@ -527,6 +533,13 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { ); } + if let Some(exec) = plan.downcast_ref::() { + return protobuf::PhysicalPlanNode::try_from_sample_exec( + exec, + extension_codec, + ); + } + let mut buf: Vec = vec![]; match extension_codec.try_encode(Arc::clone(&plan_clone), &mut buf) { Ok(_) => { @@ -556,6 +569,24 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { } impl protobuf::PhysicalPlanNode { + fn try_into_sample_physical_plan( + &self, + sample: &protobuf::SampleExecNode, + registry: &dyn FunctionRegistry, + runtime: &RuntimeEnv, + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result> { + let input = + into_physical_plan(&sample.input, registry, runtime, extension_codec)?; + Ok(Arc::new(SampleExec::try_new( + input, + sample.lower_bound, + sample.upper_bound, + sample.with_replacement, + sample.seed, + )?)) + } + fn try_into_explain_physical_plan( &self, explain: &protobuf::ExplainExecNode, @@ -2805,6 +2836,27 @@ impl protobuf::PhysicalPlanNode { ))), }) } + + fn try_from_sample_exec( + exec: &SampleExec, + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result { + let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + exec.input().to_owned(), + extension_codec, + )?; + Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::Sample(Box::new( + protobuf::SampleExecNode { + input: Some(Box::new(input)), + lower_bound: exec.lower_bound(), + upper_bound: exec.upper_bound(), + with_replacement: exec.with_replacement(), + seed: exec.seed(), + }, + ))), + }) + } } pub trait AsExecutionPlan: Debug + Send + Sync + Clone { diff --git a/datafusion/sql/src/relation/mod.rs b/datafusion/sql/src/relation/mod.rs index 1c5a8ff4d252..6d5a491864de 100644 --- a/datafusion/sql/src/relation/mod.rs +++ b/datafusion/sql/src/relation/mod.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use std::str::FromStr; use std::sync::Arc; use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; @@ -26,7 +27,10 @@ use datafusion_common::{ use datafusion_expr::builder::subquery_alias; use datafusion_expr::{expr::Unnest, Expr, LogicalPlan, LogicalPlanBuilder}; use datafusion_expr::{Subquery, SubqueryAlias}; -use sqlparser::ast::{FunctionArg, FunctionArgExpr, Spanned, TableFactor}; +use sqlparser::ast::{ + FunctionArg, FunctionArgExpr, Spanned, TableFactor, TableSampleKind, + TableSampleMethod, TableSampleUnit, +}; mod join; @@ -40,7 +44,11 @@ impl SqlToRel<'_, S> { let relation_span = relation.span(); let (plan, alias) = match relation { TableFactor::Table { - name, alias, args, .. + name, + alias, + args, + sample, + .. } => { if let Some(func_args) = args { let tbl_func_name = @@ -64,7 +72,7 @@ impl SqlToRel<'_, S> { let provider = self .context_provider .get_table_function_source(&tbl_func_name, args)?; - let plan = LogicalPlanBuilder::scan( + let mut plan = LogicalPlanBuilder::scan( TableReference::Bare { table: format!("{tbl_func_name}()").into(), }, @@ -72,34 +80,36 @@ impl SqlToRel<'_, S> { None, )? .build()?; + if let Some(sample) = sample { + plan = self.table_sample(sample, plan, planner_context)?; + } (plan, alias) } else { // Normalize name and alias let table_ref = self.object_name_to_table_reference(name)?; let table_name = table_ref.to_string(); let cte = planner_context.get_cte(&table_name); - ( - match ( - cte, - self.context_provider.get_table_source(table_ref.clone()), - ) { - (Some(cte_plan), _) => Ok(cte_plan.clone()), - (_, Ok(provider)) => LogicalPlanBuilder::scan( - table_ref.clone(), - provider, - None, - )? - .build(), - (None, Err(e)) => { - let e = e.with_diagnostic(Diagnostic::new_error( - format!("table '{table_ref}' not found"), - Span::try_from_sqlparser_span(relation_span), - )); - Err(e) - } - }?, - alias, - ) + let mut plan = match ( + cte, + self.context_provider.get_table_source(table_ref.clone()), + ) { + (Some(cte_plan), _) => Ok(cte_plan.clone()), + (_, Ok(provider)) => { + LogicalPlanBuilder::scan(table_ref.clone(), provider, None)? + .build() + } + (None, Err(e)) => { + let e = e.with_diagnostic(Diagnostic::new_error( + format!("table '{table_ref}' not found"), + Span::try_from_sqlparser_span(relation_span), + )); + Err(e) + } + }?; + if let Some(sample) = sample { + plan = self.table_sample(sample, plan, planner_context)?; + } + (plan, alias) } } TableFactor::Derived { @@ -224,6 +234,124 @@ impl SqlToRel<'_, S> { })), } } + + fn table_sample( + &self, + sample: TableSampleKind, + input: LogicalPlan, + _planner_context: &mut PlannerContext, + ) -> Result { + let sample = match sample { + TableSampleKind::BeforeTableAlias(sample) => sample, + TableSampleKind::AfterTableAlias(sample) => sample, + }; + if let Some(name) = &sample.name { + if *name != TableSampleMethod::Bernoulli && *name != TableSampleMethod::Row { + // Postgres-style sample. Not supported because DataFusion does not have a concept of pages like PostgreSQL. + return not_impl_err!("{} is not supported yet", name); + } + } + if sample.offset.is_some() { + // Clickhouse-style sample. Not supported because it requires knowing the total data size. + return not_impl_err!("Offset sample is not supported yet"); + } + + let seed = sample + .seed + .map(|seed| { + let Ok(seed) = seed.value.to_string().parse::() else { + return plan_err!("seed must be a number: {}", seed.value); + }; + Ok(seed) + }) + .transpose()?; + + if let Some(bucket) = sample.bucket { + if bucket.on.is_some() { + // Hive-style sample, only used when the Hive table is defined with CLUSTERED BY + return not_impl_err!("Bucket sample with ON is not supported yet"); + } + + let Ok(bucket_num) = bucket.bucket.to_string().parse::() else { + return plan_err!("bucket must be a number"); + }; + + let Ok(total_num) = bucket.total.to_string().parse::() else { + return plan_err!("total must be a number"); + }; + let logical_plan = LogicalPlanBuilder::from(input) + .sample(bucket_num as f64 / total_num as f64, None, seed)? + .build()?; + return Ok(logical_plan); + } + if let Some(quantity) = sample.quantity { + match quantity.unit { + Some(TableSampleUnit::Rows) => { + let value = evaluate_number::(&quantity.value); + if value.is_none() { + return plan_err!( + "quantity must be a non-negative number: {:?}", + quantity.value + ); + } + let value = value.unwrap(); + if value < 0 { + return plan_err!( + "quantity must be a non-negative number: {:?}", + quantity.value + ); + } + let logical_plan = LogicalPlanBuilder::from(input) + .limit(0, Some(value as usize))? + .build()?; + return Ok(logical_plan); + } + Some(TableSampleUnit::Percent) => { + let value = evaluate_number::(&quantity.value); + if value.is_none() { + return plan_err!( + "quantity must be a number: {:?}", + quantity.value + ); + } + let value = value.unwrap() / 100.0; + let logical_plan = LogicalPlanBuilder::from(input) + .sample(value, None, seed)? + .build()?; + return Ok(logical_plan); + } + None => { + // Clickhouse-style sample + let value = evaluate_number::(&quantity.value); + if value.is_none() { + return plan_err!( + "quantity must be a non-negative number: {:?}", + quantity.value + ); + } + let value = value.unwrap(); + if value < 0.0 { + return plan_err!( + "quantity must be a non-negative number: {:?}", + quantity.value + ); + } + if value >= 1.0 { + let logical_plan = LogicalPlanBuilder::from(input) + .limit(0, Some(value as usize))? + .build()?; + return Ok(logical_plan); + } else { + let logical_plan = LogicalPlanBuilder::from(input) + .sample(value, None, seed)? + .build()?; + return Ok(logical_plan); + } + } + } + } + Ok(input) + } } fn optimize_subquery_sort(plan: LogicalPlan) -> Result> { @@ -252,3 +380,41 @@ fn optimize_subquery_sort(plan: LogicalPlan) -> Result> }); new_plan } + +fn evaluate_number< + T: FromStr + + std::ops::Add + + std::ops::Sub + + std::ops::Mul + + std::ops::Div, +>( + expr: &sqlparser::ast::Expr, +) -> Option { + match expr { + sqlparser::ast::Expr::BinaryOp { left, op, right } => { + let left = evaluate_number::(left); + let right = evaluate_number::(right); + match (left, right) { + (Some(left), Some(right)) => match op { + sqlparser::ast::BinaryOperator::Plus => Some(left + right), + sqlparser::ast::BinaryOperator::Minus => Some(left - right), + sqlparser::ast::BinaryOperator::Multiply => Some(left * right), + sqlparser::ast::BinaryOperator::Divide => Some(left / right), + _ => None, + }, + _ => None, + } + } + sqlparser::ast::Expr::Value(value) => match &value.value { + sqlparser::ast::Value::Number(value, _) => { + let value = value.to_string(); + let Ok(value) = value.parse::() else { + return None; + }; + Some(value) + } + _ => None, + }, + _ => None, + } +} diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index d9f9767ba9e4..10a66ad2746a 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -113,6 +113,7 @@ impl Unparser<'_> { | LogicalPlan::Limit(_) | LogicalPlan::Statement(_) | LogicalPlan::Values(_) + | LogicalPlan::Sample(_) | LogicalPlan::Distinct(_) => self.select_to_sql_statement(&plan), LogicalPlan::Dml(_) => self.dml_to_sql(&plan), LogicalPlan::Extension(extension) => { diff --git a/datafusion/sqllogictest/test_files/sample.slt b/datafusion/sqllogictest/test_files/sample.slt new file mode 100644 index 000000000000..89741d4124f5 --- /dev/null +++ b/datafusion/sqllogictest/test_files/sample.slt @@ -0,0 +1,74 @@ +# SAMPLE function tests with REPEATABLE seed for stability + +# Test basic SAMPLE with REPEATABLE seed +statement ok +CREATE TABLE sample_test (id INT) AS SELECT * FROM generate_series(1, 40); + +query TT +EXPLAIN SELECT * FROM sample_test SAMPLE (50 PERCENT) REPEATABLE (42); +---- +logical_plan +01)Sample: lower_bound=0, upper_bound=0.5, with_replacement=false, seed=42 +02)--TableScan: sample_test projection=[id] +physical_plan +01)SampleExec: lower_bound=0, upper_bound=0.5, with_replacement=false, seed=42 +02)--DataSourceExec: partitions=4, partition_sizes=[1, 0, 0, 0] + +query TT +EXPLAIN SELECT * FROM sample_test SAMPLE (10 ROWS); +---- +logical_plan +01)Limit: skip=0, fetch=10 +02)--TableScan: sample_test projection=[id], fetch=10 +physical_plan +01)CoalescePartitionsExec: fetch=10 +02)--DataSourceExec: partitions=4, partition_sizes=[1, 0, 0, 0], fetch=10 + +query TT +EXPLAIN SELECT * FROM sample_test SAMPLE 0.5 REPEATABLE (42); +---- +logical_plan +01)Sample: lower_bound=0, upper_bound=0.5, with_replacement=false, seed=42 +02)--TableScan: sample_test projection=[id] +physical_plan +01)SampleExec: lower_bound=0, upper_bound=0.5, with_replacement=false, seed=42 +02)--DataSourceExec: partitions=4, partition_sizes=[1, 0, 0, 0] + +query TT +EXPLAIN SELECT * FROM sample_test SAMPLE 10; +---- +logical_plan +01)Limit: skip=0, fetch=10 +02)--TableScan: sample_test projection=[id], fetch=10 +physical_plan +01)CoalescePartitionsExec: fetch=10 +02)--DataSourceExec: partitions=4, partition_sizes=[1, 0, 0, 0], fetch=10 + + +# Test SAMPLE with 20% ratio and REPEATABLE seed +query I +SELECT * FROM sample_test SAMPLE (20 PERCENT) REPEATABLE (42); +---- +5 +9 +10 +14 +17 +19 +24 +39 + +query I +SELECT COUNT(DISTINCT id) FROM (SELECT id FROM sample_test SAMPLE (100 PERCENT)); +---- +40 + +# Test SAMPLE with 0% ratio and REPEATABLE seed (should return no rows) +query I +SELECT COUNT(DISTINCT id) FROM (SELECT id FROM sample_test SAMPLE (0 PERCENT)); +---- +0 + +# Clean up +statement ok +DROP TABLE sample_test diff --git a/datafusion/substrait/src/logical_plan/producer/rel/mod.rs b/datafusion/substrait/src/logical_plan/producer/rel/mod.rs index c3599a2635ff..8ac85143290a 100644 --- a/datafusion/substrait/src/logical_plan/producer/rel/mod.rs +++ b/datafusion/substrait/src/logical_plan/producer/rel/mod.rs @@ -74,5 +74,6 @@ pub fn to_substrait_rel( LogicalPlan::RecursiveQuery(plan) => { not_impl_err!("Unsupported plan type: {plan:?}")? } + LogicalPlan::Sample(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, } }