From 8ceba5986f2a0f83df22f827fcd5c917761b2b41 Mon Sep 17 00:00:00 2001 From: Chongchen Chen Date: Mon, 23 Jun 2025 11:34:29 +0800 Subject: [PATCH 01/10] feat: table sample --- Cargo.lock | 2 + Cargo.toml | 1 + datafusion/core/Cargo.toml | 2 +- datafusion/core/src/dataframe/mod.rs | 31 ++ datafusion/core/src/physical_planner.rs | 14 +- datafusion/expr/Cargo.toml | 1 + datafusion/expr/src/logical_plan/builder.rs | 14 +- datafusion/expr/src/logical_plan/display.rs | 16 + datafusion/expr/src/logical_plan/mod.rs | 2 +- datafusion/expr/src/logical_plan/plan.rs | 58 ++- datafusion/expr/src/logical_plan/tree_node.rs | 6 + .../optimizer/src/common_subexpr_eliminate.rs | 1 + .../optimizer/src/optimize_projections/mod.rs | 10 + datafusion/physical-plan/Cargo.toml | 3 +- datafusion/physical-plan/src/lib.rs | 2 + datafusion/physical-plan/src/sample.rs | 439 ++++++++++++++++++ datafusion/proto/proto/datafusion.proto | 18 + datafusion/proto/src/generated/pbjson.rs | 366 +++++++++++++++ datafusion/proto/src/generated/prost.rs | 34 +- datafusion/proto/src/logical_plan/mod.rs | 29 +- datafusion/proto/src/physical_plan/mod.rs | 45 +- datafusion/sql/src/relation/mod.rs | 184 ++++++-- datafusion/sql/src/unparser/plan.rs | 1 + datafusion/sqllogictest/test_files/sample.slt | 226 +++++++++ .../src/logical_plan/producer/rel/mod.rs | 1 + 25 files changed, 1468 insertions(+), 38 deletions(-) create mode 100644 datafusion/physical-plan/src/sample.rs create mode 100644 datafusion/sqllogictest/test_files/sample.slt diff --git a/Cargo.lock b/Cargo.lock index e2e593e62528..929e44ae2673 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2226,6 +2226,7 @@ dependencies = [ "indexmap 2.9.0", "insta", "paste", + "rand 0.9.1", "recursive", "serde_json", "sqlparser", @@ -2507,6 +2508,7 @@ dependencies = [ "parking_lot", "pin-project-lite", "rand 0.9.1", + "rand_distr", "rstest", "rstest_reuse", "tempfile", diff --git a/Cargo.toml b/Cargo.toml index f2cd6f72c7e6..e3bae81878a3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -163,6 +163,7 @@ pbjson-types = "0.7" insta = { version = "1.43.1", features = ["glob", "filters"] } prost = "0.13.1" rand = "0.9" +rand_distr = "0.5" recursive = "0.1.1" regex = "1.8" rstest = "0.25.0" diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 9747f4424060..5241dc253fd7 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -157,7 +157,7 @@ env_logger = { workspace = true } insta = { workspace = true } paste = "^1.0" rand = { workspace = true, features = ["small_rng"] } -rand_distr = "0.5" +rand_distr = { workspace = true } regex = { workspace = true } rstest = { workspace = true } serde_json = { workspace = true } diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 02a18f22c916..2fd6de703f06 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -2262,6 +2262,37 @@ impl DataFrame { let df = ctx.read_batch(batch)?; Ok(df) } + + /// Sample a fraction of the rows from the DataFrame. + /// + /// # Arguments + /// * `fraction` - The fraction of rows to sample. + /// * `with_replacement` - Whether to sample with replacement. + /// * `seed` - The seed for the random number generator. + /// + /// # Example + /// ``` + /// # use datafusion::prelude::*; + /// # use datafusion::error::Result; + /// # #[tokio::main] + /// # async fn main() -> Result<()> { + /// let df = dataframe!( + /// "id" => [1, 2, 3], + /// "name" => ["foo", "bar", "baz"] + /// )?; + /// let df = df.sample(0.5, Some(true), Some(42))?; + /// df.show().await?; + /// # Ok(()) + /// ``` + /// + pub fn sample(self, fraction: f64, with_replacement: Option, seed: Option) -> Result { + let plan = LogicalPlanBuilder::from(self.plan).sample(fraction, with_replacement, seed)?.build()?; + Ok(DataFrame { + session_state: self.session_state, + plan, + projection_requires_validation: self.projection_requires_validation, + }) + } } /// Macro for creating DataFrame. diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 14188f6bf0c9..f2863f8d1a50 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -79,7 +79,7 @@ use datafusion_expr::expr_rewriter::unnormalize_cols; use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary; use datafusion_expr::{ Analyze, DescribeTable, DmlStatement, Explain, ExplainFormat, Extension, FetchType, - Filter, JoinType, RecursiveQuery, SkipType, StringifiedPlan, WindowFrame, + Filter, JoinType, RecursiveQuery, Sample, SkipType, StringifiedPlan, WindowFrame, WindowFrameBound, WriteOp, }; use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr}; @@ -93,6 +93,7 @@ use datafusion_physical_plan::execution_plan::InvariantLevel; use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; use datafusion_physical_plan::recursive_query::RecursiveQueryExec; use datafusion_physical_plan::unnest::ListUnnest; +use datafusion_physical_plan::SampleExec; use sqlparser::ast::NullTreatment; use async_trait::async_trait; @@ -869,6 +870,17 @@ impl DefaultPhysicalPlanner { Arc::new(GlobalLimitExec::new(input, skip, fetch)) } + LogicalPlan::Sample(Sample { + lower_bound, + upper_bound, + seed, + with_replacement, + .. + }) => { + let input = children.one()?; + let sample = SampleExec::try_new(input, *lower_bound, *upper_bound, *with_replacement, *seed)?; + Arc::new(sample) + } LogicalPlan::Unnest(Unnest { list_type_columns, struct_type_columns, diff --git a/datafusion/expr/Cargo.toml b/datafusion/expr/Cargo.toml index d77c59ff64e1..a7601a579f00 100644 --- a/datafusion/expr/Cargo.toml +++ b/datafusion/expr/Cargo.toml @@ -51,6 +51,7 @@ datafusion-functions-window-common = { workspace = true } datafusion-physical-expr-common = { workspace = true } indexmap = { workspace = true } paste = "^1.0" +rand = { workspace = true } recursive = { workspace = true, optional = true } serde_json = { workspace = true } sqlparser = { workspace = true } diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 93dd6c2b89fc..1eb313a2bcee 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -30,6 +30,7 @@ use crate::expr_rewriter::{ normalize_col_with_schemas_and_ambiguity_check, normalize_cols, normalize_sorts, rewrite_sort_cols_by_aggs, }; +use crate::logical_plan::plan::Sample; use crate::logical_plan::{ Aggregate, Analyze, Distinct, DistinctOn, EmptyRelation, Explain, Filter, Join, JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, Prepare, @@ -43,8 +44,7 @@ use crate::utils::{ group_window_expr_by_sort_keys, }; use crate::{ - and, binary_expr, lit, DmlStatement, Expr, ExprSchemable, Operator, RecursiveQuery, - Statement, TableProviderFilterPushDown, TableSource, WriteOp, + and, binary_expr, lit, DmlStatement, Expr, ExprSchemable, Operator, RecursiveQuery, Statement, TableProviderFilterPushDown, TableSource, WriteOp }; use super::dml::InsertOp; @@ -1474,6 +1474,16 @@ impl LogicalPlanBuilder { unnest_with_options(Arc::unwrap_or_clone(self.plan), columns, options) .map(Self::new) } + + pub fn sample(self, fraction: f64, with_replacement: Option, seed: Option) -> Result { + Ok(Self::new(LogicalPlan::Sample(Sample { + input: self.plan, + lower_bound: 0.0, + upper_bound: fraction, + with_replacement: with_replacement.unwrap_or(false), + seed: seed.unwrap_or_else(|| rand::random()), + }))) + } } impl From for LogicalPlanBuilder { diff --git a/datafusion/expr/src/logical_plan/display.rs b/datafusion/expr/src/logical_plan/display.rs index f1e455f46db3..055aa338b893 100644 --- a/datafusion/expr/src/logical_plan/display.rs +++ b/datafusion/expr/src/logical_plan/display.rs @@ -20,6 +20,7 @@ use std::collections::HashMap; use std::fmt; +use crate::logical_plan::plan::Sample; use crate::{ expr_vec_fmt, Aggregate, DescribeTable, Distinct, DistinctOn, DmlStatement, Expr, Filter, Join, Limit, LogicalPlan, Partitioning, Projection, RecursiveQuery, @@ -650,6 +651,21 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> { "StructColumn": expr_vec_fmt!(struct_type_columns), }) } + LogicalPlan::Sample(Sample { + input: _, + lower_bound, + upper_bound, + with_replacement, + seed, + }) => { + json!({ + "Node Type": "Sample", + "Lower Bound": lower_bound, + "Upper Bound": upper_bound, + "With Replacement": with_replacement, + "Seed": seed, + }) + } } } } diff --git a/datafusion/expr/src/logical_plan/mod.rs b/datafusion/expr/src/logical_plan/mod.rs index a55f4d97b212..fa276293fa79 100644 --- a/datafusion/expr/src/logical_plan/mod.rs +++ b/datafusion/expr/src/logical_plan/mod.rs @@ -41,7 +41,7 @@ pub use plan::{ DistinctOn, EmptyRelation, Explain, ExplainFormat, Extension, FetchType, Filter, Join, JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, Projection, RecursiveQuery, Repartition, SkipType, Sort, StringifiedPlan, Subquery, - SubqueryAlias, TableScan, ToStringifiedPlan, Union, Unnest, Values, Window, + SubqueryAlias, TableScan, ToStringifiedPlan, Union, Unnest, Values, Window, Sample, }; pub use statement::{ Deallocate, Execute, Prepare, SetVariable, Statement, TransactionAccessMode, diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 876c14f1000f..4fa6bfae3962 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -287,6 +287,8 @@ pub enum LogicalPlan { Unnest(Unnest), /// A variadic query (e.g. "Recursive CTEs") RecursiveQuery(RecursiveQuery), + /// Sample the input table. This is used to implement SQL `SAMPLE` + Sample(Sample), } impl Default for LogicalPlan { @@ -347,6 +349,7 @@ impl LogicalPlan { LogicalPlan::Copy(CopyTo { input, .. }) => input.schema(), LogicalPlan::Ddl(ddl) => ddl.schema(), LogicalPlan::Unnest(Unnest { schema, .. }) => schema, + LogicalPlan::Sample(Sample { input, .. }) => input.schema(), LogicalPlan::RecursiveQuery(RecursiveQuery { static_term, .. }) => { // we take the schema of the static term as the schema of the entire recursive query static_term.schema() @@ -474,6 +477,7 @@ impl LogicalPlan { .. }) => vec![static_term, recursive_term], LogicalPlan::Statement(stmt) => stmt.inputs(), + LogicalPlan::Sample(Sample { input, .. }) => vec![input], // plans without inputs LogicalPlan::TableScan { .. } | LogicalPlan::EmptyRelation { .. } @@ -539,7 +543,8 @@ impl LogicalPlan { | LogicalPlan::Sort(Sort { input, .. }) | LogicalPlan::Limit(Limit { input, .. }) | LogicalPlan::Repartition(Repartition { input, .. }) - | LogicalPlan::Window(Window { input, .. }) => input.head_output_expr(), + | LogicalPlan::Window(Window { input, .. }) + | LogicalPlan::Sample(Sample { input, .. }) => input.head_output_expr(), LogicalPlan::Join(Join { left, right, @@ -739,6 +744,7 @@ impl LogicalPlan { LogicalPlan::EmptyRelation(_) => Ok(self), LogicalPlan::Statement(_) => Ok(self), LogicalPlan::DescribeTable(_) => Ok(self), + LogicalPlan::Sample(Sample {..}) => Ok(self), LogicalPlan::Unnest(Unnest { input, exec_columns, @@ -892,6 +898,18 @@ impl LogicalPlan { fetch: *fetch, })) } + LogicalPlan::Sample(Sample { with_replacement, seed, lower_bound, upper_bound, .. }) => { + self.assert_no_expressions(expr)?; + let input = self.only_input(inputs)?; + + Ok(LogicalPlan::Sample(Sample { + input: Arc::new(input), + lower_bound: *lower_bound, + upper_bound: *upper_bound, + with_replacement: *with_replacement, + seed: *seed, + })) + } LogicalPlan::Join(Join { join_type, join_constraint, @@ -1362,6 +1380,7 @@ impl LogicalPlan { Ok(FetchType::Literal(s)) => s, _ => None, }, + LogicalPlan::Sample(Sample { input, .. }) => input.max_rows(), LogicalPlan::Distinct( Distinct::All(input) | Distinct::On(DistinctOn { input, .. }), ) => input.max_rows(), @@ -1952,6 +1971,9 @@ impl LogicalPlan { ) } }, + LogicalPlan::Sample(Sample { lower_bound, upper_bound, with_replacement, seed, .. }) => { + write!(f, "Sample: lower_bound={lower_bound}, upper_bound={upper_bound}, with_replacement={with_replacement}, seed={seed}") + } LogicalPlan::Limit(limit) => { // Attempt to display `skip` and `fetch` as literals if possible, otherwise as expressions. let skip_str = match limit.get_skip_type() { @@ -3995,6 +4017,40 @@ impl PartialOrd for Unnest { } } +/// Sample the input table. This is used to implement SQL `SAMPLE` +#[derive(Debug, Clone, PartialOrd)] +pub struct Sample { + /// The input table + pub input: Arc, + pub lower_bound: f64, + pub upper_bound: f64, + pub with_replacement: bool, + pub seed: u64, +} + +impl PartialEq for Sample { + fn eq(&self, other: &Self) -> bool { + self.input == other.input + && self.lower_bound.to_bits() == other.lower_bound.to_bits() + && self.upper_bound.to_bits() == other.upper_bound.to_bits() + && self.with_replacement == other.with_replacement + && self.seed == other.seed + } +} + +impl Eq for Sample {} + +impl Hash for Sample { + fn hash(&self, state: &mut H) { + self.input.hash(state); + self.lower_bound.to_bits().hash(state); + self.upper_bound.to_bits().hash(state); + self.with_replacement.hash(state); + self.seed.hash(state); + } +} + + #[cfg(test)] mod tests { diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index 527248ad39c2..b6f4708bc906 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -37,6 +37,7 @@ //! * [`LogicalPlan::with_new_exprs`]: Create a new plan with different expressions //! * [`LogicalPlan::expressions`]: Return a copy of the plan's expressions +use crate::logical_plan::plan::Sample; use crate::{ dml::CopyTo, Aggregate, Analyze, CreateMemoryTable, CreateView, DdlStatement, Distinct, DistinctOn, DmlStatement, Execute, Explain, Expr, Extension, Filter, Join, @@ -148,6 +149,9 @@ impl TreeNode for LogicalPlan { LogicalPlan::Limit(Limit { skip, fetch, input }) => input .map_elements(f)? .update_data(|input| LogicalPlan::Limit(Limit { skip, fetch, input })), + LogicalPlan::Sample(Sample { input, lower_bound, upper_bound, with_replacement, seed }) => input + .map_elements(f)? + .update_data(|input| LogicalPlan::Sample(Sample { input, lower_bound, upper_bound, with_replacement, seed })), LogicalPlan::Subquery(Subquery { subquery, outer_ref_columns, @@ -471,6 +475,7 @@ impl LogicalPlan { | LogicalPlan::Dml(_) | LogicalPlan::Ddl(_) | LogicalPlan::Copy(_) + | LogicalPlan::Sample(_) | LogicalPlan::DescribeTable(_) => Ok(TreeNodeRecursion::Continue), } } @@ -651,6 +656,7 @@ impl LogicalPlan { | LogicalPlan::Dml(_) | LogicalPlan::Ddl(_) | LogicalPlan::Copy(_) + | LogicalPlan::Sample(_) | LogicalPlan::DescribeTable(_) => Transformed::no(self), }) } diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 6a49e5d22087..2679d7f67df6 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -554,6 +554,7 @@ impl OptimizerRule for CommonSubexprEliminate { | LogicalPlan::Subquery(_) | LogicalPlan::SubqueryAlias(_) | LogicalPlan::Limit(_) + | LogicalPlan::Sample(_) | LogicalPlan::Ddl(_) | LogicalPlan::Explain(_) | LogicalPlan::Analyze(_) diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index 33af52824a29..b2e48638259e 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -304,6 +304,16 @@ fn optimize_projections( .map(|input| indices.clone().with_plan_exprs(&plan, input.schema())) .collect::>()? } + LogicalPlan::Sample(_) => { + // Pass index requirements from the parent as well as column indices + // that appear in this plan's expressions to its child. These operators + // do not benefit from "small" inputs, so the projection_beneficial + // flag is `false`. + plan.inputs() + .into_iter() + .map(|input| indices.clone().with_plan_exprs(&plan, input.schema())) + .collect::>()? + } LogicalPlan::Copy(_) | LogicalPlan::Ddl(_) | LogicalPlan::Dml(_) diff --git a/datafusion/physical-plan/Cargo.toml b/datafusion/physical-plan/Cargo.toml index 0fc68ae49775..d9fe3a826c3e 100644 --- a/datafusion/physical-plan/Cargo.toml +++ b/datafusion/physical-plan/Cargo.toml @@ -62,6 +62,8 @@ itertools = { workspace = true, features = ["use_std"] } log = { workspace = true } parking_lot = { workspace = true } pin-project-lite = "^0.2.7" +rand = { workspace = true } +rand_distr = { workspace = true } tokio = { workspace = true } [dev-dependencies] @@ -69,7 +71,6 @@ criterion = { workspace = true, features = ["async_futures"] } datafusion-functions-aggregate = { workspace = true } datafusion-functions-window = { workspace = true } insta = { workspace = true } -rand = { workspace = true } rstest = { workspace = true } rstest_reuse = "0.7.0" tempfile = "3.19.1" diff --git a/datafusion/physical-plan/src/lib.rs b/datafusion/physical-plan/src/lib.rs index 9e703ced1fc2..88ec47d852c8 100644 --- a/datafusion/physical-plan/src/lib.rs +++ b/datafusion/physical-plan/src/lib.rs @@ -47,6 +47,7 @@ pub use crate::execution_plan::{ }; pub use crate::metrics::Metric; pub use crate::ordering::InputOrderMode; +pub use crate::sample::SampleExec; pub use crate::stream::EmptyRecordBatchStream; pub use crate::topk::TopK; pub use crate::visitor::{accept, visit_execution_plan, ExecutionPlanVisitor}; @@ -77,6 +78,7 @@ pub mod projection; pub mod recursive_query; pub mod repartition; pub mod sorts; +pub mod sample; pub mod spill; pub mod stream; pub mod streaming; diff --git a/datafusion/physical-plan/src/sample.rs b/datafusion/physical-plan/src/sample.rs new file mode 100644 index 000000000000..2c2cdc150d7b --- /dev/null +++ b/datafusion/physical-plan/src/sample.rs @@ -0,0 +1,439 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines the SAMPLE operator + +use std::any::Any; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; +use rand_distr::{Distribution, Poisson}; + +use super::{ + DisplayAs, ExecutionPlanProperties, PlanProperties, RecordBatchStream, + SendableRecordBatchStream, Statistics, +}; +use crate::{ + metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}, + DisplayFormatType, ExecutionPlan, +}; + +use arrow::array::UInt32Array; +use arrow::datatypes::SchemaRef; +use arrow::record_batch::RecordBatch; +use arrow::compute; +use datafusion_common::{internal_err, plan_datafusion_err, Result}; +use datafusion_execution::TaskContext; +use datafusion_physical_expr::EquivalenceProperties; + +use futures::stream::{Stream, StreamExt}; +use rand::{Rng, SeedableRng}; +use rand::rngs::StdRng; + +trait Sampler: Send + Sync { + fn sample(&mut self, batch: &RecordBatch) -> Result; +} + +struct BernoulliSampler { + lower_bound: f64, + upper_bound: f64, + rng: StdRng, +} + +impl BernoulliSampler { + fn new(lower_bound: f64, upper_bound: f64, seed: u64) -> Self { + Self { lower_bound, upper_bound, rng: StdRng::seed_from_u64(seed) } + } +} + +impl Sampler for BernoulliSampler { + fn sample(&mut self, batch: &RecordBatch) -> Result { + + if self.upper_bound <= self.lower_bound { + return Ok(RecordBatch::new_empty(batch.schema())); + } + + let mut indices = Vec::new(); + + for i in 0..batch.num_rows() { + let rnd: f64 = self.rng.random(); + + if rnd >= self.lower_bound && rnd < self.upper_bound { + indices.push(i as u32); + } + } + + if indices.is_empty() { + return Ok(RecordBatch::new_empty(batch.schema())); + } + let indices = UInt32Array::from(indices); + compute::take_record_batch(batch, &indices).map_err(|e| e.into()) + } +} + +struct PoissonSampler { + ratio: f64, + poisson: Poisson, + rng: StdRng, +} + +impl PoissonSampler { + fn try_new(ratio: f64, seed: u64) -> Result { + let poisson = Poisson::new(ratio).map_err(|e| plan_datafusion_err!("{}", e))?; + Ok(Self { ratio, poisson, rng: StdRng::seed_from_u64(seed) }) + } +} + +impl Sampler for PoissonSampler { + fn sample(&mut self, batch: &RecordBatch) -> Result { + if self.ratio <= 0.0 { + return Ok(RecordBatch::new_empty(batch.schema())); + } + + let mut indices = Vec::new(); + + for i in 0..batch.num_rows() { + let k = self.poisson.sample(&mut self.rng) as i32; + for _ in 0..k { + indices.push(i as u32); + } + } + + if indices.is_empty() { + return Ok(RecordBatch::new_empty(batch.schema())); + } + + let indices = UInt32Array::from(indices); + compute::take_record_batch(batch, &indices).map_err(|e| e.into()) + } +} + +/// SampleExec samples rows from its input based on a sampling method. +/// This is used to implement SQL `SAMPLE` clause. +#[derive(Debug, Clone)] +pub struct SampleExec { + /// The input plan + input: Arc, + /// The lower bound of the sampling ratio + lower_bound: f64, + /// The upper bound of the sampling ratio + upper_bound: f64, + /// Whether to sample with replacement + with_replacement: bool, + /// Random seed for reproducible sampling + seed: u64, + /// Execution metrics + metrics: ExecutionPlanMetricsSet, + /// Properties equivalence properties, partitioning, etc. + cache: PlanProperties, +} + +impl SampleExec { + /// Create a new SampleExec with a custom sampling method + pub fn try_new( + input: Arc, + lower_bound: f64, + upper_bound: f64, + with_replacement: bool, + seed: u64, + ) -> Result { + if lower_bound < 0.0 || upper_bound > 1.0 || lower_bound > upper_bound { + return internal_err!( + "Sampling bounds must be between 0.0 and 1.0, and lower_bound <= upper_bound, got [{}, {}]", + lower_bound, upper_bound + ); + } + + let cache = Self::compute_properties(&input); + + Ok(Self { + input, + lower_bound, + upper_bound, + with_replacement, + seed, + metrics: ExecutionPlanMetricsSet::new(), + cache, + }) + } + + fn create_sampler(&self, partition: usize) -> Result> { + if self.with_replacement { + Ok(Box::new(PoissonSampler::try_new(self.upper_bound - self.lower_bound, self.seed + partition as u64)?)) + } else { + Ok(Box::new(BernoulliSampler::new(self.lower_bound, self.upper_bound, self.seed + partition as u64))) + } + } + + /// Whether to sample with replacement + pub fn with_replacement(&self) -> bool { + self.with_replacement + } + + /// The lower bound of the sampling ratio + pub fn lower_bound(&self) -> f64 { + self.lower_bound + } + + /// The upper bound of the sampling ratio + pub fn upper_bound(&self) -> f64 { + self.upper_bound + } + + /// The random seed + pub fn seed(&self) -> u64 { + self.seed + } + + /// The input plan + pub fn input(&self) -> &Arc { + &self.input + } + + /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. + fn compute_properties(input: &Arc) -> PlanProperties { + PlanProperties::new( + EquivalenceProperties::new(input.schema()), + input.output_partitioning().clone(), + input.pipeline_behavior(), + input.boundedness(), + ) + } +} + +impl DisplayAs for SampleExec { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!( + f, + "SampleExec: lower_bound={}, upper_bound={}, with_replacement={}, seed={}", + self.lower_bound, self.upper_bound, self.with_replacement, self.seed + ) + } + DisplayFormatType::TreeRender => { + write!( + f, + "SampleExec: lower_bound={}, upper_bound={}, with_replacement={}, seed={}", + self.lower_bound, self.upper_bound, self.with_replacement, self.seed + ) + } + } + } +} + +impl ExecutionPlan for SampleExec { + fn name(&self) -> &'static str { + "SampleExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.cache + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.input] + } + + fn maintains_input_order(&self) -> Vec { + vec![false] // Sampling does not maintain input order + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + // Get the ratio from the current method + Ok(Arc::new(SampleExec::try_new( + children[0].clone(), + self.lower_bound, + self.upper_bound, + self.with_replacement, + self.seed, + )?)) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + let input_stream = self.input.execute(partition, context)?; + let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); + + Ok(Box::pin(SampleExecStream { + input: input_stream, + sampler: self.create_sampler(partition)?, + baseline_metrics, + })) + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + + fn partition_statistics(&self, partition: Option) -> Result { + let input_stats = self.input.partition_statistics(partition)?; + + // Apply sampling ratio to statistics + let mut stats = input_stats; + let ratio = self.upper_bound - self.lower_bound; + + stats.num_rows = stats.num_rows.map(|nr| (nr as f64 * ratio) as usize).to_inexact(); + stats.total_byte_size = stats.total_byte_size.map(|tb| (tb as f64 * ratio) as usize).to_inexact(); + + Ok(stats) + } +} + +/// Stream for the SampleExec operator +struct SampleExecStream { + /// The input stream + input: SendableRecordBatchStream, + /// The sampling method + sampler: Box, + /// Runtime metrics recording + baseline_metrics: BaselineMetrics, +} + +impl Stream for SampleExecStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let poll = self.input.poll_next_unpin(cx); + let baseline_metrics = &mut self.baseline_metrics; + + match poll { + Poll::Ready(Some(Ok(batch))) => { + let start = baseline_metrics.elapsed_compute().clone(); + let result = self.sampler.sample(&batch); + let _timer = start.timer(); + Poll::Ready(Some(result)) + } + Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))), + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + } + } +} + +impl RecordBatchStream for SampleExecStream { + fn schema(&self) -> SchemaRef { + self.input.schema() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{Int32Array, StringArray}; + use arrow::datatypes::{Field, Schema}; + use datafusion_execution::TaskContext; + use futures::TryStreamExt; + use std::sync::Arc; + + #[tokio::test] + async fn test_sample_exec_bernoulli() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", arrow::datatypes::DataType::Int32, false), + Field::new("name", arrow::datatypes::DataType::Utf8, false), + ])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])), + Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"])), + ], + )?; + + let input = Arc::new(crate::test::TestMemoryExec::try_new( + &[vec![batch]], + schema.clone(), + None, + )?); + + let sample_exec = SampleExec::try_new(input, 0.6, 1.0, false, 42)?; + + let context = Arc::new(TaskContext::default()); + let stream = sample_exec.execute(0, context)?; + + let batches = stream.try_collect::>().await?; + assert!(!batches.is_empty()); + + let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum(); + // With 60% sampling ratio and 5 input rows, we expect around 3 rows + assert!(total_rows >= 2 && total_rows <= 4); + + Ok(()) + } + + #[tokio::test] + async fn test_sample_exec_poisson() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", arrow::datatypes::DataType::Int32, false), + ])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]))], + )?; + + let input = Arc::new(crate::test::TestMemoryExec::try_new( + &[vec![batch]], + schema.clone(), + None, + )?); + + let sample_exec = SampleExec::try_new(input, 0.0, 0.5, true, 42)?; + + let context = Arc::new(TaskContext::default()); + let stream = sample_exec.execute(0, context)?; + + let batches = stream.try_collect::>().await?; + assert!(!batches.is_empty()); + + Ok(()) + } + + #[test] + fn test_sampler_trait() { + let mut bernoulli_sampler = BernoulliSampler::new(0.0, 0.5, 42); + let batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("id", arrow::datatypes::DataType::Int32, false)])), + vec![Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]))], + ).unwrap(); + + let result = bernoulli_sampler.sample(&batch); + assert!(result.is_ok()); + + let mut poisson_sampler = PoissonSampler::try_new(0.5, 42).unwrap(); + let result = poisson_sampler.sample(&batch); + assert!(result.is_ok()); + } +} diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 1e1f91e07e29..b5487f22ab46 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -62,6 +62,7 @@ message LogicalPlanNode { RecursiveQueryNode recursive_query = 31; CteWorkTableScanNode cte_work_table_scan = 32; DmlNode dml = 33; + SampleNode sample = 34; } } @@ -727,6 +728,7 @@ message PhysicalPlanNode { UnnestExecNode unnest = 30; JsonScanExecNode json_scan = 31; YieldStreamExecNode yield_stream = 32; + SampleExecNode sample = 33; } } @@ -1292,3 +1294,19 @@ message CteWorkTableScanNode { string name = 1; datafusion_common.Schema schema = 2; } + +message SampleNode { + LogicalPlanNode input = 1; + double lower_bound = 2; + double upper_bound = 3; + bool with_replacement = 4; + uint64 seed = 5; +} + +message SampleExecNode { + PhysicalPlanNode input = 1; + double lower_bound = 2; + double upper_bound = 3; + bool with_replacement = 4; + uint64 seed = 5; +} diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 02a1cc70eeb9..6144f97ae7a5 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -11062,6 +11062,9 @@ impl serde::Serialize for LogicalPlanNode { logical_plan_node::LogicalPlanType::Dml(v) => { struct_ser.serialize_field("dml", v)?; } + logical_plan_node::LogicalPlanType::Sample(v) => { + struct_ser.serialize_field("sample", v)?; + } } } struct_ser.end() @@ -11121,6 +11124,7 @@ impl<'de> serde::Deserialize<'de> for LogicalPlanNode { "cte_work_table_scan", "cteWorkTableScan", "dml", + "sample", ]; #[allow(clippy::enum_variant_names)] @@ -11157,6 +11161,7 @@ impl<'de> serde::Deserialize<'de> for LogicalPlanNode { RecursiveQuery, CteWorkTableScan, Dml, + Sample, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -11210,6 +11215,7 @@ impl<'de> serde::Deserialize<'de> for LogicalPlanNode { "recursiveQuery" | "recursive_query" => Ok(GeneratedField::RecursiveQuery), "cteWorkTableScan" | "cte_work_table_scan" => Ok(GeneratedField::CteWorkTableScan), "dml" => Ok(GeneratedField::Dml), + "sample" => Ok(GeneratedField::Sample), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -11454,6 +11460,13 @@ impl<'de> serde::Deserialize<'de> for LogicalPlanNode { return Err(serde::de::Error::duplicate_field("dml")); } logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Dml) +; + } + GeneratedField::Sample => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("sample")); + } + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Sample) ; } } @@ -15803,6 +15816,9 @@ impl serde::Serialize for PhysicalPlanNode { physical_plan_node::PhysicalPlanType::YieldStream(v) => { struct_ser.serialize_field("yieldStream", v)?; } + physical_plan_node::PhysicalPlanType::Sample(v) => { + struct_ser.serialize_field("sample", v)?; + } } } struct_ser.end() @@ -15863,6 +15879,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { "jsonScan", "yield_stream", "yieldStream", + "sample", ]; #[allow(clippy::enum_variant_names)] @@ -15898,6 +15915,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { Unnest, JsonScan, YieldStream, + Sample, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -15950,6 +15968,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { "unnest" => Ok(GeneratedField::Unnest), "jsonScan" | "json_scan" => Ok(GeneratedField::JsonScan), "yieldStream" | "yield_stream" => Ok(GeneratedField::YieldStream), + "sample" => Ok(GeneratedField::Sample), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -16187,6 +16206,13 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { return Err(serde::de::Error::duplicate_field("yieldStream")); } physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::YieldStream) +; + } + GeneratedField::Sample => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("sample")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Sample) ; } } @@ -18545,6 +18571,346 @@ impl<'de> serde::Deserialize<'de> for RollupNode { deserializer.deserialize_struct("datafusion.RollupNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for SampleExecNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.input.is_some() { + len += 1; + } + if self.lower_bound != 0. { + len += 1; + } + if self.upper_bound != 0. { + len += 1; + } + if self.with_replacement { + len += 1; + } + if self.seed != 0 { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.SampleExecNode", len)?; + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; + } + if self.lower_bound != 0. { + struct_ser.serialize_field("lowerBound", &self.lower_bound)?; + } + if self.upper_bound != 0. { + struct_ser.serialize_field("upperBound", &self.upper_bound)?; + } + if self.with_replacement { + struct_ser.serialize_field("withReplacement", &self.with_replacement)?; + } + if self.seed != 0 { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("seed", ToString::to_string(&self.seed).as_str())?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for SampleExecNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "input", + "lower_bound", + "lowerBound", + "upper_bound", + "upperBound", + "with_replacement", + "withReplacement", + "seed", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Input, + LowerBound, + UpperBound, + WithReplacement, + Seed, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "input" => Ok(GeneratedField::Input), + "lowerBound" | "lower_bound" => Ok(GeneratedField::LowerBound), + "upperBound" | "upper_bound" => Ok(GeneratedField::UpperBound), + "withReplacement" | "with_replacement" => Ok(GeneratedField::WithReplacement), + "seed" => Ok(GeneratedField::Seed), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = SampleExecNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.SampleExecNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut input__ = None; + let mut lower_bound__ = None; + let mut upper_bound__ = None; + let mut with_replacement__ = None; + let mut seed__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); + } + input__ = map_.next_value()?; + } + GeneratedField::LowerBound => { + if lower_bound__.is_some() { + return Err(serde::de::Error::duplicate_field("lowerBound")); + } + lower_bound__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::UpperBound => { + if upper_bound__.is_some() { + return Err(serde::de::Error::duplicate_field("upperBound")); + } + upper_bound__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::WithReplacement => { + if with_replacement__.is_some() { + return Err(serde::de::Error::duplicate_field("withReplacement")); + } + with_replacement__ = Some(map_.next_value()?); + } + GeneratedField::Seed => { + if seed__.is_some() { + return Err(serde::de::Error::duplicate_field("seed")); + } + seed__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + } + } + Ok(SampleExecNode { + input: input__, + lower_bound: lower_bound__.unwrap_or_default(), + upper_bound: upper_bound__.unwrap_or_default(), + with_replacement: with_replacement__.unwrap_or_default(), + seed: seed__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.SampleExecNode", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for SampleNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.input.is_some() { + len += 1; + } + if self.lower_bound != 0. { + len += 1; + } + if self.upper_bound != 0. { + len += 1; + } + if self.with_replacement { + len += 1; + } + if self.seed != 0 { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.SampleNode", len)?; + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; + } + if self.lower_bound != 0. { + struct_ser.serialize_field("lowerBound", &self.lower_bound)?; + } + if self.upper_bound != 0. { + struct_ser.serialize_field("upperBound", &self.upper_bound)?; + } + if self.with_replacement { + struct_ser.serialize_field("withReplacement", &self.with_replacement)?; + } + if self.seed != 0 { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("seed", ToString::to_string(&self.seed).as_str())?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for SampleNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "input", + "lower_bound", + "lowerBound", + "upper_bound", + "upperBound", + "with_replacement", + "withReplacement", + "seed", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Input, + LowerBound, + UpperBound, + WithReplacement, + Seed, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "input" => Ok(GeneratedField::Input), + "lowerBound" | "lower_bound" => Ok(GeneratedField::LowerBound), + "upperBound" | "upper_bound" => Ok(GeneratedField::UpperBound), + "withReplacement" | "with_replacement" => Ok(GeneratedField::WithReplacement), + "seed" => Ok(GeneratedField::Seed), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = SampleNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.SampleNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut input__ = None; + let mut lower_bound__ = None; + let mut upper_bound__ = None; + let mut with_replacement__ = None; + let mut seed__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); + } + input__ = map_.next_value()?; + } + GeneratedField::LowerBound => { + if lower_bound__.is_some() { + return Err(serde::de::Error::duplicate_field("lowerBound")); + } + lower_bound__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::UpperBound => { + if upper_bound__.is_some() { + return Err(serde::de::Error::duplicate_field("upperBound")); + } + upper_bound__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::WithReplacement => { + if with_replacement__.is_some() { + return Err(serde::de::Error::duplicate_field("withReplacement")); + } + with_replacement__ = Some(map_.next_value()?); + } + GeneratedField::Seed => { + if seed__.is_some() { + return Err(serde::de::Error::duplicate_field("seed")); + } + seed__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + } + } + Ok(SampleNode { + input: input__, + lower_bound: lower_bound__.unwrap_or_default(), + upper_bound: upper_bound__.unwrap_or_default(), + with_replacement: with_replacement__.unwrap_or_default(), + seed: seed__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.SampleNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for ScalarUdfExprNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index c1f8fa61f3b3..7a7f592e7010 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -5,7 +5,7 @@ pub struct LogicalPlanNode { #[prost( oneof = "logical_plan_node::LogicalPlanType", - tags = "1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33" + tags = "1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34" )] pub logical_plan_type: ::core::option::Option, } @@ -77,6 +77,8 @@ pub mod logical_plan_node { CteWorkTableScan(super::CteWorkTableScanNode), #[prost(message, tag = "33")] Dml(::prost::alloc::boxed::Box), + #[prost(message, tag = "34")] + Sample(::prost::alloc::boxed::Box), } } #[derive(Clone, PartialEq, ::prost::Message)] @@ -1048,7 +1050,7 @@ pub mod table_reference { pub struct PhysicalPlanNode { #[prost( oneof = "physical_plan_node::PhysicalPlanType", - tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32" + tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33" )] pub physical_plan_type: ::core::option::Option, } @@ -1120,6 +1122,8 @@ pub mod physical_plan_node { JsonScan(super::JsonScanExecNode), #[prost(message, tag = "32")] YieldStream(::prost::alloc::boxed::Box), + #[prost(message, tag = "33")] + Sample(::prost::alloc::boxed::Box), } } #[derive(Clone, PartialEq, ::prost::Message)] @@ -1943,6 +1947,32 @@ pub struct CteWorkTableScanNode { #[prost(message, optional, tag = "2")] pub schema: ::core::option::Option, } +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct SampleNode { + #[prost(message, optional, boxed, tag = "1")] + pub input: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(double, tag = "2")] + pub lower_bound: f64, + #[prost(double, tag = "3")] + pub upper_bound: f64, + #[prost(bool, tag = "4")] + pub with_replacement: bool, + #[prost(uint64, tag = "5")] + pub seed: u64, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct SampleExecNode { + #[prost(message, optional, boxed, tag = "1")] + pub input: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(double, tag = "2")] + pub lower_bound: f64, + #[prost(double, tag = "3")] + pub upper_bound: f64, + #[prost(bool, tag = "4")] + pub with_replacement: bool, + #[prost(uint64, tag = "5")] + pub seed: u64, +} #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum WindowFrameUnits { diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 1acf1ee27bfe..b7bcdcbd0408 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -71,8 +71,7 @@ use datafusion_expr::{ Statement, WindowUDF, }; use datafusion_expr::{ - AggregateUDF, ColumnUnnestList, DmlStatement, FetchType, RecursiveQuery, SkipType, - TableSource, Unnest, + AggregateUDF, ColumnUnnestList, DmlStatement, FetchType, RecursiveQuery, Sample, SkipType, TableSource, Unnest }; use self::to_proto::{serialize_expr, serialize_exprs}; @@ -994,6 +993,16 @@ impl AsLogicalPlan for LogicalPlanNode { Arc::new(into_logical_plan!(dml_node.input, ctx, extension_codec)?), ), )), + LogicalPlanType::Sample(sample) => { + let input = into_logical_plan!(sample.input, ctx, extension_codec)?; + Ok(LogicalPlan::Sample(Sample { + input: Arc::new(input), + lower_bound: sample.lower_bound, + upper_bound: sample.upper_bound, + with_replacement: sample.with_replacement, + seed: sample.seed, + })) + } } } @@ -1806,6 +1815,22 @@ impl AsLogicalPlan for LogicalPlanNode { ))), }) } + LogicalPlan::Sample(sample) => { + let input = LogicalPlanNode::try_from_logical_plan( + sample.input.as_ref(), + extension_codec, + )?; + Ok(LogicalPlanNode { + logical_plan_type: Some(LogicalPlanType::Sample(Box::new(protobuf::SampleNode { + input: Some(Box::new(input)), + lower_bound: sample.lower_bound, + upper_bound: sample.upper_bound, + with_replacement: sample.with_replacement, + seed: sample.seed, + }, + ))), + }) + } } } } diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 3d541f54fe10..80ab98b98048 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -84,7 +84,7 @@ use datafusion::physical_plan::unnest::{ListUnnest, UnnestExec}; use datafusion::physical_plan::windows::{BoundedWindowAggExec, WindowAggExec}; use datafusion::physical_plan::yield_stream::YieldStreamExec; use datafusion::physical_plan::{ - ExecutionPlan, InputOrderMode, PhysicalExpr, WindowExpr, + ExecutionPlan, InputOrderMode, PhysicalExpr, SampleExec, WindowExpr }; use datafusion_common::config::TableParquetOptions; use datafusion_common::{internal_err, not_impl_err, DataFusionError, Result}; @@ -331,6 +331,12 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { runtime, extension_codec, ), + PhysicalPlanType::Sample(sample) => self.try_into_sample_physical_plan( + sample, + registry, + runtime, + extension_codec, + ), } } @@ -527,6 +533,13 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { ); } + if let Some(exec) = plan.downcast_ref::() { + return protobuf::PhysicalPlanNode::try_from_sample_exec( + exec, + extension_codec, + ); + } + let mut buf: Vec = vec![]; match extension_codec.try_encode(Arc::clone(&plan_clone), &mut buf) { Ok(_) => { @@ -556,6 +569,17 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { } impl protobuf::PhysicalPlanNode { + fn try_into_sample_physical_plan( + &self, + sample: &protobuf::SampleExecNode, + registry: &dyn FunctionRegistry, + runtime: &RuntimeEnv, + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result> { + let input = into_physical_plan(&sample.input, registry, runtime, extension_codec)?; + Ok(Arc::new(SampleExec::try_new(input, sample.lower_bound, sample.upper_bound, sample.with_replacement, sample.seed)?)) + } + fn try_into_explain_physical_plan( &self, explain: &protobuf::ExplainExecNode, @@ -2809,6 +2833,25 @@ impl protobuf::PhysicalPlanNode { ))), }) } + + fn try_from_sample_exec( + exec: &SampleExec, + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result { + let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + exec.input().to_owned(), + extension_codec, + )?; + Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::Sample(Box::new(protobuf::SampleExecNode { + input: Some(Box::new(input)), + lower_bound: exec.lower_bound(), + upper_bound: exec.upper_bound(), + with_replacement: exec.with_replacement(), + seed: exec.seed(), + }))), + }) + } } pub trait AsExecutionPlan: Debug + Send + Sync + Clone { diff --git a/datafusion/sql/src/relation/mod.rs b/datafusion/sql/src/relation/mod.rs index 88a32a218341..a57b6784bc57 100644 --- a/datafusion/sql/src/relation/mod.rs +++ b/datafusion/sql/src/relation/mod.rs @@ -15,18 +15,22 @@ // specific language governing permissions and limitations // under the License. +use std::str::FromStr; use std::sync::Arc; use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{ - not_impl_err, plan_err, DFSchema, Diagnostic, Result, Span, Spans, TableReference, + not_impl_err, plan_err, DFSchema, Diagnostic, Result, Span, Spans, TableReference }; use datafusion_expr::builder::subquery_alias; use datafusion_expr::{expr::Unnest, Expr, LogicalPlan, LogicalPlanBuilder}; use datafusion_expr::{Subquery, SubqueryAlias}; -use sqlparser::ast::{FunctionArg, FunctionArgExpr, Spanned, TableFactor}; +use sqlparser::ast::{ + FunctionArg, FunctionArgExpr, Spanned, TableFactor, TableSampleKind, + TableSampleUnit, +}; mod join; @@ -40,7 +44,7 @@ impl SqlToRel<'_, S> { let relation_span = relation.span(); let (plan, alias) = match relation { TableFactor::Table { - name, alias, args, .. + name, alias, args, sample, .. } => { if let Some(func_args) = args { let tbl_func_name = @@ -64,7 +68,7 @@ impl SqlToRel<'_, S> { let provider = self .context_provider .get_table_function_source(&tbl_func_name, args)?; - let plan = LogicalPlanBuilder::scan( + let mut plan = LogicalPlanBuilder::scan( TableReference::Bare { table: "tmp_table".into(), }, @@ -72,34 +76,38 @@ impl SqlToRel<'_, S> { None, )? .build()?; + if let Some(sample) = sample { + plan = self.table_sample(sample, plan, planner_context)?; + } (plan, alias) } else { // Normalize name and alias let table_ref = self.object_name_to_table_reference(name)?; let table_name = table_ref.to_string(); let cte = planner_context.get_cte(&table_name); - ( - match ( - cte, - self.context_provider.get_table_source(table_ref.clone()), - ) { - (Some(cte_plan), _) => Ok(cte_plan.clone()), - (_, Ok(provider)) => LogicalPlanBuilder::scan( - table_ref.clone(), - provider, - None, - )? - .build(), - (None, Err(e)) => { - let e = e.with_diagnostic(Diagnostic::new_error( - format!("table '{table_ref}' not found"), - Span::try_from_sqlparser_span(relation_span), - )); - Err(e) - } - }?, - alias, - ) + let mut plan = match ( + cte, + self.context_provider.get_table_source(table_ref.clone()), + ) { + (Some(cte_plan), _) => Ok(cte_plan.clone()), + (_, Ok(provider)) => LogicalPlanBuilder::scan( + table_ref.clone(), + provider, + None, + )? + .build(), + (None, Err(e)) => { + let e = e.with_diagnostic(Diagnostic::new_error( + format!("table '{table_ref}' not found"), + Span::try_from_sqlparser_span(relation_span), + )); + Err(e) + } + }?; + if let Some(sample) = sample { + plan = self.table_sample(sample, plan, planner_context)?; + } + (plan, alias) } } TableFactor::Derived { @@ -224,6 +232,95 @@ impl SqlToRel<'_, S> { })), } } + + fn table_sample( + &self, + sample: TableSampleKind, + input: LogicalPlan, + _planner_context: &mut PlannerContext, + ) -> Result { + let sample = match sample { + TableSampleKind::BeforeTableAlias(sample) => sample, + TableSampleKind::AfterTableAlias(sample) => sample, + }; + if sample.name.is_some() { + // Postgres-style sample. Not supported because DataFusion does not have a concept of pages like PostgreSQL. + return not_impl_err!("{} is not supported yet", sample.name.unwrap()); + } + if sample.offset.is_some() { + // Clickhouse-style sample. Not supported because it requires knowing the total data size. + return not_impl_err!("Offset sample is not supported yet"); + } + + let seed = sample.seed.map(|seed| { + let Ok(seed) = seed.to_string().parse::() else { + return plan_err!("seed must be a number"); + }; + Ok(seed) + }).transpose()?; + + if let Some(bucket) = sample.bucket { + if bucket.on.is_some() { + // Hive-style sample, only used when the Hive table is defined with CLUSTERED BY + return not_impl_err!("Bucket sample with ON is not supported yet"); + } + + let Ok(bucket_num) = bucket.bucket.to_string().parse::() else { + return plan_err!("bucket must be a number"); + }; + + let Ok(total_num) = bucket.total.to_string().parse::() else { + return plan_err!("total must be a number"); + }; + let logical_plan = LogicalPlanBuilder::from(input).sample(bucket_num as f64 / total_num as f64, None, seed)?.build()?; + return Ok(logical_plan); + } + if let Some(quantity) = sample.quantity { + match quantity.unit { + Some(TableSampleUnit::Rows) => { + let value = evaluate_number::(&quantity.value); + if value.is_none() { + return plan_err!("quantity must be a non-negative number: {:?}", quantity.value); + } + let value = value.unwrap(); + if value < 0 { + return plan_err!("quantity must be a non-negative number: {:?}", quantity.value); + } + let logical_plan = LogicalPlanBuilder::from(input).limit(0, Some(value as usize))?.build()?; + return Ok(logical_plan); + } + Some(TableSampleUnit::Percent) => { + let value = evaluate_number::(&quantity.value); + if value.is_none() { + return plan_err!("quantity must be a number: {:?}", quantity.value); + } + let value = value.unwrap() / 100.0; + let logical_plan = LogicalPlanBuilder::from(input).sample(value, None, seed)?.build()?; + return Ok(logical_plan); + } + None => { + // Clickhouse-style sample + let value = evaluate_number::(&quantity.value); + if value.is_none() { + return plan_err!("quantity must be a non-negative number: {:?}", quantity.value); + } + let value = value.unwrap(); + if value < 0.0 { + return plan_err!("quantity must be a non-negative number: {:?}", quantity.value); + } + if value >= 1.0 { + let logical_plan = LogicalPlanBuilder::from(input).limit(0, Some(value as usize))?.build()?; + return Ok(logical_plan); + } else { + let logical_plan = LogicalPlanBuilder::from(input).sample(value, None, seed)?.build()?; + return Ok(logical_plan); + } + + } + } + } + Ok(input) + } } fn optimize_subquery_sort(plan: LogicalPlan) -> Result> { @@ -252,3 +349,38 @@ fn optimize_subquery_sort(plan: LogicalPlan) -> Result> }); new_plan } + + +fn evaluate_number + std::ops::Sub + std::ops::Mul + std::ops::Div>(expr: &sqlparser::ast::Expr) -> Option { + match expr { + sqlparser::ast::Expr::BinaryOp { left, op, right } => { + let left = evaluate_number::(&left); + let right = evaluate_number::(&right); + match (left, right) { + (Some(left), Some(right)) => { + match op { + sqlparser::ast::BinaryOperator::Plus => Some(left + right), + sqlparser::ast::BinaryOperator::Minus => Some(left - right), + sqlparser::ast::BinaryOperator::Multiply => Some(left * right), + sqlparser::ast::BinaryOperator::Divide => Some(left / right), + _ => None, + } + } + _ => None, + } + } + sqlparser::ast::Expr::Value(value) => { + match &value.value { + sqlparser::ast::Value::Number(value, _) => { + let value = format!("{value}"); + let Ok(value) = value.parse::() else { + return None; + }; + Some(value) + } + _ => None, + } + } + _ => None, + } +} diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index d9f9767ba9e4..10a66ad2746a 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -113,6 +113,7 @@ impl Unparser<'_> { | LogicalPlan::Limit(_) | LogicalPlan::Statement(_) | LogicalPlan::Values(_) + | LogicalPlan::Sample(_) | LogicalPlan::Distinct(_) => self.select_to_sql_statement(&plan), LogicalPlan::Dml(_) => self.dml_to_sql(&plan), LogicalPlan::Extension(extension) => { diff --git a/datafusion/sqllogictest/test_files/sample.slt b/datafusion/sqllogictest/test_files/sample.slt new file mode 100644 index 000000000000..11c9f3bfee07 --- /dev/null +++ b/datafusion/sqllogictest/test_files/sample.slt @@ -0,0 +1,226 @@ +# SAMPLE function tests with REPEATABLE seed for stability + +# Test basic SAMPLE with REPEATABLE seed +statement ok +CREATE TABLE sample_test (id INT, name VARCHAR, value DOUBLE) + +statement ok +INSERT INTO sample_test VALUES + (1, 'Alice', 10.5), + (2, 'Bob', 20.3), + (3, 'Charlie', 15.7), + (4, 'David', 25.1), + (5, 'Eve', 30.2), + (6, 'Frank', 12.8), + (7, 'Grace', 18.9), + (8, 'Henry', 22.4), + (9, 'Ivy', 28.6), + (10, 'Jack', 35.0) + +# Test SAMPLE with 50% ratio and REPEATABLE seed +query TT +SELECT COUNT(*) as count FROM sample_test SAMPLE (50 PERCENT) REPEATABLE (42) +---- +5 + +# Test SAMPLE with 30% ratio and REPEATABLE seed +query TT +SELECT COUNT(*) as count FROM sample_test SAMPLE (30 PERCENT) REPEATABLE (42) +---- +3 + +# Test SAMPLE with 70% ratio and REPEATABLE seed +query TT +SELECT COUNT(*) as count FROM sample_test SAMPLE (70 PERCENT) REPEATABLE (42) +---- +7 + +# Test SAMPLE with 100% ratio and REPEATABLE seed (should return all rows) +query TT +SELECT COUNT(*) as count FROM sample_test SAMPLE (100 PERCENT) REPEATABLE (42) +---- +10 + +# Test SAMPLE with 0% ratio and REPEATABLE seed (should return no rows) +query TT +SELECT COUNT(*) as count FROM sample_test SAMPLE (0 PERCENT) REPEATABLE (42) +---- +0 + +# Test SAMPLE with specific columns and REPEATABLE seed +query TT +SELECT id, name FROM sample_test SAMPLE (50 PERCENT) REPEATABLE (42) ORDER BY id +---- +1 Alice +3 Charlie +5 Eve +7 Grace +9 Ivy + +# Test SAMPLE with WHERE clause and REPEATABLE seed +query TT +SELECT COUNT(*) as count FROM sample_test SAMPLE (50 PERCENT) REPEATABLE (42) WHERE value > 20 +---- +2 + +# Test SAMPLE with ORDER BY and REPEATABLE seed +query TT +SELECT id, value FROM sample_test SAMPLE (40 PERCENT) REPEATABLE (42) ORDER BY value DESC +---- +5 30.2 +9 28.6 +4 25.1 +8 22.4 + +# Test SAMPLE with GROUP BY and REPEATABLE seed +query TT +SELECT COUNT(*) as count, AVG(value) as avg_value FROM sample_test SAMPLE (60 PERCENT) REPEATABLE (42) +---- +6 21.2 + +# Test SAMPLE with different seed values (should give different results) +query TT +SELECT COUNT(*) as count FROM sample_test SAMPLE (50 PERCENT) REPEATABLE (123) +---- +4 + +# Test SAMPLE with same seed value (should give same results) +query TT +SELECT COUNT(*) as count FROM sample_test SAMPLE (50 PERCENT) REPEATABLE (42) +---- +5 + +# Test SAMPLE with LIMIT and REPEATABLE seed +query TT +SELECT id, name FROM sample_test SAMPLE (80 PERCENT) REPEATABLE (42) ORDER BY id LIMIT 3 +---- +1 Alice +2 Bob +3 Charlie + +# Test SAMPLE with subquery and REPEATABLE seed +query TT +SELECT COUNT(*) as count FROM ( + SELECT * FROM sample_test SAMPLE (50 PERCENT) REPEATABLE (42) +) t +---- +5 + +# Test SAMPLE with JOIN and REPEATABLE seed +statement ok +CREATE TABLE sample_test2 (id INT, category VARCHAR) + +statement ok +INSERT INTO sample_test2 VALUES + (1, 'A'), + (2, 'B'), + (3, 'A'), + (4, 'B'), + (5, 'A') + +query TT +SELECT COUNT(*) as count FROM sample_test s1 +JOIN sample_test2 s2 ON s1.id = s2.id +SAMPLE (50 PERCENT) REPEATABLE (42) +---- +3 + +# Test SAMPLE with UNION and REPEATABLE seed +query TT +SELECT COUNT(*) as count FROM ( + SELECT id FROM sample_test SAMPLE (50 PERCENT) REPEATABLE (42) + UNION + SELECT id FROM sample_test SAMPLE (30 PERCENT) REPEATABLE (42) +) t +---- +7 + +# Test SAMPLE with CTE and REPEATABLE seed +query TT +WITH sampled_data AS ( + SELECT * FROM sample_test SAMPLE (50 PERCENT) REPEATABLE (42) +) +SELECT COUNT(*) as count FROM sampled_data +---- +5 + +# Test SAMPLE with window function (should work correctly) +query TT +SELECT id, name, value, + ROW_NUMBER() OVER (ORDER BY value) as rn +FROM sample_test SAMPLE (50 PERCENT) REPEATABLE (42) +ORDER BY id +---- +1 Alice 10.5 1 +3 Charlie 15.7 2 +5 Eve 30.2 3 +7 Grace 18.9 4 +9 Ivy 28.6 5 + +# Test SAMPLE with aggregation and REPEATABLE seed +query TT +SELECT MIN(value) as min_val, MAX(value) as max_val, AVG(value) as avg_val +FROM sample_test SAMPLE (60 PERCENT) REPEATABLE (42) +---- +10.5 30.2 20.7 + +# Test SAMPLE with different percentage values +query TT +SELECT COUNT(*) as count FROM sample_test SAMPLE (25 PERCENT) REPEATABLE (42) +---- +3 + +query TT +SELECT COUNT(*) as count FROM sample_test SAMPLE (75 PERCENT) REPEATABLE (42) +---- +8 + +# Test SAMPLE with decimal percentage +query TT +SELECT COUNT(*) as count FROM sample_test SAMPLE (33.33 PERCENT) REPEATABLE (42) +---- +3 + +# Test SAMPLE with large percentage +query TT +SELECT COUNT(*) as count FROM sample_test SAMPLE (90 PERCENT) REPEATABLE (42) +---- +9 + +# Test SAMPLE with small percentage +query TT +SELECT COUNT(*) as count FROM sample_test SAMPLE (10 PERCENT) REPEATABLE (42) +---- +1 + +# Test SAMPLE with REPEATABLE seed in different contexts +query TT +SELECT COUNT(*) as count FROM sample_test SAMPLE (50 PERCENT) REPEATABLE (999) +---- +5 + +query TT +SELECT COUNT(*) as count FROM sample_test SAMPLE (50 PERCENT) REPEATABLE (1) +---- +4 + +# Test SAMPLE with REPEATABLE seed and complex WHERE conditions +query TT +SELECT COUNT(*) as count FROM sample_test SAMPLE (50 PERCENT) REPEATABLE (42) +WHERE value BETWEEN 15 AND 25 AND name LIKE '%e%' +---- +1 + +# Test SAMPLE with REPEATABLE seed and multiple conditions +query TT +SELECT COUNT(*) as count FROM sample_test SAMPLE (60 PERCENT) REPEATABLE (42) +WHERE id % 2 = 0 OR value > 20 +---- +4 + +# Clean up +statement ok +DROP TABLE sample_test + +statement ok +DROP TABLE sample_test2 diff --git a/datafusion/substrait/src/logical_plan/producer/rel/mod.rs b/datafusion/substrait/src/logical_plan/producer/rel/mod.rs index c3599a2635ff..8ac85143290a 100644 --- a/datafusion/substrait/src/logical_plan/producer/rel/mod.rs +++ b/datafusion/substrait/src/logical_plan/producer/rel/mod.rs @@ -74,5 +74,6 @@ pub fn to_substrait_rel( LogicalPlan::RecursiveQuery(plan) => { not_impl_err!("Unsupported plan type: {plan:?}")? } + LogicalPlan::Sample(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, } } From 2af8aee32fbb39757620a6dbb7decaa73a0bdfbd Mon Sep 17 00:00:00 2001 From: Chongchen Chen Date: Mon, 23 Jun 2025 11:51:22 +0800 Subject: [PATCH 02/10] update --- datafusion/core/tests/dataframe/mod.rs | 67 ++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 8d60dbea3d01..2bb0e9f84c59 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -6137,3 +6137,70 @@ async fn test_dataframe_macro() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn test_dataframe_sample() -> Result<()> { + let df = dataframe!( + "a" => (1..40).collect::>(), + )?; + + // Test sampling 20% of rows with replacement + let df_sampled = df.clone().sample(0.2, Some(true), Some(42))?; + assert_batches_eq!( + &[ + "+----+", + "| a |", + "+----+", + "| 8 |", + "| 10 |", + "| 19 |", + "| 29 |", + "| 29 |", + "| 36 |", + "+----+", + ], + &df_sampled.collect().await? + ); + + // Test sampling 20% of rows without replacement + let df_sampled = df.clone().sample(0.2, Some(false), Some(42))?; + assert_batches_eq!( + &[ + "+----+", + "| a |", + "+----+", + "| 5 |", + "| 9 |", + "| 10 |", + "| 14 |", + "| 17 |", + "| 19 |", + "| 24 |", + "| 39 |", + "+----+", + ], + &df_sampled.collect().await? + ); + + // Test sampling with None parameters (should use defaults) + let df_sampled_default = df.clone().sample(0.2, None, Some(42))?; + assert_batches_eq!( + &[ + "+----+", + "| a |", + "+----+", + "| 5 |", + "| 9 |", + "| 10 |", + "| 14 |", + "| 17 |", + "| 19 |", + "| 24 |", + "| 39 |", + "+----+", + ], + &df_sampled_default.collect().await? + ); + + Ok(()) +} From 987eb197a97f3cb7cf879a1053e6a6b2791e224c Mon Sep 17 00:00:00 2001 From: Chongchen Chen Date: Mon, 23 Jun 2025 12:05:30 +0800 Subject: [PATCH 03/10] update test --- datafusion/sql/src/relation/mod.rs | 4 +-- datafusion/sqllogictest/test_files/sample.slt | 26 +++++++------------ 2 files changed, 12 insertions(+), 18 deletions(-) diff --git a/datafusion/sql/src/relation/mod.rs b/datafusion/sql/src/relation/mod.rs index a57b6784bc57..6e1708167217 100644 --- a/datafusion/sql/src/relation/mod.rs +++ b/datafusion/sql/src/relation/mod.rs @@ -253,8 +253,8 @@ impl SqlToRel<'_, S> { } let seed = sample.seed.map(|seed| { - let Ok(seed) = seed.to_string().parse::() else { - return plan_err!("seed must be a number"); + let Ok(seed) = seed.value.to_string().parse::() else { + return plan_err!("seed must be a number: {}", seed.value); }; Ok(seed) }).transpose()?; diff --git a/datafusion/sqllogictest/test_files/sample.slt b/datafusion/sqllogictest/test_files/sample.slt index 11c9f3bfee07..14d9e5657d70 100644 --- a/datafusion/sqllogictest/test_files/sample.slt +++ b/datafusion/sqllogictest/test_files/sample.slt @@ -2,20 +2,17 @@ # Test basic SAMPLE with REPEATABLE seed statement ok -CREATE TABLE sample_test (id INT, name VARCHAR, value DOUBLE) +CREATE TABLE sample_test (id INT) AS SELECT * FROM generate_series(1, 40); -statement ok -INSERT INTO sample_test VALUES - (1, 'Alice', 10.5), - (2, 'Bob', 20.3), - (3, 'Charlie', 15.7), - (4, 'David', 25.1), - (5, 'Eve', 30.2), - (6, 'Frank', 12.8), - (7, 'Grace', 18.9), - (8, 'Henry', 22.4), - (9, 'Ivy', 28.6), - (10, 'Jack', 35.0) +query TT +EXPLAIN SELECT * FROM sample_test SAMPLE (50 PERCENT) REPEATABLE (42); +---- +logical_plan +01)Sample: lower_bound=0, upper_bound=0.5, with_replacement=false, seed=42 +02)--TableScan: sample_test projection=[id] +physical_plan +01)SampleExec: lower_bound=0, upper_bound=0.5, with_replacement=false, seed=42 +02)--DataSourceExec: partitions=1, partition_sizes=[1] # Test SAMPLE with 50% ratio and REPEATABLE seed query TT @@ -221,6 +218,3 @@ WHERE id % 2 = 0 OR value > 20 # Clean up statement ok DROP TABLE sample_test - -statement ok -DROP TABLE sample_test2 From b473d4b7f256f0128c7ab8bde7d108179b461dde Mon Sep 17 00:00:00 2001 From: Chongchen Chen Date: Mon, 23 Jun 2025 15:26:34 +0800 Subject: [PATCH 04/10] update slt --- datafusion/sqllogictest/test_files/sample.slt | 222 +++--------------- 1 file changed, 38 insertions(+), 184 deletions(-) diff --git a/datafusion/sqllogictest/test_files/sample.slt b/datafusion/sqllogictest/test_files/sample.slt index 14d9e5657d70..89741d4124f5 100644 --- a/datafusion/sqllogictest/test_files/sample.slt +++ b/datafusion/sqllogictest/test_files/sample.slt @@ -12,208 +12,62 @@ logical_plan 02)--TableScan: sample_test projection=[id] physical_plan 01)SampleExec: lower_bound=0, upper_bound=0.5, with_replacement=false, seed=42 -02)--DataSourceExec: partitions=1, partition_sizes=[1] +02)--DataSourceExec: partitions=4, partition_sizes=[1, 0, 0, 0] -# Test SAMPLE with 50% ratio and REPEATABLE seed query TT -SELECT COUNT(*) as count FROM sample_test SAMPLE (50 PERCENT) REPEATABLE (42) +EXPLAIN SELECT * FROM sample_test SAMPLE (10 ROWS); ---- -5 - -# Test SAMPLE with 30% ratio and REPEATABLE seed -query TT -SELECT COUNT(*) as count FROM sample_test SAMPLE (30 PERCENT) REPEATABLE (42) ----- -3 - -# Test SAMPLE with 70% ratio and REPEATABLE seed -query TT -SELECT COUNT(*) as count FROM sample_test SAMPLE (70 PERCENT) REPEATABLE (42) ----- -7 - -# Test SAMPLE with 100% ratio and REPEATABLE seed (should return all rows) -query TT -SELECT COUNT(*) as count FROM sample_test SAMPLE (100 PERCENT) REPEATABLE (42) ----- -10 - -# Test SAMPLE with 0% ratio and REPEATABLE seed (should return no rows) -query TT -SELECT COUNT(*) as count FROM sample_test SAMPLE (0 PERCENT) REPEATABLE (42) ----- -0 - -# Test SAMPLE with specific columns and REPEATABLE seed -query TT -SELECT id, name FROM sample_test SAMPLE (50 PERCENT) REPEATABLE (42) ORDER BY id ----- -1 Alice -3 Charlie -5 Eve -7 Grace -9 Ivy - -# Test SAMPLE with WHERE clause and REPEATABLE seed -query TT -SELECT COUNT(*) as count FROM sample_test SAMPLE (50 PERCENT) REPEATABLE (42) WHERE value > 20 ----- -2 - -# Test SAMPLE with ORDER BY and REPEATABLE seed -query TT -SELECT id, value FROM sample_test SAMPLE (40 PERCENT) REPEATABLE (42) ORDER BY value DESC ----- -5 30.2 -9 28.6 -4 25.1 -8 22.4 - -# Test SAMPLE with GROUP BY and REPEATABLE seed -query TT -SELECT COUNT(*) as count, AVG(value) as avg_value FROM sample_test SAMPLE (60 PERCENT) REPEATABLE (42) ----- -6 21.2 - -# Test SAMPLE with different seed values (should give different results) -query TT -SELECT COUNT(*) as count FROM sample_test SAMPLE (50 PERCENT) REPEATABLE (123) ----- -4 - -# Test SAMPLE with same seed value (should give same results) -query TT -SELECT COUNT(*) as count FROM sample_test SAMPLE (50 PERCENT) REPEATABLE (42) ----- -5 - -# Test SAMPLE with LIMIT and REPEATABLE seed -query TT -SELECT id, name FROM sample_test SAMPLE (80 PERCENT) REPEATABLE (42) ORDER BY id LIMIT 3 ----- -1 Alice -2 Bob -3 Charlie +logical_plan +01)Limit: skip=0, fetch=10 +02)--TableScan: sample_test projection=[id], fetch=10 +physical_plan +01)CoalescePartitionsExec: fetch=10 +02)--DataSourceExec: partitions=4, partition_sizes=[1, 0, 0, 0], fetch=10 -# Test SAMPLE with subquery and REPEATABLE seed query TT -SELECT COUNT(*) as count FROM ( - SELECT * FROM sample_test SAMPLE (50 PERCENT) REPEATABLE (42) -) t +EXPLAIN SELECT * FROM sample_test SAMPLE 0.5 REPEATABLE (42); ---- -5 - -# Test SAMPLE with JOIN and REPEATABLE seed -statement ok -CREATE TABLE sample_test2 (id INT, category VARCHAR) - -statement ok -INSERT INTO sample_test2 VALUES - (1, 'A'), - (2, 'B'), - (3, 'A'), - (4, 'B'), - (5, 'A') +logical_plan +01)Sample: lower_bound=0, upper_bound=0.5, with_replacement=false, seed=42 +02)--TableScan: sample_test projection=[id] +physical_plan +01)SampleExec: lower_bound=0, upper_bound=0.5, with_replacement=false, seed=42 +02)--DataSourceExec: partitions=4, partition_sizes=[1, 0, 0, 0] query TT -SELECT COUNT(*) as count FROM sample_test s1 -JOIN sample_test2 s2 ON s1.id = s2.id -SAMPLE (50 PERCENT) REPEATABLE (42) +EXPLAIN SELECT * FROM sample_test SAMPLE 10; ---- -3 +logical_plan +01)Limit: skip=0, fetch=10 +02)--TableScan: sample_test projection=[id], fetch=10 +physical_plan +01)CoalescePartitionsExec: fetch=10 +02)--DataSourceExec: partitions=4, partition_sizes=[1, 0, 0, 0], fetch=10 -# Test SAMPLE with UNION and REPEATABLE seed -query TT -SELECT COUNT(*) as count FROM ( - SELECT id FROM sample_test SAMPLE (50 PERCENT) REPEATABLE (42) - UNION - SELECT id FROM sample_test SAMPLE (30 PERCENT) REPEATABLE (42) -) t ----- -7 -# Test SAMPLE with CTE and REPEATABLE seed -query TT -WITH sampled_data AS ( - SELECT * FROM sample_test SAMPLE (50 PERCENT) REPEATABLE (42) -) -SELECT COUNT(*) as count FROM sampled_data +# Test SAMPLE with 20% ratio and REPEATABLE seed +query I +SELECT * FROM sample_test SAMPLE (20 PERCENT) REPEATABLE (42); ---- 5 - -# Test SAMPLE with window function (should work correctly) -query TT -SELECT id, name, value, - ROW_NUMBER() OVER (ORDER BY value) as rn -FROM sample_test SAMPLE (50 PERCENT) REPEATABLE (42) -ORDER BY id ----- -1 Alice 10.5 1 -3 Charlie 15.7 2 -5 Eve 30.2 3 -7 Grace 18.9 4 -9 Ivy 28.6 5 - -# Test SAMPLE with aggregation and REPEATABLE seed -query TT -SELECT MIN(value) as min_val, MAX(value) as max_val, AVG(value) as avg_val -FROM sample_test SAMPLE (60 PERCENT) REPEATABLE (42) ----- -10.5 30.2 20.7 - -# Test SAMPLE with different percentage values -query TT -SELECT COUNT(*) as count FROM sample_test SAMPLE (25 PERCENT) REPEATABLE (42) ----- -3 - -query TT -SELECT COUNT(*) as count FROM sample_test SAMPLE (75 PERCENT) REPEATABLE (42) ----- -8 - -# Test SAMPLE with decimal percentage -query TT -SELECT COUNT(*) as count FROM sample_test SAMPLE (33.33 PERCENT) REPEATABLE (42) ----- -3 - -# Test SAMPLE with large percentage -query TT -SELECT COUNT(*) as count FROM sample_test SAMPLE (90 PERCENT) REPEATABLE (42) ----- 9 +10 +14 +17 +19 +24 +39 -# Test SAMPLE with small percentage -query TT -SELECT COUNT(*) as count FROM sample_test SAMPLE (10 PERCENT) REPEATABLE (42) ----- -1 - -# Test SAMPLE with REPEATABLE seed in different contexts -query TT -SELECT COUNT(*) as count FROM sample_test SAMPLE (50 PERCENT) REPEATABLE (999) ----- -5 - -query TT -SELECT COUNT(*) as count FROM sample_test SAMPLE (50 PERCENT) REPEATABLE (1) ----- -4 - -# Test SAMPLE with REPEATABLE seed and complex WHERE conditions -query TT -SELECT COUNT(*) as count FROM sample_test SAMPLE (50 PERCENT) REPEATABLE (42) -WHERE value BETWEEN 15 AND 25 AND name LIKE '%e%' +query I +SELECT COUNT(DISTINCT id) FROM (SELECT id FROM sample_test SAMPLE (100 PERCENT)); ---- -1 +40 -# Test SAMPLE with REPEATABLE seed and multiple conditions -query TT -SELECT COUNT(*) as count FROM sample_test SAMPLE (60 PERCENT) REPEATABLE (42) -WHERE id % 2 = 0 OR value > 20 +# Test SAMPLE with 0% ratio and REPEATABLE seed (should return no rows) +query I +SELECT COUNT(DISTINCT id) FROM (SELECT id FROM sample_test SAMPLE (0 PERCENT)); ---- -4 +0 # Clean up statement ok From 0ef07807c09d59b4f8835217a942f8d99e325b2d Mon Sep 17 00:00:00 2001 From: Chongchen Chen Date: Mon, 23 Jun 2025 15:42:33 +0800 Subject: [PATCH 05/10] format --- datafusion/core/src/dataframe/mod.rs | 13 +- datafusion/core/src/physical_planner.rs | 8 +- datafusion/core/tests/dataframe/mod.rs | 40 +----- datafusion/expr/src/logical_plan/builder.rs | 12 +- datafusion/expr/src/logical_plan/mod.rs | 4 +- datafusion/expr/src/logical_plan/plan.rs | 11 +- datafusion/expr/src/logical_plan/tree_node.rs | 18 ++- datafusion/physical-plan/src/lib.rs | 2 +- datafusion/physical-plan/src/sample.rs | 87 +++++++---- datafusion/proto/src/logical_plan/mod.rs | 6 +- datafusion/proto/src/physical_plan/mod.rs | 29 ++-- datafusion/sql/src/relation/mod.rs | 135 +++++++++++------- 12 files changed, 221 insertions(+), 144 deletions(-) diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 2fd6de703f06..856b775ddf2a 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -2284,9 +2284,16 @@ impl DataFrame { /// df.show().await?; /// # Ok(()) /// ``` - /// - pub fn sample(self, fraction: f64, with_replacement: Option, seed: Option) -> Result { - let plan = LogicalPlanBuilder::from(self.plan).sample(fraction, with_replacement, seed)?.build()?; + /// + pub fn sample( + self, + fraction: f64, + with_replacement: Option, + seed: Option, + ) -> Result { + let plan = LogicalPlanBuilder::from(self.plan) + .sample(fraction, with_replacement, seed)? + .build()?; Ok(DataFrame { session_state: self.session_state, plan, diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index f2863f8d1a50..c7490063a080 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -878,7 +878,13 @@ impl DefaultPhysicalPlanner { .. }) => { let input = children.one()?; - let sample = SampleExec::try_new(input, *lower_bound, *upper_bound, *with_replacement, *seed)?; + let sample = SampleExec::try_new( + input, + *lower_bound, + *upper_bound, + *with_replacement, + *seed, + )?; Arc::new(sample) } LogicalPlan::Unnest(Unnest { diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 2bb0e9f84c59..03ccb09092da 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -6148,16 +6148,8 @@ async fn test_dataframe_sample() -> Result<()> { let df_sampled = df.clone().sample(0.2, Some(true), Some(42))?; assert_batches_eq!( &[ - "+----+", - "| a |", - "+----+", - "| 8 |", - "| 10 |", - "| 19 |", - "| 29 |", - "| 29 |", - "| 36 |", - "+----+", + "+----+", "| a |", "+----+", "| 8 |", "| 10 |", "| 19 |", "| 29 |", + "| 29 |", "| 36 |", "+----+", ], &df_sampled.collect().await? ); @@ -6166,18 +6158,8 @@ async fn test_dataframe_sample() -> Result<()> { let df_sampled = df.clone().sample(0.2, Some(false), Some(42))?; assert_batches_eq!( &[ - "+----+", - "| a |", - "+----+", - "| 5 |", - "| 9 |", - "| 10 |", - "| 14 |", - "| 17 |", - "| 19 |", - "| 24 |", - "| 39 |", - "+----+", + "+----+", "| a |", "+----+", "| 5 |", "| 9 |", "| 10 |", "| 14 |", + "| 17 |", "| 19 |", "| 24 |", "| 39 |", "+----+", ], &df_sampled.collect().await? ); @@ -6186,18 +6168,8 @@ async fn test_dataframe_sample() -> Result<()> { let df_sampled_default = df.clone().sample(0.2, None, Some(42))?; assert_batches_eq!( &[ - "+----+", - "| a |", - "+----+", - "| 5 |", - "| 9 |", - "| 10 |", - "| 14 |", - "| 17 |", - "| 19 |", - "| 24 |", - "| 39 |", - "+----+", + "+----+", "| a |", "+----+", "| 5 |", "| 9 |", "| 10 |", "| 14 |", + "| 17 |", "| 19 |", "| 24 |", "| 39 |", "+----+", ], &df_sampled_default.collect().await? ); diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 1eb313a2bcee..2e02216ae526 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -44,7 +44,8 @@ use crate::utils::{ group_window_expr_by_sort_keys, }; use crate::{ - and, binary_expr, lit, DmlStatement, Expr, ExprSchemable, Operator, RecursiveQuery, Statement, TableProviderFilterPushDown, TableSource, WriteOp + and, binary_expr, lit, DmlStatement, Expr, ExprSchemable, Operator, RecursiveQuery, + Statement, TableProviderFilterPushDown, TableSource, WriteOp, }; use super::dml::InsertOp; @@ -1475,13 +1476,18 @@ impl LogicalPlanBuilder { .map(Self::new) } - pub fn sample(self, fraction: f64, with_replacement: Option, seed: Option) -> Result { + pub fn sample( + self, + fraction: f64, + with_replacement: Option, + seed: Option, + ) -> Result { Ok(Self::new(LogicalPlan::Sample(Sample { input: self.plan, lower_bound: 0.0, upper_bound: fraction, with_replacement: with_replacement.unwrap_or(false), - seed: seed.unwrap_or_else(|| rand::random()), + seed: seed.unwrap_or_else(rand::random), }))) } } diff --git a/datafusion/expr/src/logical_plan/mod.rs b/datafusion/expr/src/logical_plan/mod.rs index fa276293fa79..4f1b08b5863e 100644 --- a/datafusion/expr/src/logical_plan/mod.rs +++ b/datafusion/expr/src/logical_plan/mod.rs @@ -40,8 +40,8 @@ pub use plan::{ projection_schema, Aggregate, Analyze, ColumnUnnestList, DescribeTable, Distinct, DistinctOn, EmptyRelation, Explain, ExplainFormat, Extension, FetchType, Filter, Join, JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, - Projection, RecursiveQuery, Repartition, SkipType, Sort, StringifiedPlan, Subquery, - SubqueryAlias, TableScan, ToStringifiedPlan, Union, Unnest, Values, Window, Sample, + Projection, RecursiveQuery, Repartition, Sample, SkipType, Sort, StringifiedPlan, + Subquery, SubqueryAlias, TableScan, ToStringifiedPlan, Union, Unnest, Values, Window, }; pub use statement::{ Deallocate, Execute, Prepare, SetVariable, Statement, TransactionAccessMode, diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 4fa6bfae3962..43c993e26e89 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -744,7 +744,7 @@ impl LogicalPlan { LogicalPlan::EmptyRelation(_) => Ok(self), LogicalPlan::Statement(_) => Ok(self), LogicalPlan::DescribeTable(_) => Ok(self), - LogicalPlan::Sample(Sample {..}) => Ok(self), + LogicalPlan::Sample(Sample { .. }) => Ok(self), LogicalPlan::Unnest(Unnest { input, exec_columns, @@ -898,7 +898,13 @@ impl LogicalPlan { fetch: *fetch, })) } - LogicalPlan::Sample(Sample { with_replacement, seed, lower_bound, upper_bound, .. }) => { + LogicalPlan::Sample(Sample { + with_replacement, + seed, + lower_bound, + upper_bound, + .. + }) => { self.assert_no_expressions(expr)?; let input = self.only_input(inputs)?; @@ -4050,7 +4056,6 @@ impl Hash for Sample { } } - #[cfg(test)] mod tests { diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index b6f4708bc906..cb33c620ccc2 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -149,9 +149,21 @@ impl TreeNode for LogicalPlan { LogicalPlan::Limit(Limit { skip, fetch, input }) => input .map_elements(f)? .update_data(|input| LogicalPlan::Limit(Limit { skip, fetch, input })), - LogicalPlan::Sample(Sample { input, lower_bound, upper_bound, with_replacement, seed }) => input - .map_elements(f)? - .update_data(|input| LogicalPlan::Sample(Sample { input, lower_bound, upper_bound, with_replacement, seed })), + LogicalPlan::Sample(Sample { + input, + lower_bound, + upper_bound, + with_replacement, + seed, + }) => input.map_elements(f)?.update_data(|input| { + LogicalPlan::Sample(Sample { + input, + lower_bound, + upper_bound, + with_replacement, + seed, + }) + }), LogicalPlan::Subquery(Subquery { subquery, outer_ref_columns, diff --git a/datafusion/physical-plan/src/lib.rs b/datafusion/physical-plan/src/lib.rs index fb696028a785..6af67040dfce 100644 --- a/datafusion/physical-plan/src/lib.rs +++ b/datafusion/physical-plan/src/lib.rs @@ -79,8 +79,8 @@ pub mod placeholder_row; pub mod projection; pub mod recursive_query; pub mod repartition; -pub mod sorts; pub mod sample; +pub mod sorts; pub mod spill; pub mod stream; pub mod streaming; diff --git a/datafusion/physical-plan/src/sample.rs b/datafusion/physical-plan/src/sample.rs index 2c2cdc150d7b..ade60a2edcd2 100644 --- a/datafusion/physical-plan/src/sample.rs +++ b/datafusion/physical-plan/src/sample.rs @@ -17,11 +17,11 @@ //! Defines the SAMPLE operator +use rand_distr::{Distribution, Poisson}; use std::any::Any; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use rand_distr::{Distribution, Poisson}; use super::{ DisplayAs, ExecutionPlanProperties, PlanProperties, RecordBatchStream, @@ -33,16 +33,16 @@ use crate::{ }; use arrow::array::UInt32Array; +use arrow::compute; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; -use arrow::compute; use datafusion_common::{internal_err, plan_datafusion_err, Result}; use datafusion_execution::TaskContext; use datafusion_physical_expr::EquivalenceProperties; use futures::stream::{Stream, StreamExt}; -use rand::{Rng, SeedableRng}; use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; trait Sampler: Send + Sync { fn sample(&mut self, batch: &RecordBatch) -> Result; @@ -56,13 +56,16 @@ struct BernoulliSampler { impl BernoulliSampler { fn new(lower_bound: f64, upper_bound: f64, seed: u64) -> Self { - Self { lower_bound, upper_bound, rng: StdRng::seed_from_u64(seed) } + Self { + lower_bound, + upper_bound, + rng: StdRng::seed_from_u64(seed), + } } } impl Sampler for BernoulliSampler { fn sample(&mut self, batch: &RecordBatch) -> Result { - if self.upper_bound <= self.lower_bound { return Ok(RecordBatch::new_empty(batch.schema())); } @@ -94,7 +97,11 @@ struct PoissonSampler { impl PoissonSampler { fn try_new(ratio: f64, seed: u64) -> Result { let poisson = Poisson::new(ratio).map_err(|e| plan_datafusion_err!("{}", e))?; - Ok(Self { ratio, poisson, rng: StdRng::seed_from_u64(seed) }) + Ok(Self { + ratio, + poisson, + rng: StdRng::seed_from_u64(seed), + }) } } @@ -103,9 +110,9 @@ impl Sampler for PoissonSampler { if self.ratio <= 0.0 { return Ok(RecordBatch::new_empty(batch.schema())); } - + let mut indices = Vec::new(); - + for i in 0..batch.num_rows() { let k = self.poisson.sample(&mut self.rng) as i32; for _ in 0..k { @@ -173,9 +180,16 @@ impl SampleExec { fn create_sampler(&self, partition: usize) -> Result> { if self.with_replacement { - Ok(Box::new(PoissonSampler::try_new(self.upper_bound - self.lower_bound, self.seed + partition as u64)?)) + Ok(Box::new(PoissonSampler::try_new( + self.upper_bound - self.lower_bound, + self.seed + partition as u64, + )?)) } else { - Ok(Box::new(BernoulliSampler::new(self.lower_bound, self.upper_bound, self.seed + partition as u64))) + Ok(Box::new(BernoulliSampler::new( + self.lower_bound, + self.upper_bound, + self.seed + partition as u64, + ))) } } @@ -267,7 +281,7 @@ impl ExecutionPlan for SampleExec { ) -> Result> { // Get the ratio from the current method Ok(Arc::new(SampleExec::try_new( - children[0].clone(), + Arc::clone(&children[0]), self.lower_bound, self.upper_bound, self.with_replacement, @@ -296,14 +310,20 @@ impl ExecutionPlan for SampleExec { fn partition_statistics(&self, partition: Option) -> Result { let input_stats = self.input.partition_statistics(partition)?; - + // Apply sampling ratio to statistics let mut stats = input_stats; let ratio = self.upper_bound - self.lower_bound; - - stats.num_rows = stats.num_rows.map(|nr| (nr as f64 * ratio) as usize).to_inexact(); - stats.total_byte_size = stats.total_byte_size.map(|tb| (tb as f64 * ratio) as usize).to_inexact(); - + + stats.num_rows = stats + .num_rows + .map(|nr| (nr as f64 * ratio) as usize) + .to_inexact(); + stats.total_byte_size = stats + .total_byte_size + .map(|tb| (tb as f64 * ratio) as usize) + .to_inexact(); + Ok(stats) } } @@ -379,25 +399,27 @@ mod tests { )?); let sample_exec = SampleExec::try_new(input, 0.6, 1.0, false, 42)?; - + let context = Arc::new(TaskContext::default()); let stream = sample_exec.execute(0, context)?; - + let batches = stream.try_collect::>().await?; assert!(!batches.is_empty()); - + let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum(); // With 60% sampling ratio and 5 input rows, we expect around 3 rows assert!(total_rows >= 2 && total_rows <= 4); - + Ok(()) } #[tokio::test] async fn test_sample_exec_poisson() -> Result<()> { - let schema = Arc::new(Schema::new(vec![ - Field::new("id", arrow::datatypes::DataType::Int32, false), - ])); + let schema = Arc::new(Schema::new(vec![Field::new( + "id", + arrow::datatypes::DataType::Int32, + false, + )])); let batch = RecordBatch::try_new( schema.clone(), @@ -411,24 +433,29 @@ mod tests { )?); let sample_exec = SampleExec::try_new(input, 0.0, 0.5, true, 42)?; - + let context = Arc::new(TaskContext::default()); let stream = sample_exec.execute(0, context)?; - + let batches = stream.try_collect::>().await?; assert!(!batches.is_empty()); - + Ok(()) } - + #[test] fn test_sampler_trait() { let mut bernoulli_sampler = BernoulliSampler::new(0.0, 0.5, 42); let batch = RecordBatch::try_new( - Arc::new(Schema::new(vec![Field::new("id", arrow::datatypes::DataType::Int32, false)])), + Arc::new(Schema::new(vec![Field::new( + "id", + arrow::datatypes::DataType::Int32, + false, + )])), vec![Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]))], - ).unwrap(); - + ) + .unwrap(); + let result = bernoulli_sampler.sample(&batch); assert!(result.is_ok()); diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index b7bcdcbd0408..bacfe482bff7 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -71,7 +71,8 @@ use datafusion_expr::{ Statement, WindowUDF, }; use datafusion_expr::{ - AggregateUDF, ColumnUnnestList, DmlStatement, FetchType, RecursiveQuery, Sample, SkipType, TableSource, Unnest + AggregateUDF, ColumnUnnestList, DmlStatement, FetchType, RecursiveQuery, Sample, + SkipType, TableSource, Unnest, }; use self::to_proto::{serialize_expr, serialize_exprs}; @@ -1821,7 +1822,8 @@ impl AsLogicalPlan for LogicalPlanNode { extension_codec, )?; Ok(LogicalPlanNode { - logical_plan_type: Some(LogicalPlanType::Sample(Box::new(protobuf::SampleNode { + logical_plan_type: Some(LogicalPlanType::Sample(Box::new( + protobuf::SampleNode { input: Some(Box::new(input)), lower_bound: sample.lower_bound, upper_bound: sample.upper_bound, diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 7ef26df329ad..283de36102fc 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -84,7 +84,7 @@ use datafusion::physical_plan::union::{InterleaveExec, UnionExec}; use datafusion::physical_plan::unnest::{ListUnnest, UnnestExec}; use datafusion::physical_plan::windows::{BoundedWindowAggExec, WindowAggExec}; use datafusion::physical_plan::{ - ExecutionPlan, InputOrderMode, PhysicalExpr, SampleExec, WindowExpr + ExecutionPlan, InputOrderMode, PhysicalExpr, SampleExec, WindowExpr, }; use datafusion_common::config::TableParquetOptions; use datafusion_common::{internal_err, not_impl_err, DataFusionError, Result}; @@ -576,8 +576,15 @@ impl protobuf::PhysicalPlanNode { runtime: &RuntimeEnv, extension_codec: &dyn PhysicalExtensionCodec, ) -> Result> { - let input = into_physical_plan(&sample.input, registry, runtime, extension_codec)?; - Ok(Arc::new(SampleExec::try_new(input, sample.lower_bound, sample.upper_bound, sample.with_replacement, sample.seed)?)) + let input = + into_physical_plan(&sample.input, registry, runtime, extension_codec)?; + Ok(Arc::new(SampleExec::try_new( + input, + sample.lower_bound, + sample.upper_bound, + sample.with_replacement, + sample.seed, + )?)) } fn try_into_explain_physical_plan( @@ -2839,13 +2846,15 @@ impl protobuf::PhysicalPlanNode { extension_codec, )?; Ok(protobuf::PhysicalPlanNode { - physical_plan_type: Some(PhysicalPlanType::Sample(Box::new(protobuf::SampleExecNode { - input: Some(Box::new(input)), - lower_bound: exec.lower_bound(), - upper_bound: exec.upper_bound(), - with_replacement: exec.with_replacement(), - seed: exec.seed(), - }))), + physical_plan_type: Some(PhysicalPlanType::Sample(Box::new( + protobuf::SampleExecNode { + input: Some(Box::new(input)), + lower_bound: exec.lower_bound(), + upper_bound: exec.upper_bound(), + with_replacement: exec.with_replacement(), + seed: exec.seed(), + }, + ))), }) } } diff --git a/datafusion/sql/src/relation/mod.rs b/datafusion/sql/src/relation/mod.rs index dea1f228887d..7e4818daefb9 100644 --- a/datafusion/sql/src/relation/mod.rs +++ b/datafusion/sql/src/relation/mod.rs @@ -22,14 +22,13 @@ use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{ - not_impl_err, plan_err, DFSchema, Diagnostic, Result, Span, Spans, TableReference + not_impl_err, plan_err, DFSchema, Diagnostic, Result, Span, Spans, TableReference, }; use datafusion_expr::builder::subquery_alias; use datafusion_expr::{expr::Unnest, Expr, LogicalPlan, LogicalPlanBuilder}; use datafusion_expr::{Subquery, SubqueryAlias}; use sqlparser::ast::{ - FunctionArg, FunctionArgExpr, Spanned, TableFactor, TableSampleKind, - TableSampleUnit, + FunctionArg, FunctionArgExpr, Spanned, TableFactor, TableSampleKind, TableSampleUnit, }; mod join; @@ -44,7 +43,11 @@ impl SqlToRel<'_, S> { let relation_span = relation.span(); let (plan, alias) = match relation { TableFactor::Table { - name, alias, args, sample, .. + name, + alias, + args, + sample, + .. } => { if let Some(func_args) = args { let tbl_func_name = @@ -90,12 +93,10 @@ impl SqlToRel<'_, S> { self.context_provider.get_table_source(table_ref.clone()), ) { (Some(cte_plan), _) => Ok(cte_plan.clone()), - (_, Ok(provider)) => LogicalPlanBuilder::scan( - table_ref.clone(), - provider, - None, - )? - .build(), + (_, Ok(provider)) => { + LogicalPlanBuilder::scan(table_ref.clone(), provider, None)? + .build() + } (None, Err(e)) => { let e = e.with_diagnostic(Diagnostic::new_error( format!("table '{table_ref}' not found"), @@ -252,19 +253,22 @@ impl SqlToRel<'_, S> { return not_impl_err!("Offset sample is not supported yet"); } - let seed = sample.seed.map(|seed| { - let Ok(seed) = seed.value.to_string().parse::() else { - return plan_err!("seed must be a number: {}", seed.value); - }; - Ok(seed) - }).transpose()?; + let seed = sample + .seed + .map(|seed| { + let Ok(seed) = seed.value.to_string().parse::() else { + return plan_err!("seed must be a number: {}", seed.value); + }; + Ok(seed) + }) + .transpose()?; if let Some(bucket) = sample.bucket { if bucket.on.is_some() { // Hive-style sample, only used when the Hive table is defined with CLUSTERED BY return not_impl_err!("Bucket sample with ON is not supported yet"); } - + let Ok(bucket_num) = bucket.bucket.to_string().parse::() else { return plan_err!("bucket must be a number"); }; @@ -272,7 +276,9 @@ impl SqlToRel<'_, S> { let Ok(total_num) = bucket.total.to_string().parse::() else { return plan_err!("total must be a number"); }; - let logical_plan = LogicalPlanBuilder::from(input).sample(bucket_num as f64 / total_num as f64, None, seed)?.build()?; + let logical_plan = LogicalPlanBuilder::from(input) + .sample(bucket_num as f64 / total_num as f64, None, seed)? + .build()?; return Ok(logical_plan); } if let Some(quantity) = sample.quantity { @@ -280,42 +286,64 @@ impl SqlToRel<'_, S> { Some(TableSampleUnit::Rows) => { let value = evaluate_number::(&quantity.value); if value.is_none() { - return plan_err!("quantity must be a non-negative number: {:?}", quantity.value); + return plan_err!( + "quantity must be a non-negative number: {:?}", + quantity.value + ); } let value = value.unwrap(); if value < 0 { - return plan_err!("quantity must be a non-negative number: {:?}", quantity.value); + return plan_err!( + "quantity must be a non-negative number: {:?}", + quantity.value + ); } - let logical_plan = LogicalPlanBuilder::from(input).limit(0, Some(value as usize))?.build()?; + let logical_plan = LogicalPlanBuilder::from(input) + .limit(0, Some(value as usize))? + .build()?; return Ok(logical_plan); } Some(TableSampleUnit::Percent) => { let value = evaluate_number::(&quantity.value); if value.is_none() { - return plan_err!("quantity must be a number: {:?}", quantity.value); + return plan_err!( + "quantity must be a number: {:?}", + quantity.value + ); } let value = value.unwrap() / 100.0; - let logical_plan = LogicalPlanBuilder::from(input).sample(value, None, seed)?.build()?; + let logical_plan = LogicalPlanBuilder::from(input) + .sample(value, None, seed)? + .build()?; return Ok(logical_plan); } None => { // Clickhouse-style sample let value = evaluate_number::(&quantity.value); if value.is_none() { - return plan_err!("quantity must be a non-negative number: {:?}", quantity.value); + return plan_err!( + "quantity must be a non-negative number: {:?}", + quantity.value + ); } let value = value.unwrap(); if value < 0.0 { - return plan_err!("quantity must be a non-negative number: {:?}", quantity.value); + return plan_err!( + "quantity must be a non-negative number: {:?}", + quantity.value + ); } if value >= 1.0 { - let logical_plan = LogicalPlanBuilder::from(input).limit(0, Some(value as usize))?.build()?; + let logical_plan = LogicalPlanBuilder::from(input) + .limit(0, Some(value as usize))? + .build()?; return Ok(logical_plan); } else { - let logical_plan = LogicalPlanBuilder::from(input).sample(value, None, seed)?.build()?; + let logical_plan = LogicalPlanBuilder::from(input) + .sample(value, None, seed)? + .build()?; return Ok(logical_plan); } - } } } @@ -350,37 +378,40 @@ fn optimize_subquery_sort(plan: LogicalPlan) -> Result> new_plan } - -fn evaluate_number + std::ops::Sub + std::ops::Mul + std::ops::Div>(expr: &sqlparser::ast::Expr) -> Option { +fn evaluate_number< + T: FromStr + + std::ops::Add + + std::ops::Sub + + std::ops::Mul + + std::ops::Div, +>( + expr: &sqlparser::ast::Expr, +) -> Option { match expr { sqlparser::ast::Expr::BinaryOp { left, op, right } => { - let left = evaluate_number::(&left); - let right = evaluate_number::(&right); + let left = evaluate_number::(left); + let right = evaluate_number::(right); match (left, right) { - (Some(left), Some(right)) => { - match op { - sqlparser::ast::BinaryOperator::Plus => Some(left + right), - sqlparser::ast::BinaryOperator::Minus => Some(left - right), - sqlparser::ast::BinaryOperator::Multiply => Some(left * right), - sqlparser::ast::BinaryOperator::Divide => Some(left / right), - _ => None, - } - } + (Some(left), Some(right)) => match op { + sqlparser::ast::BinaryOperator::Plus => Some(left + right), + sqlparser::ast::BinaryOperator::Minus => Some(left - right), + sqlparser::ast::BinaryOperator::Multiply => Some(left * right), + sqlparser::ast::BinaryOperator::Divide => Some(left / right), + _ => None, + }, _ => None, } } - sqlparser::ast::Expr::Value(value) => { - match &value.value { - sqlparser::ast::Value::Number(value, _) => { - let value = format!("{value}"); - let Ok(value) = value.parse::() else { - return None; - }; - Some(value) - } - _ => None, + sqlparser::ast::Expr::Value(value) => match &value.value { + sqlparser::ast::Value::Number(value, _) => { + let value = value.to_string(); + let Ok(value) = value.parse::() else { + return None; + }; + Some(value) } - } + _ => None, + }, _ => None, } } From 4c1aa20e086ff856b76f0f09bc6f4d0a0c230294 Mon Sep 17 00:00:00 2001 From: Chongchen Chen Date: Mon, 23 Jun 2025 16:33:16 +0800 Subject: [PATCH 06/10] update --- datafusion/core/src/dataframe/mod.rs | 11 ++--- datafusion/physical-plan/src/sample.rs | 62 +++++++++++++++++++------- 2 files changed, 53 insertions(+), 20 deletions(-) diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 856b775ddf2a..9423ad85d6d6 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -2272,17 +2272,18 @@ impl DataFrame { /// /// # Example /// ``` - /// # use datafusion::prelude::*; - /// # use datafusion::error::Result; - /// # #[tokio::main] - /// # async fn main() -> Result<()> { + /// use datafusion::prelude::*; /// let df = dataframe!( /// "id" => [1, 2, 3], /// "name" => ["foo", "bar", "baz"] /// )?; /// let df = df.sample(0.5, Some(true), Some(42))?; /// df.show().await?; - /// # Ok(()) + /// // +----+------+ + /// // | id | name | + /// // +----+------+ + /// // | 3 | baz | + /// // +----+------+ /// ``` /// pub fn sample( diff --git a/datafusion/physical-plan/src/sample.rs b/datafusion/physical-plan/src/sample.rs index ade60a2edcd2..6bdf52890097 100644 --- a/datafusion/physical-plan/src/sample.rs +++ b/datafusion/physical-plan/src/sample.rs @@ -371,8 +371,9 @@ impl RecordBatchStream for SampleExecStream { #[cfg(test)] mod tests { use super::*; - use arrow::array::{Int32Array, StringArray}; + use arrow::array::Int32Array; use arrow::datatypes::{Field, Schema}; + use datafusion_common::assert_batches_eq; use datafusion_execution::TaskContext; use futures::TryStreamExt; use std::sync::Arc; @@ -381,14 +382,12 @@ mod tests { async fn test_sample_exec_bernoulli() -> Result<()> { let schema = Arc::new(Schema::new(vec![ Field::new("id", arrow::datatypes::DataType::Int32, false), - Field::new("name", arrow::datatypes::DataType::Utf8, false), ])); let batch = RecordBatch::try_new( schema.clone(), vec![ Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])), - Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"])), ], )?; @@ -404,11 +403,16 @@ mod tests { let stream = sample_exec.execute(0, context)?; let batches = stream.try_collect::>().await?; - assert!(!batches.is_empty()); - - let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum(); - // With 60% sampling ratio and 5 input rows, we expect around 3 rows - assert!(total_rows >= 2 && total_rows <= 4); + assert_batches_eq!( + &[ + "+----+", + "| id |", + "+----+", + "| 3 |", + "+----+", + ], + &batches + ); Ok(()) } @@ -422,13 +426,13 @@ mod tests { )])); let batch = RecordBatch::try_new( - schema.clone(), + Arc::clone(&schema), vec![Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]))], )?; let input = Arc::new(crate::test::TestMemoryExec::try_new( &[vec![batch]], - schema.clone(), + Arc::clone(&schema), None, )?); @@ -438,7 +442,16 @@ mod tests { let stream = sample_exec.execute(0, context)?; let batches = stream.try_collect::>().await?; - assert!(!batches.is_empty()); + assert_batches_eq!( + &[ + "+----+", + "| id |", + "+----+", + "| 3 |", + "+----+", + ], + &batches + ); Ok(()) } @@ -456,11 +469,30 @@ mod tests { ) .unwrap(); - let result = bernoulli_sampler.sample(&batch); - assert!(result.is_ok()); + let result = bernoulli_sampler.sample(&batch).unwrap(); + assert_batches_eq!( + &[ + "+----+", + "| id |", + "+----+", + "| 4 |", + "| 5 |", + "+----+", + ], + &vec![result] + ); let mut poisson_sampler = PoissonSampler::try_new(0.5, 42).unwrap(); - let result = poisson_sampler.sample(&batch); - assert!(result.is_ok()); + let result = poisson_sampler.sample(&batch).unwrap(); + assert_batches_eq!( + &[ + "+----+", + "| id |", + "+----+", + "| 3 |", + "+----+", + ], + &vec![result] + ); } } From fbcb96dc3eeac7d39965f6c189913f7638e704f8 Mon Sep 17 00:00:00 2001 From: Chongchen Chen Date: Mon, 23 Jun 2025 16:48:11 +0800 Subject: [PATCH 07/10] fmt --- datafusion/physical-plan/src/sample.rs | 27 ++++++++++++-------------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/datafusion/physical-plan/src/sample.rs b/datafusion/physical-plan/src/sample.rs index 6bdf52890097..67d03b01231c 100644 --- a/datafusion/physical-plan/src/sample.rs +++ b/datafusion/physical-plan/src/sample.rs @@ -380,15 +380,15 @@ mod tests { #[tokio::test] async fn test_sample_exec_bernoulli() -> Result<()> { - let schema = Arc::new(Schema::new(vec![ - Field::new("id", arrow::datatypes::DataType::Int32, false), - ])); + let schema = Arc::new(Schema::new(vec![Field::new( + "id", + arrow::datatypes::DataType::Int32, + false, + )])); let batch = RecordBatch::try_new( schema.clone(), - vec![ - Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])), - ], + vec![Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]))], )?; let input = Arc::new(crate::test::TestMemoryExec::try_new( @@ -404,13 +404,7 @@ mod tests { let batches = stream.try_collect::>().await?; assert_batches_eq!( - &[ - "+----+", - "| id |", - "+----+", - "| 3 |", - "+----+", - ], + &["+----+", "| id |", "+----+", "| 3 |", "+----+",], &batches ); @@ -443,6 +437,7 @@ mod tests { let batches = stream.try_collect::>().await?; assert_batches_eq!( + #[rustfmt::skip] &[ "+----+", "| id |", @@ -471,6 +466,7 @@ mod tests { let result = bernoulli_sampler.sample(&batch).unwrap(); assert_batches_eq!( + #[rustfmt::skip] &[ "+----+", "| id |", @@ -479,12 +475,13 @@ mod tests { "| 5 |", "+----+", ], - &vec![result] + &[result] ); let mut poisson_sampler = PoissonSampler::try_new(0.5, 42).unwrap(); let result = poisson_sampler.sample(&batch).unwrap(); assert_batches_eq!( + #[rustfmt::skip] &[ "+----+", "| id |", @@ -492,7 +489,7 @@ mod tests { "| 3 |", "+----+", ], - &vec![result] + &[result] ); } } From a2ff496257b2e2e62e02441351c118251c8ecee2 Mon Sep 17 00:00:00 2001 From: Chongchen Chen Date: Mon, 23 Jun 2025 17:03:38 +0800 Subject: [PATCH 08/10] fmt --- datafusion/physical-plan/src/sample.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/physical-plan/src/sample.rs b/datafusion/physical-plan/src/sample.rs index 67d03b01231c..e9c63a957407 100644 --- a/datafusion/physical-plan/src/sample.rs +++ b/datafusion/physical-plan/src/sample.rs @@ -387,13 +387,13 @@ mod tests { )])); let batch = RecordBatch::try_new( - schema.clone(), + Arc::clone(&schema), vec![Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]))], )?; let input = Arc::new(crate::test::TestMemoryExec::try_new( &[vec![batch]], - schema.clone(), + Arc::clone(&schema), None, )?); From 65646515e7275dc0b164bfa3a2c55f11afec42c7 Mon Sep 17 00:00:00 2001 From: Chongchen Chen Date: Mon, 23 Jun 2025 17:38:51 +0800 Subject: [PATCH 09/10] fmt --- datafusion/core/src/dataframe/mod.rs | 5 +++ datafusion/core/tests/dataframe/mod.rs | 43 ++++++++++++++++++++++---- 2 files changed, 42 insertions(+), 6 deletions(-) diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 9423ad85d6d6..6ab03b468d28 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -2273,12 +2273,17 @@ impl DataFrame { /// # Example /// ``` /// use datafusion::prelude::*; + /// # use datafusion::error::Result; + /// # #[tokio::main] + /// # async fn main() -> Result<()> { /// let df = dataframe!( /// "id" => [1, 2, 3], /// "name" => ["foo", "bar", "baz"] /// )?; /// let df = df.sample(0.5, Some(true), Some(42))?; /// df.show().await?; + /// # Ok(()) + /// # } /// // +----+------+ /// // | id | name | /// // +----+------+ diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 03ccb09092da..a548675bb96f 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -6147,9 +6147,18 @@ async fn test_dataframe_sample() -> Result<()> { // Test sampling 20% of rows with replacement let df_sampled = df.clone().sample(0.2, Some(true), Some(42))?; assert_batches_eq!( + #[rustfmt::skip] &[ - "+----+", "| a |", "+----+", "| 8 |", "| 10 |", "| 19 |", "| 29 |", - "| 29 |", "| 36 |", "+----+", + "+----+", + "| a |", + "+----+", + "| 8 |", + "| 10 |", + "| 19 |", + "| 29 |", + "| 29 |", + "| 36 |", + "+----+", ], &df_sampled.collect().await? ); @@ -6157,9 +6166,20 @@ async fn test_dataframe_sample() -> Result<()> { // Test sampling 20% of rows without replacement let df_sampled = df.clone().sample(0.2, Some(false), Some(42))?; assert_batches_eq!( + #[rustfmt::skip] &[ - "+----+", "| a |", "+----+", "| 5 |", "| 9 |", "| 10 |", "| 14 |", - "| 17 |", "| 19 |", "| 24 |", "| 39 |", "+----+", + "+----+", + "| a |", + "+----+", + "| 5 |", + "| 9 |", + "| 10 |", + "| 14 |", + "| 17 |", + "| 19 |", + "| 24 |", + "| 39 |", + "+----+", ], &df_sampled.collect().await? ); @@ -6167,9 +6187,20 @@ async fn test_dataframe_sample() -> Result<()> { // Test sampling with None parameters (should use defaults) let df_sampled_default = df.clone().sample(0.2, None, Some(42))?; assert_batches_eq!( + #[rustfmt::skip] &[ - "+----+", "| a |", "+----+", "| 5 |", "| 9 |", "| 10 |", "| 14 |", - "| 17 |", "| 19 |", "| 24 |", "| 39 |", "+----+", + "+----+", + "| a |", + "+----+", + "| 5 |", + "| 9 |", + "| 10 |", + "| 14 |", + "| 17 |", + "| 19 |", + "| 24 |", + "| 39 |", + "+----+", ], &df_sampled_default.collect().await? ); From 1fa72f429b8a1c4309fe764385f1b5cf16b6d61c Mon Sep 17 00:00:00 2001 From: Chen Chongchen Date: Tue, 24 Jun 2025 09:51:55 +0800 Subject: [PATCH 10/10] Update mod.rs --- datafusion/sql/src/relation/mod.rs | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/datafusion/sql/src/relation/mod.rs b/datafusion/sql/src/relation/mod.rs index 7e4818daefb9..6d5a491864de 100644 --- a/datafusion/sql/src/relation/mod.rs +++ b/datafusion/sql/src/relation/mod.rs @@ -28,7 +28,8 @@ use datafusion_expr::builder::subquery_alias; use datafusion_expr::{expr::Unnest, Expr, LogicalPlan, LogicalPlanBuilder}; use datafusion_expr::{Subquery, SubqueryAlias}; use sqlparser::ast::{ - FunctionArg, FunctionArgExpr, Spanned, TableFactor, TableSampleKind, TableSampleUnit, + FunctionArg, FunctionArgExpr, Spanned, TableFactor, TableSampleKind, + TableSampleMethod, TableSampleUnit, }; mod join; @@ -244,9 +245,11 @@ impl SqlToRel<'_, S> { TableSampleKind::BeforeTableAlias(sample) => sample, TableSampleKind::AfterTableAlias(sample) => sample, }; - if sample.name.is_some() { - // Postgres-style sample. Not supported because DataFusion does not have a concept of pages like PostgreSQL. - return not_impl_err!("{} is not supported yet", sample.name.unwrap()); + if let Some(name) = &sample.name { + if *name != TableSampleMethod::Bernoulli && *name != TableSampleMethod::Row { + // Postgres-style sample. Not supported because DataFusion does not have a concept of pages like PostgreSQL. + return not_impl_err!("{} is not supported yet", name); + } } if sample.offset.is_some() { // Clickhouse-style sample. Not supported because it requires knowing the total data size.