From 590a401bc0eeb7a0f4c89d90c07f66e8c7fc0bdd Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Tue, 27 May 2025 21:49:58 +0200 Subject: [PATCH 01/10] First pass at incorporating boxed WindowFunction --- datafusion/expr/src/expr.rs | 53 +++++++++++-------- datafusion/expr/src/expr_fn.rs | 10 ++-- datafusion/expr/src/logical_plan/plan.rs | 22 ++++---- datafusion/expr/src/tree_node.rs | 37 ++++++------- datafusion/expr/src/udwf.rs | 2 +- datafusion/expr/src/utils.rs | 3 +- datafusion/functions-aggregate/src/count.rs | 2 +- datafusion/functions-window/src/planner.rs | 15 +++--- .../optimizer/src/analyzer/type_coercion.rs | 25 ++++----- datafusion/sql/src/expr/function.rs | 2 +- datafusion/sql/src/unparser/expr.rs | 29 +++++----- datafusion/sql/src/utils.rs | 20 ++++--- 12 files changed, 123 insertions(+), 97 deletions(-) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index a081a5430d40..600a6a9ea70e 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -330,7 +330,7 @@ pub enum Expr { /// [`ExprFunctionExt`]: crate::expr_fn::ExprFunctionExt AggregateFunction(AggregateFunction), /// Call a window function with a set of arguments. - WindowFunction(WindowFunction), + WindowFunction(Box), /// Returns whether the list contains the expr value. InList(InList), /// EXISTS subquery @@ -378,6 +378,13 @@ impl From for Expr { } } +/// Create an [`Expr`] from a [`WindowFunction`] +impl From for Expr { + fn from(value: WindowFunction) -> Self { + Expr::WindowFunction(Box::new(value)) + } +} + /// Create an [`Expr`] from an optional qualifier and a [`FieldRef`]. This is /// useful for creating [`Expr`] from a [`DFSchema`]. /// @@ -2086,32 +2093,27 @@ impl NormalizeEq for Expr { _ => false, } } - ( - Expr::WindowFunction(WindowFunction { + (Expr::WindowFunction(left), Expr::WindowFunction(other)) => { + let WindowFunction { fun: self_fun, - params: self_params, - }), - Expr::WindowFunction(WindowFunction { - fun: other_fun, - params: other_params, - }), - ) => { - let ( - WindowFunctionParams { + params: WindowFunctionParams { args: self_args, window_frame: self_window_frame, partition_by: self_partition_by, order_by: self_order_by, null_treatment: self_null_treatment, - }, - WindowFunctionParams { + } + } = left.as_ref(); + let WindowFunction { + fun: other_fun, + params: WindowFunctionParams { args: other_args, window_frame: other_window_frame, partition_by: other_partition_by, order_by: other_order_by, null_treatment: other_null_treatment, - }, - ) = (self_params, other_params); + } + } = other.as_ref(); self_fun.name() == other_fun.name() && self_window_frame == other_window_frame @@ -2356,14 +2358,17 @@ impl HashNode for Expr { distinct.hash(state); null_treatment.hash(state); } - Expr::WindowFunction(WindowFunction { fun, params }) => { - let WindowFunctionParams { + Expr::WindowFunction(window_fun) => { + let WindowFunction { + fun, + params: WindowFunctionParams { args: _args, partition_by: _, order_by: _, window_frame, null_treatment, - } = params; + } + } = window_fun.as_ref(); fun.hash(state); window_frame.hash(state); null_treatment.hash(state); @@ -2646,7 +2651,9 @@ impl Display for SchemaDisplay<'_> { Ok(()) } - Expr::WindowFunction(WindowFunction { fun, params }) => match fun { + Expr::WindowFunction(window_fun) => { + let WindowFunction { fun, params } = window_fun.as_ref(); + match fun { WindowFunctionDefinition::AggregateUDF(fun) => { match fun.window_function_schema_name(params) { Ok(name) => { @@ -2691,6 +2698,7 @@ impl Display for SchemaDisplay<'_> { write!(f, " {window_frame}") } + } }, } } @@ -3026,7 +3034,9 @@ impl Display for Expr { // Expr::ScalarFunction(ScalarFunction { func, args }) => { // write!(f, "{}", func.display_name(args).unwrap()) // } - Expr::WindowFunction(WindowFunction { fun, params }) => match fun { + Expr::WindowFunction(window_fun) => { + let WindowFunction { fun, params } = window_fun.as_ref(); + match fun { WindowFunctionDefinition::AggregateUDF(fun) => { match fun.window_function_display_name(params) { Ok(name) => { @@ -3066,6 +3076,7 @@ impl Display for Expr { window_frame.end_bound ) } + } }, Expr::AggregateFunction(AggregateFunction { func, params }) => { match func.display_name(params) { diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index cee356a2b42c..11eefe6eaa5b 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -831,7 +831,7 @@ impl ExprFuncBuilder { params: WindowFunctionParams { args, .. }, }) => { let has_order_by = order_by.as_ref().map(|o| !o.is_empty()); - Expr::WindowFunction(WindowFunction { + Expr::from(WindowFunction { fun, params: WindowFunctionParams { args, @@ -895,7 +895,7 @@ impl ExprFunctionExt for Expr { ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))) } Expr::WindowFunction(udwf) => { - ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf))) + ExprFuncBuilder::new(Some(ExprFuncKind::Window(*udwf))) } _ => ExprFuncBuilder::new(None), }; @@ -935,7 +935,7 @@ impl ExprFunctionExt for Expr { ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))) } Expr::WindowFunction(udwf) => { - ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf))) + ExprFuncBuilder::new(Some(ExprFuncKind::Window(*udwf))) } _ => ExprFuncBuilder::new(None), }; @@ -948,7 +948,7 @@ impl ExprFunctionExt for Expr { fn partition_by(self, partition_by: Vec) -> ExprFuncBuilder { match self { Expr::WindowFunction(udwf) => { - let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf))); + let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Window(*udwf))); builder.partition_by = Some(partition_by); builder } @@ -959,7 +959,7 @@ impl ExprFunctionExt for Expr { fn window_frame(self, window_frame: WindowFrame) -> ExprFuncBuilder { match self { Expr::WindowFunction(udwf) => { - let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf))); + let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Window(*udwf))); builder.window_frame = Some(window_frame); builder } diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 2fe1c0d7398f..1e067eea3534 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -2422,18 +2422,22 @@ impl Window { .iter() .enumerate() .filter_map(|(idx, expr)| { - if let Expr::WindowFunction(WindowFunction { + let Expr::WindowFunction(window_fun) = expr else { + return None; + }; + let WindowFunction { fun: WindowFunctionDefinition::WindowUDF(udwf), params: WindowFunctionParams { partition_by, .. }, - }) = expr - { - // When there is no PARTITION BY, row number will be unique - // across the entire table. - if udwf.name() == "row_number" && partition_by.is_empty() { - return Some(idx + input_len); - } + } = window_fun.as_ref() else { + return None; + }; + // When there is no PARTITION BY, row number will be unique + // across the entire table. + if udwf.name() == "row_number" && partition_by.is_empty() { + return Some(idx + input_len); + } else { + None } - None }) .map(|idx| { FunctionalDependence::new(vec![idx], vec![], false) diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index f20dab7e165f..c0247288aee9 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -92,14 +92,16 @@ impl TreeNode for Expr { (expr, when_then_expr, else_expr).apply_ref_elements(f), Expr::AggregateFunction(AggregateFunction { params: AggregateFunctionParams { args, filter, order_by, ..}, .. }) => (args, filter, order_by).apply_ref_elements(f), - Expr::WindowFunction(WindowFunction { - params : WindowFunctionParams { + Expr::WindowFunction(window_fun) => { + let WindowFunctionParams { args, partition_by, order_by, - ..}, ..}) => { + .. + } = &window_fun.as_ref().params; (args, partition_by, order_by).apply_ref_elements(f) } + Expr::InList(InList { expr, list, .. }) => { (expr, list).apply_ref_elements(f) } @@ -230,27 +232,26 @@ impl TreeNode for Expr { ))) })? } - Expr::WindowFunction(WindowFunction { - fun, - params: - WindowFunctionParams { + Expr::WindowFunction(window_fun) => { + let WindowFunction { fun, params : WindowFunctionParams { args, partition_by, order_by, window_frame, null_treatment, + }} = *window_fun; + (args, partition_by, order_by).map_elements(f)?.update_data( + |(new_args, new_partition_by, new_order_by)| { + Expr::from(WindowFunction::new(fun, new_args)) + .partition_by(new_partition_by) + .order_by(new_order_by) + .window_frame(window_frame) + .null_treatment(null_treatment) + .build() + .unwrap() }, - }) => (args, partition_by, order_by).map_elements(f)?.update_data( - |(new_args, new_partition_by, new_order_by)| { - Expr::WindowFunction(WindowFunction::new(fun, new_args)) - .partition_by(new_partition_by) - .order_by(new_order_by) - .window_frame(window_frame) - .null_treatment(null_treatment) - .build() - .unwrap() - }, - ), + ) + }, Expr::AggregateFunction(AggregateFunction { func, params: diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs index a52438fcc99c..04cad12a1231 100644 --- a/datafusion/expr/src/udwf.rs +++ b/datafusion/expr/src/udwf.rs @@ -133,7 +133,7 @@ impl WindowUDF { pub fn call(&self, args: Vec) -> Expr { let fun = crate::WindowFunctionDefinition::WindowUDF(Arc::new(self.clone())); - Expr::WindowFunction(WindowFunction::new(fun, args)) + Expr::from(WindowFunction::new(fun, args)) } /// Returns this function's name diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 552ce1502d46..8e02f27ed118 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -579,7 +579,8 @@ pub fn group_window_expr_by_sort_keys( ) -> Result)>> { let mut result = vec![]; window_expr.into_iter().try_for_each(|expr| match &expr { - Expr::WindowFunction( WindowFunction{ params: WindowFunctionParams { partition_by, order_by, ..}, .. }) => { + Expr::WindowFunction(window_fun) => { + let WindowFunctionParams{ partition_by, order_by, ..} = &window_fun.as_ref().params; let sort_key = generate_sort_key(partition_by, order_by)?; if let Some((_, values)) = result.iter_mut().find( |group: &&mut (WindowSortKey, Vec)| matches!(group, (key, _) if *key == sort_key), diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index 42078c735578..4ce844ad52ce 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -123,7 +123,7 @@ pub fn count_all() -> Expr { /// let expr = col(expr.schema_name().to_string()); /// ``` pub fn count_all_window() -> Expr { - Expr::WindowFunction(WindowFunction::new( + Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(count_udaf()), vec![Expr::Literal(COUNT_STAR_EXPANSION)], )) diff --git a/datafusion/functions-window/src/planner.rs b/datafusion/functions-window/src/planner.rs index 1ddd8b27c420..a6d992390545 100644 --- a/datafusion/functions-window/src/planner.rs +++ b/datafusion/functions-window/src/planner.rs @@ -43,7 +43,7 @@ impl ExprPlanner for WindowFunctionPlanner { null_treatment, } = raw_expr; - let origin_expr = Expr::WindowFunction(WindowFunction { + let origin_expr = Expr::from(WindowFunction { fun: func_def, params: WindowFunctionParams { args, @@ -56,7 +56,11 @@ impl ExprPlanner for WindowFunctionPlanner { let saved_name = NamePreserver::new_for_projection().save(&origin_expr); - let Expr::WindowFunction(WindowFunction { + let Expr::WindowFunction(window_fun) = origin_expr + else { + unreachable!("") + }; + let WindowFunction { fun, params: WindowFunctionParams { @@ -66,10 +70,7 @@ impl ExprPlanner for WindowFunctionPlanner { window_frame, null_treatment, }, - }) = origin_expr - else { - unreachable!("") - }; + } = *window_fun; let raw_expr = RawWindowExpr { func_def: fun, args, @@ -95,7 +96,7 @@ impl ExprPlanner for WindowFunctionPlanner { null_treatment, } = raw_expr; - let new_expr = Expr::WindowFunction(WindowFunction::new( + let new_expr = Expr::from(WindowFunction::new( func_def, vec![Expr::Literal(COUNT_STAR_EXPANSION)], )) diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index c17e6c766cc3..c7e7784cf863 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -539,17 +539,18 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { ), ))) } - Expr::WindowFunction(WindowFunction { - fun, - params: - expr::WindowFunctionParams { - args, - partition_by, - order_by, - window_frame, - null_treatment, - }, - }) => { + Expr::WindowFunction(window_fun) => { + let WindowFunction { + fun, + params: + expr::WindowFunctionParams { + args, + partition_by, + order_by, + window_frame, + null_treatment, + }, + } = *window_fun; let window_frame = coerce_window_frame(window_frame, self.schema, &order_by)?; @@ -565,7 +566,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { }; Ok(Transformed::yes( - Expr::WindowFunction(WindowFunction::new(fun, args)) + Expr::from(WindowFunction::new(fun, args)) .partition_by(partition_by) .order_by(order_by) .window_frame(window_frame) diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 97ff7bf19904..071e95940035 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -361,7 +361,7 @@ impl SqlToRel<'_, S> { null_treatment, } = window_expr; - return Expr::WindowFunction(expr::WindowFunction::new(func_def, args)) + return Expr::from(expr::WindowFunction::new(func_def, args)) .partition_by(partition_by) .order_by(order_by) .window_frame(window_frame) diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 41cb811d19d9..661e8581ac06 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -189,17 +189,18 @@ impl Unparser<'_> { } Expr::Literal(value) => Ok(self.scalar_to_sql(value)?), Expr::Alias(Alias { expr, name: _, .. }) => self.expr_to_sql_inner(expr), - Expr::WindowFunction(WindowFunction { - fun, - params: - WindowFunctionParams { - args, - partition_by, - order_by, - window_frame, - .. - }, - }) => { + Expr::WindowFunction(window_fun) => { + let WindowFunction { + fun, + params: + WindowFunctionParams { + args, + partition_by, + order_by, + window_frame, + .. + }, + } = window_fun.as_ref(); let func_name = fun.name(); let args = self.function_args_to_sql(args)?; @@ -2019,7 +2020,7 @@ mod tests { "count(*) FILTER (WHERE true)", ), ( - Expr::WindowFunction(WindowFunction { + Expr::from(WindowFunction { fun: WindowFunctionDefinition::WindowUDF(row_number_udwf()), params: WindowFunctionParams { args: vec![col("col")], @@ -2033,7 +2034,7 @@ mod tests { ), ( #[expect(deprecated)] - Expr::WindowFunction(WindowFunction { + Expr::from(WindowFunction { fun: WindowFunctionDefinition::AggregateUDF(count_udaf()), params: WindowFunctionParams { args: vec![Expr::Wildcard { @@ -2902,7 +2903,7 @@ mod tests { let func = WindowFunctionDefinition::WindowUDF(rank_udwf()); let mut window_func = WindowFunction::new(func, vec![]); window_func.params.order_by = vec![Sort::new(col("a"), true, true)]; - let expr = Expr::WindowFunction(window_func); + let expr = Expr::from(window_func); let ast = unparser.expr_to_sql(&expr)?; let actual = ast.to_string(); diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index 8496be1d7f9a..067da40cf9a8 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -241,15 +241,21 @@ pub fn window_expr_common_partition_keys(window_exprs: &[Expr]) -> Result<&[Expr let all_partition_keys = window_exprs .iter() .map(|expr| match expr { - Expr::WindowFunction(WindowFunction { - params: WindowFunctionParams { partition_by, .. }, - .. - }) => Ok(partition_by), - Expr::Alias(Alias { expr, .. }) => match expr.as_ref() { - Expr::WindowFunction(WindowFunction { + Expr::WindowFunction(window_fun) => { + let WindowFunction { params: WindowFunctionParams { partition_by, .. }, .. - }) => Ok(partition_by), + } = window_fun.as_ref(); + Ok(partition_by) + } + Expr::Alias(Alias { expr, .. }) => match expr.as_ref() { + Expr::WindowFunction(window_fun) => { + let WindowFunction { + params: WindowFunctionParams { partition_by, .. }, + .. + } = window_fun.as_ref(); + Ok(partition_by) + } expr => exec_err!("Impossibly got non-window expr {expr:?}"), }, expr => exec_err!("Impossibly got non-window expr {expr:?}"), From 0eaf367911d4406311429bb2d228cf2a8c59a02a Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Tue, 27 May 2025 22:25:05 +0200 Subject: [PATCH 02/10] Second pass --- datafusion/core/src/physical_planner.rs | 57 +++++++++---------- datafusion/core/tests/dataframe/mod.rs | 2 +- .../user_defined/user_defined_aggregates.rs | 8 +-- datafusion/expr/src/expr.rs | 18 ++++++ .../simplify_expressions/expr_simplifier.rs | 13 ++--- .../proto/src/logical_plan/from_proto.rs | 4 +- datafusion/proto/src/logical_plan/to_proto.rs | 25 ++++---- .../consumer/expr/window_function.rs | 2 +- 8 files changed, 73 insertions(+), 56 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index fbb4250fc4df..61d1fee79472 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -574,27 +574,25 @@ impl DefaultPhysicalPlanner { let input_exec = children.one()?; let get_sort_keys = |expr: &Expr| match expr { - Expr::WindowFunction(WindowFunction { - params: - WindowFunctionParams { - ref partition_by, - ref order_by, - .. - }, - .. - }) => generate_sort_key(partition_by, order_by), + Expr::WindowFunction(window_fun) => { + let WindowFunctionParams { + ref partition_by, + ref order_by, + .. + } = &window_fun.as_ref().params; + generate_sort_key(partition_by, order_by) + } Expr::Alias(Alias { expr, .. }) => { // Convert &Box to &T match &**expr { - Expr::WindowFunction(WindowFunction { - params: - WindowFunctionParams { - ref partition_by, - ref order_by, - .. - }, - .. - }) => generate_sort_key(partition_by, order_by), + Expr::WindowFunction(window_fun) => { + let WindowFunctionParams { + ref partition_by, + ref order_by, + .. + } = &window_fun.as_ref().params; + generate_sort_key(partition_by, order_by) + } _ => unreachable!(), } } @@ -1506,17 +1504,18 @@ pub fn create_window_expr_with_name( let name = name.into(); let physical_schema: &Schema = &logical_schema.into(); match e { - Expr::WindowFunction(WindowFunction { - fun, - params: - WindowFunctionParams { - args, - partition_by, - order_by, - window_frame, - null_treatment, - }, - }) => { + Expr::WindowFunction(window_fun) => { + let WindowFunction { + fun, + params: + WindowFunctionParams { + args, + partition_by, + order_by, + window_frame, + null_treatment, + }, + } = window_fun.as_ref(); let physical_args = create_physical_exprs(args, logical_schema, execution_props)?; let partition_by = diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index dfd11fcb096f..298c5bd63ac3 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -907,7 +907,7 @@ async fn window_using_aggregates() -> Result<()> { vec![col("c3")], ); - Expr::WindowFunction(w) + Expr::from(w) .null_treatment(NullTreatment::IgnoreNulls) .order_by(vec![col("c2").sort(true, true), col("c3").sort(true, true)]) .window_frame(WindowFrame::new_bounds( diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index 203fb6e85237..b07b5f5cc4b0 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -1104,22 +1104,22 @@ async fn test_metadata_based_aggregate_as_window() -> Result<()> { ))); let df = df.select(vec![ - Expr::WindowFunction(WindowFunction::new( + Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(Arc::clone(&no_output_meta_udf)), vec![col("no_metadata")], )) .alias("meta_no_in_no_out"), - Expr::WindowFunction(WindowFunction::new( + Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(no_output_meta_udf), vec![col("with_metadata")], )) .alias("meta_with_in_no_out"), - Expr::WindowFunction(WindowFunction::new( + Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(Arc::clone(&with_output_meta_udf)), vec![col("no_metadata")], )) .alias("meta_no_in_with_out"), - Expr::WindowFunction(WindowFunction::new( + Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(with_output_meta_udf), vec![col("with_metadata")], )) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 600a6a9ea70e..7a3ab78d7706 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -24,6 +24,7 @@ use std::mem; use std::sync::Arc; use crate::expr_fn::binary_expr; +use crate::function::WindowFunctionSimplification; use crate::logical_plan::Subquery; use crate::Volatility; use crate::{udaf, ExprSchemable, Operator, Signature, WindowFrame, WindowUDF}; @@ -882,6 +883,16 @@ impl WindowFunctionDefinition { WindowFunctionDefinition::AggregateUDF(fun) => fun.name(), } } + + /// Return the the inner window simplification function, if any + /// + /// See [`WindowFunctionSimplification`] for more information + pub fn simplify(&self) -> Option { + match self { + WindowFunctionDefinition::AggregateUDF(_) => None, + WindowFunctionDefinition::WindowUDF(udwf) => udwf.simplify(), + } + } } impl Display for WindowFunctionDefinition { @@ -953,6 +964,13 @@ impl WindowFunction { }, } } + + /// Return the the inner window simplification function, if any + /// + /// See [`WindowFunctionSimplification`] for more information + pub fn simplify(&self) -> Option { + self.fun.simplify() + } } /// EXISTS expression diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 04ca47130998..fc14496655d3 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -1523,14 +1523,13 @@ impl TreeNodeRewriter for Simplifier<'_, S> { (_, expr) => Transformed::no(expr), }, - Expr::WindowFunction(WindowFunction { - fun: WindowFunctionDefinition::WindowUDF(ref udwf), - .. - }) => match (udwf.simplify(), expr) { - (Some(simplify_function), Expr::WindowFunction(wf)) => { - Transformed::yes(simplify_function(wf, info)?) + Expr::WindowFunction(ref window_fun) => { + match (window_fun.simplify(), expr) { + (Some(simplify_function), Expr::WindowFunction(wf)) => { + Transformed::yes(simplify_function(*wf, info)?) + } + (_, expr) => Transformed::no(expr), } - (_, expr) => Transformed::no(expr), }, // diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 9f0489d6b0ea..38546fa38064 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -302,7 +302,7 @@ pub fn parse_expr( }; let args = parse_exprs(&expr.exprs, registry, codec)?; - Expr::WindowFunction(WindowFunction::new( + Expr::from(WindowFunction::new( expr::WindowFunctionDefinition::AggregateUDF(udaf_function), args, )) @@ -321,7 +321,7 @@ pub fn parse_expr( }; let args = parse_exprs(&expr.exprs, registry, codec)?; - Expr::WindowFunction(WindowFunction::new( + Expr::from(WindowFunction::new( expr::WindowFunctionDefinition::WindowUDF(udwf_function), args, )) diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 841c31fa035f..18073516610c 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -302,18 +302,19 @@ pub fn serialize_expr( expr_type: Some(ExprType::SimilarTo(pb)), } } - Expr::WindowFunction(expr::WindowFunction { - ref fun, - params: - expr::WindowFunctionParams { - ref args, - ref partition_by, - ref order_by, - ref window_frame, - // TODO: support null treatment in proto - null_treatment: _, - }, - }) => { + Expr::WindowFunction(window_fun) => { + let expr::WindowFunction { + ref fun, + params: + expr::WindowFunctionParams { + ref args, + ref partition_by, + ref order_by, + ref window_frame, + // TODO: support null treatment in proto + null_treatment: _, + }, + } = window_fun.as_ref(); let (window_function, fun_definition) = match fun { WindowFunctionDefinition::AggregateUDF(aggr_udf) => { let mut buf = Vec::new(); diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/window_function.rs b/datafusion/substrait/src/logical_plan/consumer/expr/window_function.rs index 10a92a686b59..4a7fde256b6c 100644 --- a/datafusion/substrait/src/logical_plan/consumer/expr/window_function.rs +++ b/datafusion/substrait/src/logical_plan/consumer/expr/window_function.rs @@ -99,7 +99,7 @@ pub async fn from_window_function( from_substrait_func_args(consumer, &window.arguments, input_schema).await? }; - Ok(Expr::WindowFunction(expr::WindowFunction { + Ok(Expr::from(expr::WindowFunction { fun, params: WindowFunctionParams { args, From 467d75871603c3026fac333618fa83b1ed278a4c Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Wed, 28 May 2025 07:17:26 +0200 Subject: [PATCH 03/10] Adjust tests --- datafusion/expr/src/utils.rs | 16 ++++++++-------- .../optimizer/src/optimize_projections/mod.rs | 4 ++-- datafusion/optimizer/src/push_down_filter.rs | 18 +++++++++--------- .../simplify_expressions/expr_simplifier.rs | 4 ++-- .../tests/cases/roundtrip_logical_plan.rs | 14 +++++++------- 5 files changed, 28 insertions(+), 28 deletions(-) diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 8e02f27ed118..f0e4dc6e9c63 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -1280,19 +1280,19 @@ mod tests { #[test] fn test_group_window_expr_by_sort_keys_empty_window() -> Result<()> { - let max1 = Expr::WindowFunction(WindowFunction::new( + let max1 = Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], )); - let max2 = Expr::WindowFunction(WindowFunction::new( + let max2 = Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], )); - let min3 = Expr::WindowFunction(WindowFunction::new( + let min3 = Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(min_udaf()), vec![col("name")], )); - let sum4 = Expr::WindowFunction(WindowFunction::new( + let sum4 = Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(sum_udaf()), vec![col("age")], )); @@ -1310,25 +1310,25 @@ mod tests { let age_asc = Sort::new(col("age"), true, true); let name_desc = Sort::new(col("name"), false, true); let created_at_desc = Sort::new(col("created_at"), false, true); - let max1 = Expr::WindowFunction(WindowFunction::new( + let max1 = Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], )) .order_by(vec![age_asc.clone(), name_desc.clone()]) .build() .unwrap(); - let max2 = Expr::WindowFunction(WindowFunction::new( + let max2 = Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], )); - let min3 = Expr::WindowFunction(WindowFunction::new( + let min3 = Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(min_udaf()), vec![col("name")], )) .order_by(vec![age_asc.clone(), name_desc.clone()]) .build() .unwrap(); - let sum4 = Expr::WindowFunction(WindowFunction::new( + let sum4 = Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(sum_udaf()), vec![col("age")], )) diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index 926315eb8629..ba583a8d7123 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -2144,7 +2144,7 @@ mod tests { fn test_window() -> Result<()> { let table_scan = test_table_scan()?; - let max1 = Expr::WindowFunction(expr::WindowFunction::new( + let max1 = Expr::from(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("test.a")], )) @@ -2152,7 +2152,7 @@ mod tests { .build() .unwrap(); - let max2 = Expr::WindowFunction(expr::WindowFunction::new( + let max2 = Expr::from(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("test.b")], )); diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index bbf0b0dd810e..7c352031bce6 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -1584,7 +1584,7 @@ mod tests { fn filter_move_window() -> Result<()> { let table_scan = test_table_scan()?; - let window = Expr::WindowFunction(WindowFunction::new( + let window = Expr::from(WindowFunction::new( WindowFunctionDefinition::WindowUDF( datafusion_functions_window::rank::rank_udwf(), ), @@ -1615,7 +1615,7 @@ mod tests { fn filter_move_complex_window() -> Result<()> { let table_scan = test_table_scan()?; - let window = Expr::WindowFunction(WindowFunction::new( + let window = Expr::from(WindowFunction::new( WindowFunctionDefinition::WindowUDF( datafusion_functions_window::rank::rank_udwf(), ), @@ -1645,7 +1645,7 @@ mod tests { fn filter_move_partial_window() -> Result<()> { let table_scan = test_table_scan()?; - let window = Expr::WindowFunction(WindowFunction::new( + let window = Expr::from(WindowFunction::new( WindowFunctionDefinition::WindowUDF( datafusion_functions_window::rank::rank_udwf(), ), @@ -1677,7 +1677,7 @@ mod tests { fn filter_expression_keep_window() -> Result<()> { let table_scan = test_table_scan()?; - let window = Expr::WindowFunction(WindowFunction::new( + let window = Expr::from(WindowFunction::new( WindowFunctionDefinition::WindowUDF( datafusion_functions_window::rank::rank_udwf(), ), @@ -1710,7 +1710,7 @@ mod tests { fn filter_order_keep_window() -> Result<()> { let table_scan = test_table_scan()?; - let window = Expr::WindowFunction(WindowFunction::new( + let window = Expr::from(WindowFunction::new( WindowFunctionDefinition::WindowUDF( datafusion_functions_window::rank::rank_udwf(), ), @@ -1742,7 +1742,7 @@ mod tests { fn filter_multiple_windows_common_partitions() -> Result<()> { let table_scan = test_table_scan()?; - let window1 = Expr::WindowFunction(WindowFunction::new( + let window1 = Expr::from(WindowFunction::new( WindowFunctionDefinition::WindowUDF( datafusion_functions_window::rank::rank_udwf(), ), @@ -1753,7 +1753,7 @@ mod tests { .build() .unwrap(); - let window2 = Expr::WindowFunction(WindowFunction::new( + let window2 = Expr::from(WindowFunction::new( WindowFunctionDefinition::WindowUDF( datafusion_functions_window::rank::rank_udwf(), ), @@ -1784,7 +1784,7 @@ mod tests { fn filter_multiple_windows_disjoint_partitions() -> Result<()> { let table_scan = test_table_scan()?; - let window1 = Expr::WindowFunction(WindowFunction::new( + let window1 = Expr::from(WindowFunction::new( WindowFunctionDefinition::WindowUDF( datafusion_functions_window::rank::rank_udwf(), ), @@ -1795,7 +1795,7 @@ mod tests { .build() .unwrap(); - let window2 = Expr::WindowFunction(WindowFunction::new( + let window2 = Expr::from(WindowFunction::new( WindowFunctionDefinition::WindowUDF( datafusion_functions_window::rank::rank_udwf(), ), diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index fc14496655d3..23ab4f48d09b 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -4389,7 +4389,7 @@ mod tests { WindowUDF::new_from_impl(SimplifyMockUdwf::new_with_simplify()).into(), ); let window_function_expr = - Expr::WindowFunction(WindowFunction::new(udwf, vec![])); + Expr::from(WindowFunction::new(udwf, vec![])); let expected = col("result_column"); assert_eq!(simplify(window_function_expr), expected); @@ -4398,7 +4398,7 @@ mod tests { WindowUDF::new_from_impl(SimplifyMockUdwf::new_without_simplify()).into(), ); let window_function_expr = - Expr::WindowFunction(WindowFunction::new(udwf, vec![])); + Expr::from(WindowFunction::new(udwf, vec![])); let expected = window_function_expr.clone(); assert_eq!(simplify(window_function_expr), expected); diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 369700bded04..1dbc584403b7 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -2359,7 +2359,7 @@ fn roundtrip_window() { let ctx = SessionContext::new(); // 1. without window_frame - let test_expr1 = Expr::WindowFunction(expr::WindowFunction::new( + let test_expr1 = Expr::from(expr::WindowFunction::new( WindowFunctionDefinition::WindowUDF(rank_udwf()), vec![], )) @@ -2370,7 +2370,7 @@ fn roundtrip_window() { .unwrap(); // 2. with default window_frame - let test_expr2 = Expr::WindowFunction(expr::WindowFunction::new( + let test_expr2 = Expr::from(expr::WindowFunction::new( WindowFunctionDefinition::WindowUDF(rank_udwf()), vec![], )) @@ -2387,7 +2387,7 @@ fn roundtrip_window() { WindowFrameBound::Following(ScalarValue::UInt64(Some(2))), ); - let test_expr3 = Expr::WindowFunction(expr::WindowFunction::new( + let test_expr3 = Expr::from(expr::WindowFunction::new( WindowFunctionDefinition::WindowUDF(rank_udwf()), vec![], )) @@ -2404,7 +2404,7 @@ fn roundtrip_window() { WindowFrameBound::Following(ScalarValue::UInt64(Some(2))), ); - let test_expr4 = Expr::WindowFunction(expr::WindowFunction::new( + let test_expr4 = Expr::from(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("col1")], )) @@ -2454,7 +2454,7 @@ fn roundtrip_window() { Arc::new(vec![DataType::Float64, DataType::UInt32]), ); - let test_expr5 = Expr::WindowFunction(expr::WindowFunction::new( + let test_expr5 = Expr::from(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(Arc::new(dummy_agg.clone())), vec![col("col1")], )) @@ -2535,7 +2535,7 @@ fn roundtrip_window() { let dummy_window_udf = WindowUDF::from(SimpleWindowUDF::new()); - let test_expr6 = Expr::WindowFunction(expr::WindowFunction::new( + let test_expr6 = Expr::from(expr::WindowFunction::new( WindowFunctionDefinition::WindowUDF(Arc::new(dummy_window_udf.clone())), vec![col("col1")], )) @@ -2545,7 +2545,7 @@ fn roundtrip_window() { .build() .unwrap(); - let text_expr7 = Expr::WindowFunction(expr::WindowFunction::new( + let text_expr7 = Expr::from(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(avg_udaf()), vec![col("col1")], )) From 341a08eb6e8605e44d1c2cc0e51bfd95bec35e62 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Wed, 28 May 2025 07:18:08 +0200 Subject: [PATCH 04/10] cargo fmt --- datafusion/expr/src/expr.rs | 199 ++++++++++-------- datafusion/expr/src/logical_plan/plan.rs | 3 +- datafusion/expr/src/tree_node.rs | 24 ++- datafusion/functions-window/src/planner.rs | 3 +- .../simplify_expressions/expr_simplifier.rs | 16 +- 5 files changed, 129 insertions(+), 116 deletions(-) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 7a3ab78d7706..d02392366c3e 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -883,7 +883,7 @@ impl WindowFunctionDefinition { WindowFunctionDefinition::AggregateUDF(fun) => fun.name(), } } - + /// Return the the inner window simplification function, if any /// /// See [`WindowFunctionSimplification`] for more information @@ -2114,23 +2114,25 @@ impl NormalizeEq for Expr { (Expr::WindowFunction(left), Expr::WindowFunction(other)) => { let WindowFunction { fun: self_fun, - params: WindowFunctionParams { - args: self_args, - window_frame: self_window_frame, - partition_by: self_partition_by, - order_by: self_order_by, - null_treatment: self_null_treatment, - } + params: + WindowFunctionParams { + args: self_args, + window_frame: self_window_frame, + partition_by: self_partition_by, + order_by: self_order_by, + null_treatment: self_null_treatment, + }, } = left.as_ref(); let WindowFunction { fun: other_fun, - params: WindowFunctionParams { - args: other_args, - window_frame: other_window_frame, - partition_by: other_partition_by, - order_by: other_order_by, - null_treatment: other_null_treatment, - } + params: + WindowFunctionParams { + args: other_args, + window_frame: other_window_frame, + partition_by: other_partition_by, + order_by: other_order_by, + null_treatment: other_null_treatment, + }, } = other.as_ref(); self_fun.name() == other_fun.name() @@ -2379,13 +2381,14 @@ impl HashNode for Expr { Expr::WindowFunction(window_fun) => { let WindowFunction { fun, - params: WindowFunctionParams { - args: _args, - partition_by: _, - order_by: _, - window_frame, - null_treatment, - } + params: + WindowFunctionParams { + args: _args, + partition_by: _, + order_by: _, + window_frame, + null_treatment, + }, } = window_fun.as_ref(); fun.hash(state); window_frame.hash(state); @@ -2672,52 +2675,59 @@ impl Display for SchemaDisplay<'_> { Expr::WindowFunction(window_fun) => { let WindowFunction { fun, params } = window_fun.as_ref(); match fun { - WindowFunctionDefinition::AggregateUDF(fun) => { - match fun.window_function_schema_name(params) { - Ok(name) => { - write!(f, "{name}") - } - Err(e) => { - write!(f, "got error from window_function_schema_name {e}") + WindowFunctionDefinition::AggregateUDF(fun) => { + match fun.window_function_schema_name(params) { + Ok(name) => { + write!(f, "{name}") + } + Err(e) => { + write!( + f, + "got error from window_function_schema_name {e}" + ) + } } } - } - _ => { - let WindowFunctionParams { - args, - partition_by, - order_by, - window_frame, - null_treatment, - } = params; - - write!( - f, - "{}({})", - fun, - schema_name_from_exprs_comma_separated_without_space(args)? - )?; - - if let Some(null_treatment) = null_treatment { - write!(f, " {null_treatment}")?; - } + _ => { + let WindowFunctionParams { + args, + partition_by, + order_by, + window_frame, + null_treatment, + } = params; - if !partition_by.is_empty() { write!( f, - " PARTITION BY [{}]", - schema_name_from_exprs(partition_by)? + "{}({})", + fun, + schema_name_from_exprs_comma_separated_without_space(args)? )?; - } - if !order_by.is_empty() { - write!(f, " ORDER BY [{}]", schema_name_from_sorts(order_by)?)?; - }; + if let Some(null_treatment) = null_treatment { + write!(f, " {null_treatment}")?; + } + + if !partition_by.is_empty() { + write!( + f, + " PARTITION BY [{}]", + schema_name_from_exprs(partition_by)? + )?; + } + + if !order_by.is_empty() { + write!( + f, + " ORDER BY [{}]", + schema_name_from_sorts(order_by)? + )?; + }; - write!(f, " {window_frame}") + write!(f, " {window_frame}") + } } } - }, } } } @@ -3055,47 +3065,50 @@ impl Display for Expr { Expr::WindowFunction(window_fun) => { let WindowFunction { fun, params } = window_fun.as_ref(); match fun { - WindowFunctionDefinition::AggregateUDF(fun) => { - match fun.window_function_display_name(params) { - Ok(name) => { - write!(f, "{name}") - } - Err(e) => { - write!(f, "got error from window_function_display_name {e}") + WindowFunctionDefinition::AggregateUDF(fun) => { + match fun.window_function_display_name(params) { + Ok(name) => { + write!(f, "{name}") + } + Err(e) => { + write!( + f, + "got error from window_function_display_name {e}" + ) + } } } - } - WindowFunctionDefinition::WindowUDF(fun) => { - let WindowFunctionParams { - args, - partition_by, - order_by, - window_frame, - null_treatment, - } = params; - - fmt_function(f, &fun.to_string(), false, args, true)?; - - if let Some(nt) = null_treatment { - write!(f, "{nt}")?; - } + WindowFunctionDefinition::WindowUDF(fun) => { + let WindowFunctionParams { + args, + partition_by, + order_by, + window_frame, + null_treatment, + } = params; + + fmt_function(f, &fun.to_string(), false, args, true)?; + + if let Some(nt) = null_treatment { + write!(f, "{nt}")?; + } - if !partition_by.is_empty() { - write!(f, " PARTITION BY [{}]", expr_vec_fmt!(partition_by))?; - } - if !order_by.is_empty() { - write!(f, " ORDER BY [{}]", expr_vec_fmt!(order_by))?; + if !partition_by.is_empty() { + write!(f, " PARTITION BY [{}]", expr_vec_fmt!(partition_by))?; + } + if !order_by.is_empty() { + write!(f, " ORDER BY [{}]", expr_vec_fmt!(order_by))?; + } + write!( + f, + " {} BETWEEN {} AND {}", + window_frame.units, + window_frame.start_bound, + window_frame.end_bound + ) } - write!( - f, - " {} BETWEEN {} AND {}", - window_frame.units, - window_frame.start_bound, - window_frame.end_bound - ) } } - }, Expr::AggregateFunction(AggregateFunction { func, params }) => { match func.display_name(params) { Ok(name) => { diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 1e067eea3534..c3e88a4f4012 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -2428,7 +2428,8 @@ impl Window { let WindowFunction { fun: WindowFunctionDefinition::WindowUDF(udwf), params: WindowFunctionParams { partition_by, .. }, - } = window_fun.as_ref() else { + } = window_fun.as_ref() + else { return None; }; // When there is no PARTITION BY, row number will be unique diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index c0247288aee9..bfdc19394576 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -101,7 +101,7 @@ impl TreeNode for Expr { } = &window_fun.as_ref().params; (args, partition_by, order_by).apply_ref_elements(f) } - + Expr::InList(InList { expr, list, .. }) => { (expr, list).apply_ref_elements(f) } @@ -233,14 +233,18 @@ impl TreeNode for Expr { })? } Expr::WindowFunction(window_fun) => { - let WindowFunction { fun, params : WindowFunctionParams { - args, - partition_by, - order_by, - window_frame, - null_treatment, - }} = *window_fun; - (args, partition_by, order_by).map_elements(f)?.update_data( + let WindowFunction { + fun, + params: + WindowFunctionParams { + args, + partition_by, + order_by, + window_frame, + null_treatment, + }, + } = *window_fun; + (args, partition_by, order_by).map_elements(f)?.update_data( |(new_args, new_partition_by, new_order_by)| { Expr::from(WindowFunction::new(fun, new_args)) .partition_by(new_partition_by) @@ -251,7 +255,7 @@ impl TreeNode for Expr { .unwrap() }, ) - }, + } Expr::AggregateFunction(AggregateFunction { func, params: diff --git a/datafusion/functions-window/src/planner.rs b/datafusion/functions-window/src/planner.rs index a6d992390545..8fca0114f65e 100644 --- a/datafusion/functions-window/src/planner.rs +++ b/datafusion/functions-window/src/planner.rs @@ -56,8 +56,7 @@ impl ExprPlanner for WindowFunctionPlanner { let saved_name = NamePreserver::new_for_projection().save(&origin_expr); - let Expr::WindowFunction(window_fun) = origin_expr - else { + let Expr::WindowFunction(window_fun) = origin_expr else { unreachable!("") }; let WindowFunction { diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 23ab4f48d09b..263a44fdd66e 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -1523,13 +1523,11 @@ impl TreeNodeRewriter for Simplifier<'_, S> { (_, expr) => Transformed::no(expr), }, - Expr::WindowFunction(ref window_fun) => { - match (window_fun.simplify(), expr) { - (Some(simplify_function), Expr::WindowFunction(wf)) => { - Transformed::yes(simplify_function(*wf, info)?) - } - (_, expr) => Transformed::no(expr), + Expr::WindowFunction(ref window_fun) => match (window_fun.simplify(), expr) { + (Some(simplify_function), Expr::WindowFunction(wf)) => { + Transformed::yes(simplify_function(*wf, info)?) } + (_, expr) => Transformed::no(expr), }, // @@ -4388,8 +4386,7 @@ mod tests { let udwf = WindowFunctionDefinition::WindowUDF( WindowUDF::new_from_impl(SimplifyMockUdwf::new_with_simplify()).into(), ); - let window_function_expr = - Expr::from(WindowFunction::new(udwf, vec![])); + let window_function_expr = Expr::from(WindowFunction::new(udwf, vec![])); let expected = col("result_column"); assert_eq!(simplify(window_function_expr), expected); @@ -4397,8 +4394,7 @@ mod tests { let udwf = WindowFunctionDefinition::WindowUDF( WindowUDF::new_from_impl(SimplifyMockUdwf::new_without_simplify()).into(), ); - let window_function_expr = - Expr::from(WindowFunction::new(udwf, vec![])); + let window_function_expr = Expr::from(WindowFunction::new(udwf, vec![])); let expected = window_function_expr.clone(); assert_eq!(simplify(window_function_expr), expected); From 95c87a477d700bfdf35b634e007bd747652f3c1d Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Wed, 28 May 2025 07:18:34 +0200 Subject: [PATCH 05/10] fmt --- datafusion-examples/examples/advanced_udwf.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion-examples/examples/advanced_udwf.rs b/datafusion-examples/examples/advanced_udwf.rs index 8330e783319d..0a9536c246ef 100644 --- a/datafusion-examples/examples/advanced_udwf.rs +++ b/datafusion-examples/examples/advanced_udwf.rs @@ -190,7 +190,7 @@ impl WindowUDFImpl for SimplifySmoothItUdf { /// default implementation will not be called (left as `todo!()`) fn simplify(&self) -> Option { let simplify = |window_function: WindowFunction, _: &dyn SimplifyInfo| { - Ok(Expr::WindowFunction(WindowFunction { + Ok(Expr::from(WindowFunction { fun: WindowFunctionDefinition::AggregateUDF(avg_udaf()), params: WindowFunctionParams { args: window_function.params.args, From d14128bc41335c143cddf3a1a226daa1055030fa Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Wed, 28 May 2025 19:40:25 +0200 Subject: [PATCH 06/10] Add test --- datafusion/expr/src/expr.rs | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index d02392366c3e..57da3178307e 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -3634,4 +3634,20 @@ mod test { rename: opt_rename, } } + + #[test] + fn test_size_of_expr() { + // because Expr is such a widely used struct in DataFusion + // it is important to keep its size as small as possible + // + // If this test fails when you change `Expr`, please try + // `Box`ing the fields to make `Expr` smaller + // See https://github.com/apache/datafusion/issues/14256 for details + // TODO: used to be 112 + assert_eq!(size_of::(), 144); + assert_eq!(size_of::(), 64); + assert_eq!(size_of::(), 24); // 3 ptrs + assert_eq!(size_of::>(), 24); + assert_eq!(size_of::>(), 8); + } } From 16dee4d470dda048ab4dc894b3f94e6fbaf5b7b5 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Thu, 29 May 2025 17:13:50 +0200 Subject: [PATCH 07/10] clippy --- datafusion/expr/src/logical_plan/plan.rs | 2 +- datafusion/expr/src/utils.rs | 6 ++---- .../src/simplify_expressions/expr_simplifier.rs | 10 ++++------ 3 files changed, 7 insertions(+), 11 deletions(-) diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index c3e88a4f4012..691f5684a11c 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -2435,7 +2435,7 @@ impl Window { // When there is no PARTITION BY, row number will be unique // across the entire table. if udwf.name() == "row_number" && partition_by.is_empty() { - return Some(idx + input_len); + Some(idx + input_len) } else { None } diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index f0e4dc6e9c63..bf6d65fcbd2b 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -21,7 +21,7 @@ use std::cmp::Ordering; use std::collections::{BTreeSet, HashSet}; use std::sync::Arc; -use crate::expr::{Alias, Sort, WildcardOptions, WindowFunction, WindowFunctionParams}; +use crate::expr::{Alias, Sort, WildcardOptions, WindowFunctionParams}; use crate::expr_rewriter::strip_outer_reference; use crate::{ and, BinaryExpr, Expr, ExprSchemable, Filter, GroupingSet, LogicalPlan, Operator, @@ -1264,9 +1264,7 @@ pub fn collect_subquery_cols( mod tests { use super::*; use crate::{ - col, cube, expr_vec_fmt, grouping_set, lit, rollup, - test::function_stub::max_udaf, test::function_stub::min_udaf, - test::function_stub::sum_udaf, Cast, ExprFunctionExt, WindowFunctionDefinition, + col, cube, expr::WindowFunction, expr_vec_fmt, grouping_set, lit, rollup, test::function_stub::{max_udaf, min_udaf, sum_udaf}, Cast, ExprFunctionExt, WindowFunctionDefinition }; use arrow::datatypes::{UnionFields, UnionMode}; diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 263a44fdd66e..0182b1e305f3 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -34,11 +34,11 @@ use datafusion_common::{ use datafusion_common::{internal_err, DFSchema, DataFusionError, Result, ScalarValue}; use datafusion_expr::{ and, binary::BinaryTypeCoercer, lit, or, BinaryExpr, Case, ColumnarValue, Expr, Like, - Operator, Volatility, WindowFunctionDefinition, + Operator, Volatility, }; use datafusion_expr::{expr::ScalarFunction, interval_arithmetic::NullableInterval}; use datafusion_expr::{ - expr::{InList, InSubquery, WindowFunction}, + expr::{InList, InSubquery}, utils::{iter_conjunction, iter_conjunction_owned}, }; use datafusion_expr::{simplify::ExprSimplifyResult, Cast, TryCast}; @@ -2145,12 +2145,10 @@ mod tests { use crate::test::test_table_scan_with_name; use datafusion_common::{assert_contains, DFSchemaRef, ToDFSchema}; use datafusion_expr::{ - function::{ + expr::WindowFunction, function::{ AccumulatorArgs, AggregateFunctionSimplification, WindowFunctionSimplification, - }, - interval_arithmetic::Interval, - *, + }, interval_arithmetic::Interval, * }; use datafusion_functions_window_common::field::WindowUDFFieldArgs; use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; From a416818ffcd2de562920627f1c8e405dbe2b76d3 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Thu, 29 May 2025 17:19:22 +0200 Subject: [PATCH 08/10] clippy --- datafusion/datasource/src/file_scan_config.rs | 2 +- .../optimizer/src/simplify_expressions/expr_simplifier.rs | 6 +++--- .../optimizer/src/simplify_expressions/unwrap_cast.rs | 6 +++--- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/datafusion/datasource/src/file_scan_config.rs b/datafusion/datasource/src/file_scan_config.rs index f93792b7facc..f268608cdaf3 100644 --- a/datafusion/datasource/src/file_scan_config.rs +++ b/datafusion/datasource/src/file_scan_config.rs @@ -2385,7 +2385,7 @@ mod tests { // Setup sort expression let exec_props = ExecutionProps::new(); let df_schema = DFSchema::try_from_qualified_schema("test", schema.as_ref())?; - let sort_expr = vec![col("value").sort(true, false)]; + let sort_expr = [col("value").sort(true, false)]; let physical_sort_exprs: Vec<_> = sort_expr .iter() diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 3343c99629ef..40ae25de52ef 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -1825,7 +1825,7 @@ impl TreeNodeRewriter for Simplifier<'_, S> { info, &left, op, &right, ) && op.supports_propagation() => { - unwrap_cast_in_comparison_for_binary(info, left, right, op)? + unwrap_cast_in_comparison_for_binary(info, *left, *right, op)? } // literal op try_cast/cast(expr as data_type) // --> @@ -1838,8 +1838,8 @@ impl TreeNodeRewriter for Simplifier<'_, S> { { unwrap_cast_in_comparison_for_binary( info, - right, - left, + *right, + *left, op.swap().unwrap(), )? } diff --git a/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs b/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs index 37116018cdca..b70b19bae6df 100644 --- a/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs +++ b/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs @@ -69,11 +69,11 @@ use datafusion_expr::{simplify::SimplifyInfo, Cast, Expr, Operator, TryCast}; pub(super) fn unwrap_cast_in_comparison_for_binary( info: &S, - cast_expr: Box, - literal: Box, + cast_expr: Expr, + literal: Expr, op: Operator, ) -> Result> { - match (*cast_expr, *literal) { + match (cast_expr, literal) { ( Expr::TryCast(TryCast { expr, .. }) | Expr::Cast(Cast { expr, .. }), Expr::Literal(lit_value), From fcd1cf51fa72fde06d27612f941b908512dd2781 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Thu, 29 May 2025 17:20:19 +0200 Subject: [PATCH 09/10] fmt --- datafusion/expr/src/utils.rs | 6 +++++- .../optimizer/src/simplify_expressions/expr_simplifier.rs | 7 +++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index bf6d65fcbd2b..6f44e37d0523 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -1264,7 +1264,11 @@ pub fn collect_subquery_cols( mod tests { use super::*; use crate::{ - col, cube, expr::WindowFunction, expr_vec_fmt, grouping_set, lit, rollup, test::function_stub::{max_udaf, min_udaf, sum_udaf}, Cast, ExprFunctionExt, WindowFunctionDefinition + col, cube, + expr::WindowFunction, + expr_vec_fmt, grouping_set, lit, rollup, + test::function_stub::{max_udaf, min_udaf, sum_udaf}, + Cast, ExprFunctionExt, WindowFunctionDefinition, }; use arrow::datatypes::{UnionFields, UnionMode}; diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 40ae25de52ef..fa565a973f6b 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -2147,10 +2147,13 @@ mod tests { use arrow::datatypes::FieldRef; use datafusion_common::{assert_contains, DFSchemaRef, ToDFSchema}; use datafusion_expr::{ - expr::WindowFunction, function::{ + expr::WindowFunction, + function::{ AccumulatorArgs, AggregateFunctionSimplification, WindowFunctionSimplification, - }, interval_arithmetic::Interval, * + }, + interval_arithmetic::Interval, + *, }; use datafusion_functions_window_common::field::WindowUDFFieldArgs; use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; From 3ad082c3306884d13c795e7b37c70bf79963d355 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Thu, 29 May 2025 17:56:07 +0200 Subject: [PATCH 10/10] Adjust comments --- datafusion/expr/src/expr.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 9f248de58f3b..dcd5380b4859 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -3642,8 +3642,7 @@ mod test { // // If this test fails when you change `Expr`, please try // `Box`ing the fields to make `Expr` smaller - // See https://github.com/apache/datafusion/issues/14256 for details - // TODO: used to be 112 + // See https://github.com/apache/datafusion/issues/16199 for details assert_eq!(size_of::(), 144); assert_eq!(size_of::(), 64); assert_eq!(size_of::(), 24); // 3 ptrs