diff --git a/datafusion-examples/examples/advanced_udwf.rs b/datafusion-examples/examples/advanced_udwf.rs index 4f00e04e7e99..f7316ddc1bec 100644 --- a/datafusion-examples/examples/advanced_udwf.rs +++ b/datafusion-examples/examples/advanced_udwf.rs @@ -191,7 +191,7 @@ impl WindowUDFImpl for SimplifySmoothItUdf { /// default implementation will not be called (left as `todo!()`) fn simplify(&self) -> Option { let simplify = |window_function: WindowFunction, _: &dyn SimplifyInfo| { - Ok(Expr::WindowFunction(WindowFunction { + Ok(Expr::from(WindowFunction { fun: WindowFunctionDefinition::AggregateUDF(avg_udaf()), params: WindowFunctionParams { args: window_function.params.args, diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index fbb4250fc4df..61d1fee79472 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -574,27 +574,25 @@ impl DefaultPhysicalPlanner { let input_exec = children.one()?; let get_sort_keys = |expr: &Expr| match expr { - Expr::WindowFunction(WindowFunction { - params: - WindowFunctionParams { - ref partition_by, - ref order_by, - .. - }, - .. - }) => generate_sort_key(partition_by, order_by), + Expr::WindowFunction(window_fun) => { + let WindowFunctionParams { + ref partition_by, + ref order_by, + .. + } = &window_fun.as_ref().params; + generate_sort_key(partition_by, order_by) + } Expr::Alias(Alias { expr, .. }) => { // Convert &Box to &T match &**expr { - Expr::WindowFunction(WindowFunction { - params: - WindowFunctionParams { - ref partition_by, - ref order_by, - .. - }, - .. - }) => generate_sort_key(partition_by, order_by), + Expr::WindowFunction(window_fun) => { + let WindowFunctionParams { + ref partition_by, + ref order_by, + .. + } = &window_fun.as_ref().params; + generate_sort_key(partition_by, order_by) + } _ => unreachable!(), } } @@ -1506,17 +1504,18 @@ pub fn create_window_expr_with_name( let name = name.into(); let physical_schema: &Schema = &logical_schema.into(); match e { - Expr::WindowFunction(WindowFunction { - fun, - params: - WindowFunctionParams { - args, - partition_by, - order_by, - window_frame, - null_treatment, - }, - }) => { + Expr::WindowFunction(window_fun) => { + let WindowFunction { + fun, + params: + WindowFunctionParams { + args, + partition_by, + order_by, + window_frame, + null_treatment, + }, + } = window_fun.as_ref(); let physical_args = create_physical_exprs(args, logical_schema, execution_props)?; let partition_by = diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index dfd11fcb096f..298c5bd63ac3 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -907,7 +907,7 @@ async fn window_using_aggregates() -> Result<()> { vec![col("c3")], ); - Expr::WindowFunction(w) + Expr::from(w) .null_treatment(NullTreatment::IgnoreNulls) .order_by(vec![col("c2").sort(true, true), col("c3").sort(true, true)]) .window_frame(WindowFrame::new_bounds( diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index ae517795ab95..aa5a72c0fb45 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -1106,22 +1106,22 @@ async fn test_metadata_based_aggregate_as_window() -> Result<()> { ))); let df = df.select(vec![ - Expr::WindowFunction(WindowFunction::new( + Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(Arc::clone(&no_output_meta_udf)), vec![col("no_metadata")], )) .alias("meta_no_in_no_out"), - Expr::WindowFunction(WindowFunction::new( + Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(no_output_meta_udf), vec![col("with_metadata")], )) .alias("meta_with_in_no_out"), - Expr::WindowFunction(WindowFunction::new( + Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(Arc::clone(&with_output_meta_udf)), vec![col("no_metadata")], )) .alias("meta_no_in_with_out"), - Expr::WindowFunction(WindowFunction::new( + Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(with_output_meta_udf), vec![col("with_metadata")], )) diff --git a/datafusion/datasource/src/file_scan_config.rs b/datafusion/datasource/src/file_scan_config.rs index f93792b7facc..f268608cdaf3 100644 --- a/datafusion/datasource/src/file_scan_config.rs +++ b/datafusion/datasource/src/file_scan_config.rs @@ -2385,7 +2385,7 @@ mod tests { // Setup sort expression let exec_props = ExecutionProps::new(); let df_schema = DFSchema::try_from_qualified_schema("test", schema.as_ref())?; - let sort_expr = vec![col("value").sort(true, false)]; + let sort_expr = [col("value").sort(true, false)]; let physical_sort_exprs: Vec<_> = sort_expr .iter() diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index fe5ea2ecd5b8..dcd5380b4859 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -24,6 +24,7 @@ use std::mem; use std::sync::Arc; use crate::expr_fn::binary_expr; +use crate::function::WindowFunctionSimplification; use crate::logical_plan::Subquery; use crate::Volatility; use crate::{udaf, ExprSchemable, Operator, Signature, WindowFrame, WindowUDF}; @@ -330,7 +331,7 @@ pub enum Expr { /// [`ExprFunctionExt`]: crate::expr_fn::ExprFunctionExt AggregateFunction(AggregateFunction), /// Call a window function with a set of arguments. - WindowFunction(WindowFunction), + WindowFunction(Box), /// Returns whether the list contains the expr value. InList(InList), /// EXISTS subquery @@ -378,6 +379,13 @@ impl From for Expr { } } +/// Create an [`Expr`] from a [`WindowFunction`] +impl From for Expr { + fn from(value: WindowFunction) -> Self { + Expr::WindowFunction(Box::new(value)) + } +} + /// Create an [`Expr`] from an optional qualifier and a [`FieldRef`]. This is /// useful for creating [`Expr`] from a [`DFSchema`]. /// @@ -875,6 +883,16 @@ impl WindowFunctionDefinition { WindowFunctionDefinition::AggregateUDF(fun) => fun.name(), } } + + /// Return the the inner window simplification function, if any + /// + /// See [`WindowFunctionSimplification`] for more information + pub fn simplify(&self) -> Option { + match self { + WindowFunctionDefinition::AggregateUDF(_) => None, + WindowFunctionDefinition::WindowUDF(udwf) => udwf.simplify(), + } + } } impl Display for WindowFunctionDefinition { @@ -946,6 +964,13 @@ impl WindowFunction { }, } } + + /// Return the the inner window simplification function, if any + /// + /// See [`WindowFunctionSimplification`] for more information + pub fn simplify(&self) -> Option { + self.fun.simplify() + } } /// EXISTS expression @@ -2086,32 +2111,29 @@ impl NormalizeEq for Expr { _ => false, } } - ( - Expr::WindowFunction(WindowFunction { + (Expr::WindowFunction(left), Expr::WindowFunction(other)) => { + let WindowFunction { fun: self_fun, - params: self_params, - }), - Expr::WindowFunction(WindowFunction { + params: + WindowFunctionParams { + args: self_args, + window_frame: self_window_frame, + partition_by: self_partition_by, + order_by: self_order_by, + null_treatment: self_null_treatment, + }, + } = left.as_ref(); + let WindowFunction { fun: other_fun, - params: other_params, - }), - ) => { - let ( - WindowFunctionParams { - args: self_args, - window_frame: self_window_frame, - partition_by: self_partition_by, - order_by: self_order_by, - null_treatment: self_null_treatment, - }, - WindowFunctionParams { - args: other_args, - window_frame: other_window_frame, - partition_by: other_partition_by, - order_by: other_order_by, - null_treatment: other_null_treatment, - }, - ) = (self_params, other_params); + params: + WindowFunctionParams { + args: other_args, + window_frame: other_window_frame, + partition_by: other_partition_by, + order_by: other_order_by, + null_treatment: other_null_treatment, + }, + } = other.as_ref(); self_fun.name() == other_fun.name() && self_window_frame == other_window_frame @@ -2356,14 +2378,18 @@ impl HashNode for Expr { distinct.hash(state); null_treatment.hash(state); } - Expr::WindowFunction(WindowFunction { fun, params }) => { - let WindowFunctionParams { - args: _args, - partition_by: _, - order_by: _, - window_frame, - null_treatment, - } = params; + Expr::WindowFunction(window_fun) => { + let WindowFunction { + fun, + params: + WindowFunctionParams { + args: _args, + partition_by: _, + order_by: _, + window_frame, + null_treatment, + }, + } = window_fun.as_ref(); fun.hash(state); window_frame.hash(state); null_treatment.hash(state); @@ -2646,52 +2672,62 @@ impl Display for SchemaDisplay<'_> { Ok(()) } - Expr::WindowFunction(WindowFunction { fun, params }) => match fun { - WindowFunctionDefinition::AggregateUDF(fun) => { - match fun.window_function_schema_name(params) { - Ok(name) => { - write!(f, "{name}") + Expr::WindowFunction(window_fun) => { + let WindowFunction { fun, params } = window_fun.as_ref(); + match fun { + WindowFunctionDefinition::AggregateUDF(fun) => { + match fun.window_function_schema_name(params) { + Ok(name) => { + write!(f, "{name}") + } + Err(e) => { + write!( + f, + "got error from window_function_schema_name {e}" + ) + } } - Err(e) => { - write!(f, "got error from window_function_schema_name {e}") - } - } - } - _ => { - let WindowFunctionParams { - args, - partition_by, - order_by, - window_frame, - null_treatment, - } = params; - - write!( - f, - "{}({})", - fun, - schema_name_from_exprs_comma_separated_without_space(args)? - )?; - - if let Some(null_treatment) = null_treatment { - write!(f, " {null_treatment}")?; } + _ => { + let WindowFunctionParams { + args, + partition_by, + order_by, + window_frame, + null_treatment, + } = params; - if !partition_by.is_empty() { write!( f, - " PARTITION BY [{}]", - schema_name_from_exprs(partition_by)? + "{}({})", + fun, + schema_name_from_exprs_comma_separated_without_space(args)? )?; - } - if !order_by.is_empty() { - write!(f, " ORDER BY [{}]", schema_name_from_sorts(order_by)?)?; - }; + if let Some(null_treatment) = null_treatment { + write!(f, " {null_treatment}")?; + } + + if !partition_by.is_empty() { + write!( + f, + " PARTITION BY [{}]", + schema_name_from_exprs(partition_by)? + )?; + } - write!(f, " {window_frame}") + if !order_by.is_empty() { + write!( + f, + " ORDER BY [{}]", + schema_name_from_sorts(order_by)? + )?; + }; + + write!(f, " {window_frame}") + } } - }, + } } } } @@ -3026,47 +3062,53 @@ impl Display for Expr { // Expr::ScalarFunction(ScalarFunction { func, args }) => { // write!(f, "{}", func.display_name(args).unwrap()) // } - Expr::WindowFunction(WindowFunction { fun, params }) => match fun { - WindowFunctionDefinition::AggregateUDF(fun) => { - match fun.window_function_display_name(params) { - Ok(name) => { - write!(f, "{name}") - } - Err(e) => { - write!(f, "got error from window_function_display_name {e}") + Expr::WindowFunction(window_fun) => { + let WindowFunction { fun, params } = window_fun.as_ref(); + match fun { + WindowFunctionDefinition::AggregateUDF(fun) => { + match fun.window_function_display_name(params) { + Ok(name) => { + write!(f, "{name}") + } + Err(e) => { + write!( + f, + "got error from window_function_display_name {e}" + ) + } } } - } - WindowFunctionDefinition::WindowUDF(fun) => { - let WindowFunctionParams { - args, - partition_by, - order_by, - window_frame, - null_treatment, - } = params; - - fmt_function(f, &fun.to_string(), false, args, true)?; - - if let Some(nt) = null_treatment { - write!(f, "{nt}")?; - } + WindowFunctionDefinition::WindowUDF(fun) => { + let WindowFunctionParams { + args, + partition_by, + order_by, + window_frame, + null_treatment, + } = params; + + fmt_function(f, &fun.to_string(), false, args, true)?; + + if let Some(nt) = null_treatment { + write!(f, "{nt}")?; + } - if !partition_by.is_empty() { - write!(f, " PARTITION BY [{}]", expr_vec_fmt!(partition_by))?; - } - if !order_by.is_empty() { - write!(f, " ORDER BY [{}]", expr_vec_fmt!(order_by))?; + if !partition_by.is_empty() { + write!(f, " PARTITION BY [{}]", expr_vec_fmt!(partition_by))?; + } + if !order_by.is_empty() { + write!(f, " ORDER BY [{}]", expr_vec_fmt!(order_by))?; + } + write!( + f, + " {} BETWEEN {} AND {}", + window_frame.units, + window_frame.start_bound, + window_frame.end_bound + ) } - write!( - f, - " {} BETWEEN {} AND {}", - window_frame.units, - window_frame.start_bound, - window_frame.end_bound - ) } - }, + } Expr::AggregateFunction(AggregateFunction { func, params }) => { match func.display_name(params) { Ok(name) => { @@ -3592,4 +3634,19 @@ mod test { rename: opt_rename, } } + + #[test] + fn test_size_of_expr() { + // because Expr is such a widely used struct in DataFusion + // it is important to keep its size as small as possible + // + // If this test fails when you change `Expr`, please try + // `Box`ing the fields to make `Expr` smaller + // See https://github.com/apache/datafusion/issues/16199 for details + assert_eq!(size_of::(), 144); + assert_eq!(size_of::(), 64); + assert_eq!(size_of::(), 24); // 3 ptrs + assert_eq!(size_of::>(), 24); + assert_eq!(size_of::>(), 8); + } } diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 67e80a8d9bba..5182ccb15c0a 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -832,7 +832,7 @@ impl ExprFuncBuilder { params: WindowFunctionParams { args, .. }, }) => { let has_order_by = order_by.as_ref().map(|o| !o.is_empty()); - Expr::WindowFunction(WindowFunction { + Expr::from(WindowFunction { fun, params: WindowFunctionParams { args, @@ -896,7 +896,7 @@ impl ExprFunctionExt for Expr { ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))) } Expr::WindowFunction(udwf) => { - ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf))) + ExprFuncBuilder::new(Some(ExprFuncKind::Window(*udwf))) } _ => ExprFuncBuilder::new(None), }; @@ -936,7 +936,7 @@ impl ExprFunctionExt for Expr { ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))) } Expr::WindowFunction(udwf) => { - ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf))) + ExprFuncBuilder::new(Some(ExprFuncKind::Window(*udwf))) } _ => ExprFuncBuilder::new(None), }; @@ -949,7 +949,7 @@ impl ExprFunctionExt for Expr { fn partition_by(self, partition_by: Vec) -> ExprFuncBuilder { match self { Expr::WindowFunction(udwf) => { - let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf))); + let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Window(*udwf))); builder.partition_by = Some(partition_by); builder } @@ -960,7 +960,7 @@ impl ExprFunctionExt for Expr { fn window_frame(self, window_frame: WindowFrame) -> ExprFuncBuilder { match self { Expr::WindowFunction(udwf) => { - let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf))); + let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Window(*udwf))); builder.window_frame = Some(window_frame); builder } diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 2fe1c0d7398f..691f5684a11c 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -2422,18 +2422,23 @@ impl Window { .iter() .enumerate() .filter_map(|(idx, expr)| { - if let Expr::WindowFunction(WindowFunction { + let Expr::WindowFunction(window_fun) = expr else { + return None; + }; + let WindowFunction { fun: WindowFunctionDefinition::WindowUDF(udwf), params: WindowFunctionParams { partition_by, .. }, - }) = expr - { - // When there is no PARTITION BY, row number will be unique - // across the entire table. - if udwf.name() == "row_number" && partition_by.is_empty() { - return Some(idx + input_len); - } + } = window_fun.as_ref() + else { + return None; + }; + // When there is no PARTITION BY, row number will be unique + // across the entire table. + if udwf.name() == "row_number" && partition_by.is_empty() { + Some(idx + input_len) + } else { + None } - None }) .map(|idx| { FunctionalDependence::new(vec![idx], vec![], false) diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index f20dab7e165f..bfdc19394576 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -92,14 +92,16 @@ impl TreeNode for Expr { (expr, when_then_expr, else_expr).apply_ref_elements(f), Expr::AggregateFunction(AggregateFunction { params: AggregateFunctionParams { args, filter, order_by, ..}, .. }) => (args, filter, order_by).apply_ref_elements(f), - Expr::WindowFunction(WindowFunction { - params : WindowFunctionParams { + Expr::WindowFunction(window_fun) => { + let WindowFunctionParams { args, partition_by, order_by, - ..}, ..}) => { + .. + } = &window_fun.as_ref().params; (args, partition_by, order_by).apply_ref_elements(f) } + Expr::InList(InList { expr, list, .. }) => { (expr, list).apply_ref_elements(f) } @@ -230,27 +232,30 @@ impl TreeNode for Expr { ))) })? } - Expr::WindowFunction(WindowFunction { - fun, - params: - WindowFunctionParams { - args, - partition_by, - order_by, - window_frame, - null_treatment, + Expr::WindowFunction(window_fun) => { + let WindowFunction { + fun, + params: + WindowFunctionParams { + args, + partition_by, + order_by, + window_frame, + null_treatment, + }, + } = *window_fun; + (args, partition_by, order_by).map_elements(f)?.update_data( + |(new_args, new_partition_by, new_order_by)| { + Expr::from(WindowFunction::new(fun, new_args)) + .partition_by(new_partition_by) + .order_by(new_order_by) + .window_frame(window_frame) + .null_treatment(null_treatment) + .build() + .unwrap() }, - }) => (args, partition_by, order_by).map_elements(f)?.update_data( - |(new_args, new_partition_by, new_order_by)| { - Expr::WindowFunction(WindowFunction::new(fun, new_args)) - .partition_by(new_partition_by) - .order_by(new_order_by) - .window_frame(window_frame) - .null_treatment(null_treatment) - .build() - .unwrap() - }, - ), + ) + } Expr::AggregateFunction(AggregateFunction { func, params: diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs index c0187735d602..155de232285e 100644 --- a/datafusion/expr/src/udwf.rs +++ b/datafusion/expr/src/udwf.rs @@ -133,7 +133,7 @@ impl WindowUDF { pub fn call(&self, args: Vec) -> Expr { let fun = crate::WindowFunctionDefinition::WindowUDF(Arc::new(self.clone())); - Expr::WindowFunction(WindowFunction::new(fun, args)) + Expr::from(WindowFunction::new(fun, args)) } /// Returns this function's name diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 552ce1502d46..6f44e37d0523 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -21,7 +21,7 @@ use std::cmp::Ordering; use std::collections::{BTreeSet, HashSet}; use std::sync::Arc; -use crate::expr::{Alias, Sort, WildcardOptions, WindowFunction, WindowFunctionParams}; +use crate::expr::{Alias, Sort, WildcardOptions, WindowFunctionParams}; use crate::expr_rewriter::strip_outer_reference; use crate::{ and, BinaryExpr, Expr, ExprSchemable, Filter, GroupingSet, LogicalPlan, Operator, @@ -579,7 +579,8 @@ pub fn group_window_expr_by_sort_keys( ) -> Result)>> { let mut result = vec![]; window_expr.into_iter().try_for_each(|expr| match &expr { - Expr::WindowFunction( WindowFunction{ params: WindowFunctionParams { partition_by, order_by, ..}, .. }) => { + Expr::WindowFunction(window_fun) => { + let WindowFunctionParams{ partition_by, order_by, ..} = &window_fun.as_ref().params; let sort_key = generate_sort_key(partition_by, order_by)?; if let Some((_, values)) = result.iter_mut().find( |group: &&mut (WindowSortKey, Vec)| matches!(group, (key, _) if *key == sort_key), @@ -1263,9 +1264,11 @@ pub fn collect_subquery_cols( mod tests { use super::*; use crate::{ - col, cube, expr_vec_fmt, grouping_set, lit, rollup, - test::function_stub::max_udaf, test::function_stub::min_udaf, - test::function_stub::sum_udaf, Cast, ExprFunctionExt, WindowFunctionDefinition, + col, cube, + expr::WindowFunction, + expr_vec_fmt, grouping_set, lit, rollup, + test::function_stub::{max_udaf, min_udaf, sum_udaf}, + Cast, ExprFunctionExt, WindowFunctionDefinition, }; use arrow::datatypes::{UnionFields, UnionMode}; @@ -1279,19 +1282,19 @@ mod tests { #[test] fn test_group_window_expr_by_sort_keys_empty_window() -> Result<()> { - let max1 = Expr::WindowFunction(WindowFunction::new( + let max1 = Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], )); - let max2 = Expr::WindowFunction(WindowFunction::new( + let max2 = Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], )); - let min3 = Expr::WindowFunction(WindowFunction::new( + let min3 = Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(min_udaf()), vec![col("name")], )); - let sum4 = Expr::WindowFunction(WindowFunction::new( + let sum4 = Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(sum_udaf()), vec![col("age")], )); @@ -1309,25 +1312,25 @@ mod tests { let age_asc = Sort::new(col("age"), true, true); let name_desc = Sort::new(col("name"), false, true); let created_at_desc = Sort::new(col("created_at"), false, true); - let max1 = Expr::WindowFunction(WindowFunction::new( + let max1 = Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], )) .order_by(vec![age_asc.clone(), name_desc.clone()]) .build() .unwrap(); - let max2 = Expr::WindowFunction(WindowFunction::new( + let max2 = Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], )); - let min3 = Expr::WindowFunction(WindowFunction::new( + let min3 = Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(min_udaf()), vec![col("name")], )) .order_by(vec![age_asc.clone(), name_desc.clone()]) .build() .unwrap(); - let sum4 = Expr::WindowFunction(WindowFunction::new( + let sum4 = Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(sum_udaf()), vec![col("age")], )) diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index eccd0cd05187..df31465e4a37 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -124,7 +124,7 @@ pub fn count_all() -> Expr { /// let expr = col(expr.schema_name().to_string()); /// ``` pub fn count_all_window() -> Expr { - Expr::WindowFunction(WindowFunction::new( + Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(count_udaf()), vec![Expr::Literal(COUNT_STAR_EXPANSION)], )) diff --git a/datafusion/functions-window/src/planner.rs b/datafusion/functions-window/src/planner.rs index 1ddd8b27c420..8fca0114f65e 100644 --- a/datafusion/functions-window/src/planner.rs +++ b/datafusion/functions-window/src/planner.rs @@ -43,7 +43,7 @@ impl ExprPlanner for WindowFunctionPlanner { null_treatment, } = raw_expr; - let origin_expr = Expr::WindowFunction(WindowFunction { + let origin_expr = Expr::from(WindowFunction { fun: func_def, params: WindowFunctionParams { args, @@ -56,7 +56,10 @@ impl ExprPlanner for WindowFunctionPlanner { let saved_name = NamePreserver::new_for_projection().save(&origin_expr); - let Expr::WindowFunction(WindowFunction { + let Expr::WindowFunction(window_fun) = origin_expr else { + unreachable!("") + }; + let WindowFunction { fun, params: WindowFunctionParams { @@ -66,10 +69,7 @@ impl ExprPlanner for WindowFunctionPlanner { window_frame, null_treatment, }, - }) = origin_expr - else { - unreachable!("") - }; + } = *window_fun; let raw_expr = RawWindowExpr { func_def: fun, args, @@ -95,7 +95,7 @@ impl ExprPlanner for WindowFunctionPlanner { null_treatment, } = raw_expr; - let new_expr = Expr::WindowFunction(WindowFunction::new( + let new_expr = Expr::from(WindowFunction::new( func_def, vec![Expr::Literal(COUNT_STAR_EXPANSION)], )) diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index c8246ecebd54..7034982956ae 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -539,17 +539,18 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { ), ))) } - Expr::WindowFunction(WindowFunction { - fun, - params: - expr::WindowFunctionParams { - args, - partition_by, - order_by, - window_frame, - null_treatment, - }, - }) => { + Expr::WindowFunction(window_fun) => { + let WindowFunction { + fun, + params: + expr::WindowFunctionParams { + args, + partition_by, + order_by, + window_frame, + null_treatment, + }, + } = *window_fun; let window_frame = coerce_window_frame(window_frame, self.schema, &order_by)?; @@ -565,7 +566,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { }; Ok(Transformed::yes( - Expr::WindowFunction(WindowFunction::new(fun, args)) + Expr::from(WindowFunction::new(fun, args)) .partition_by(partition_by) .order_by(order_by) .window_frame(window_frame) diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index 926315eb8629..ba583a8d7123 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -2144,7 +2144,7 @@ mod tests { fn test_window() -> Result<()> { let table_scan = test_table_scan()?; - let max1 = Expr::WindowFunction(expr::WindowFunction::new( + let max1 = Expr::from(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("test.a")], )) @@ -2152,7 +2152,7 @@ mod tests { .build() .unwrap(); - let max2 = Expr::WindowFunction(expr::WindowFunction::new( + let max2 = Expr::from(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("test.b")], )); diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index bbf0b0dd810e..7c352031bce6 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -1584,7 +1584,7 @@ mod tests { fn filter_move_window() -> Result<()> { let table_scan = test_table_scan()?; - let window = Expr::WindowFunction(WindowFunction::new( + let window = Expr::from(WindowFunction::new( WindowFunctionDefinition::WindowUDF( datafusion_functions_window::rank::rank_udwf(), ), @@ -1615,7 +1615,7 @@ mod tests { fn filter_move_complex_window() -> Result<()> { let table_scan = test_table_scan()?; - let window = Expr::WindowFunction(WindowFunction::new( + let window = Expr::from(WindowFunction::new( WindowFunctionDefinition::WindowUDF( datafusion_functions_window::rank::rank_udwf(), ), @@ -1645,7 +1645,7 @@ mod tests { fn filter_move_partial_window() -> Result<()> { let table_scan = test_table_scan()?; - let window = Expr::WindowFunction(WindowFunction::new( + let window = Expr::from(WindowFunction::new( WindowFunctionDefinition::WindowUDF( datafusion_functions_window::rank::rank_udwf(), ), @@ -1677,7 +1677,7 @@ mod tests { fn filter_expression_keep_window() -> Result<()> { let table_scan = test_table_scan()?; - let window = Expr::WindowFunction(WindowFunction::new( + let window = Expr::from(WindowFunction::new( WindowFunctionDefinition::WindowUDF( datafusion_functions_window::rank::rank_udwf(), ), @@ -1710,7 +1710,7 @@ mod tests { fn filter_order_keep_window() -> Result<()> { let table_scan = test_table_scan()?; - let window = Expr::WindowFunction(WindowFunction::new( + let window = Expr::from(WindowFunction::new( WindowFunctionDefinition::WindowUDF( datafusion_functions_window::rank::rank_udwf(), ), @@ -1742,7 +1742,7 @@ mod tests { fn filter_multiple_windows_common_partitions() -> Result<()> { let table_scan = test_table_scan()?; - let window1 = Expr::WindowFunction(WindowFunction::new( + let window1 = Expr::from(WindowFunction::new( WindowFunctionDefinition::WindowUDF( datafusion_functions_window::rank::rank_udwf(), ), @@ -1753,7 +1753,7 @@ mod tests { .build() .unwrap(); - let window2 = Expr::WindowFunction(WindowFunction::new( + let window2 = Expr::from(WindowFunction::new( WindowFunctionDefinition::WindowUDF( datafusion_functions_window::rank::rank_udwf(), ), @@ -1784,7 +1784,7 @@ mod tests { fn filter_multiple_windows_disjoint_partitions() -> Result<()> { let table_scan = test_table_scan()?; - let window1 = Expr::WindowFunction(WindowFunction::new( + let window1 = Expr::from(WindowFunction::new( WindowFunctionDefinition::WindowUDF( datafusion_functions_window::rank::rank_udwf(), ), @@ -1795,7 +1795,7 @@ mod tests { .build() .unwrap(); - let window2 = Expr::WindowFunction(WindowFunction::new( + let window2 = Expr::from(WindowFunction::new( WindowFunctionDefinition::WindowUDF( datafusion_functions_window::rank::rank_udwf(), ), diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 4e4e3d316c26..fa565a973f6b 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -34,11 +34,11 @@ use datafusion_common::{ use datafusion_common::{internal_err, DFSchema, DataFusionError, Result, ScalarValue}; use datafusion_expr::{ and, binary::BinaryTypeCoercer, lit, or, BinaryExpr, Case, ColumnarValue, Expr, Like, - Operator, Volatility, WindowFunctionDefinition, + Operator, Volatility, }; use datafusion_expr::{expr::ScalarFunction, interval_arithmetic::NullableInterval}; use datafusion_expr::{ - expr::{InList, InSubquery, WindowFunction}, + expr::{InList, InSubquery}, utils::{iter_conjunction, iter_conjunction_owned}, }; use datafusion_expr::{simplify::ExprSimplifyResult, Cast, TryCast}; @@ -1523,12 +1523,9 @@ impl TreeNodeRewriter for Simplifier<'_, S> { (_, expr) => Transformed::no(expr), }, - Expr::WindowFunction(WindowFunction { - fun: WindowFunctionDefinition::WindowUDF(ref udwf), - .. - }) => match (udwf.simplify(), expr) { + Expr::WindowFunction(ref window_fun) => match (window_fun.simplify(), expr) { (Some(simplify_function), Expr::WindowFunction(wf)) => { - Transformed::yes(simplify_function(wf, info)?) + Transformed::yes(simplify_function(*wf, info)?) } (_, expr) => Transformed::no(expr), }, @@ -1828,7 +1825,7 @@ impl TreeNodeRewriter for Simplifier<'_, S> { info, &left, op, &right, ) && op.supports_propagation() => { - unwrap_cast_in_comparison_for_binary(info, left, right, op)? + unwrap_cast_in_comparison_for_binary(info, *left, *right, op)? } // literal op try_cast/cast(expr as data_type) // --> @@ -1841,8 +1838,8 @@ impl TreeNodeRewriter for Simplifier<'_, S> { { unwrap_cast_in_comparison_for_binary( info, - right, - left, + *right, + *left, op.swap().unwrap(), )? } @@ -2150,6 +2147,7 @@ mod tests { use arrow::datatypes::FieldRef; use datafusion_common::{assert_contains, DFSchemaRef, ToDFSchema}; use datafusion_expr::{ + expr::WindowFunction, function::{ AccumulatorArgs, AggregateFunctionSimplification, WindowFunctionSimplification, @@ -4389,8 +4387,7 @@ mod tests { let udwf = WindowFunctionDefinition::WindowUDF( WindowUDF::new_from_impl(SimplifyMockUdwf::new_with_simplify()).into(), ); - let window_function_expr = - Expr::WindowFunction(WindowFunction::new(udwf, vec![])); + let window_function_expr = Expr::from(WindowFunction::new(udwf, vec![])); let expected = col("result_column"); assert_eq!(simplify(window_function_expr), expected); @@ -4398,8 +4395,7 @@ mod tests { let udwf = WindowFunctionDefinition::WindowUDF( WindowUDF::new_from_impl(SimplifyMockUdwf::new_without_simplify()).into(), ); - let window_function_expr = - Expr::WindowFunction(WindowFunction::new(udwf, vec![])); + let window_function_expr = Expr::from(WindowFunction::new(udwf, vec![])); let expected = window_function_expr.clone(); assert_eq!(simplify(window_function_expr), expected); diff --git a/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs b/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs index 37116018cdca..b70b19bae6df 100644 --- a/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs +++ b/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs @@ -69,11 +69,11 @@ use datafusion_expr::{simplify::SimplifyInfo, Cast, Expr, Operator, TryCast}; pub(super) fn unwrap_cast_in_comparison_for_binary( info: &S, - cast_expr: Box, - literal: Box, + cast_expr: Expr, + literal: Expr, op: Operator, ) -> Result> { - match (*cast_expr, *literal) { + match (cast_expr, literal) { ( Expr::TryCast(TryCast { expr, .. }) | Expr::Cast(Cast { expr, .. }), Expr::Literal(lit_value), diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 9f0489d6b0ea..38546fa38064 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -302,7 +302,7 @@ pub fn parse_expr( }; let args = parse_exprs(&expr.exprs, registry, codec)?; - Expr::WindowFunction(WindowFunction::new( + Expr::from(WindowFunction::new( expr::WindowFunctionDefinition::AggregateUDF(udaf_function), args, )) @@ -321,7 +321,7 @@ pub fn parse_expr( }; let args = parse_exprs(&expr.exprs, registry, codec)?; - Expr::WindowFunction(WindowFunction::new( + Expr::from(WindowFunction::new( expr::WindowFunctionDefinition::WindowUDF(udwf_function), args, )) diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 841c31fa035f..18073516610c 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -302,18 +302,19 @@ pub fn serialize_expr( expr_type: Some(ExprType::SimilarTo(pb)), } } - Expr::WindowFunction(expr::WindowFunction { - ref fun, - params: - expr::WindowFunctionParams { - ref args, - ref partition_by, - ref order_by, - ref window_frame, - // TODO: support null treatment in proto - null_treatment: _, - }, - }) => { + Expr::WindowFunction(window_fun) => { + let expr::WindowFunction { + ref fun, + params: + expr::WindowFunctionParams { + ref args, + ref partition_by, + ref order_by, + ref window_frame, + // TODO: support null treatment in proto + null_treatment: _, + }, + } = window_fun.as_ref(); let (window_function, fun_definition) = match fun { WindowFunctionDefinition::AggregateUDF(aggr_udf) => { let mut buf = Vec::new(); diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index b515ef6e38de..3edf152f4c71 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -2359,7 +2359,7 @@ fn roundtrip_window() { let ctx = SessionContext::new(); // 1. without window_frame - let test_expr1 = Expr::WindowFunction(expr::WindowFunction::new( + let test_expr1 = Expr::from(expr::WindowFunction::new( WindowFunctionDefinition::WindowUDF(rank_udwf()), vec![], )) @@ -2370,7 +2370,7 @@ fn roundtrip_window() { .unwrap(); // 2. with default window_frame - let test_expr2 = Expr::WindowFunction(expr::WindowFunction::new( + let test_expr2 = Expr::from(expr::WindowFunction::new( WindowFunctionDefinition::WindowUDF(rank_udwf()), vec![], )) @@ -2387,7 +2387,7 @@ fn roundtrip_window() { WindowFrameBound::Following(ScalarValue::UInt64(Some(2))), ); - let test_expr3 = Expr::WindowFunction(expr::WindowFunction::new( + let test_expr3 = Expr::from(expr::WindowFunction::new( WindowFunctionDefinition::WindowUDF(rank_udwf()), vec![], )) @@ -2404,7 +2404,7 @@ fn roundtrip_window() { WindowFrameBound::Following(ScalarValue::UInt64(Some(2))), ); - let test_expr4 = Expr::WindowFunction(expr::WindowFunction::new( + let test_expr4 = Expr::from(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("col1")], )) @@ -2454,7 +2454,7 @@ fn roundtrip_window() { Arc::new(vec![DataType::Float64, DataType::UInt32]), ); - let test_expr5 = Expr::WindowFunction(expr::WindowFunction::new( + let test_expr5 = Expr::from(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(Arc::new(dummy_agg.clone())), vec![col("col1")], )) @@ -2539,7 +2539,7 @@ fn roundtrip_window() { let dummy_window_udf = WindowUDF::from(SimpleWindowUDF::new()); - let test_expr6 = Expr::WindowFunction(expr::WindowFunction::new( + let test_expr6 = Expr::from(expr::WindowFunction::new( WindowFunctionDefinition::WindowUDF(Arc::new(dummy_window_udf.clone())), vec![col("col1")], )) @@ -2549,7 +2549,7 @@ fn roundtrip_window() { .build() .unwrap(); - let text_expr7 = Expr::WindowFunction(expr::WindowFunction::new( + let text_expr7 = Expr::from(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(avg_udaf()), vec![col("col1")], )) diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 97ff7bf19904..071e95940035 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -361,7 +361,7 @@ impl SqlToRel<'_, S> { null_treatment, } = window_expr; - return Expr::WindowFunction(expr::WindowFunction::new(func_def, args)) + return Expr::from(expr::WindowFunction::new(func_def, args)) .partition_by(partition_by) .order_by(order_by) .window_frame(window_frame) diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 41cb811d19d9..661e8581ac06 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -189,17 +189,18 @@ impl Unparser<'_> { } Expr::Literal(value) => Ok(self.scalar_to_sql(value)?), Expr::Alias(Alias { expr, name: _, .. }) => self.expr_to_sql_inner(expr), - Expr::WindowFunction(WindowFunction { - fun, - params: - WindowFunctionParams { - args, - partition_by, - order_by, - window_frame, - .. - }, - }) => { + Expr::WindowFunction(window_fun) => { + let WindowFunction { + fun, + params: + WindowFunctionParams { + args, + partition_by, + order_by, + window_frame, + .. + }, + } = window_fun.as_ref(); let func_name = fun.name(); let args = self.function_args_to_sql(args)?; @@ -2019,7 +2020,7 @@ mod tests { "count(*) FILTER (WHERE true)", ), ( - Expr::WindowFunction(WindowFunction { + Expr::from(WindowFunction { fun: WindowFunctionDefinition::WindowUDF(row_number_udwf()), params: WindowFunctionParams { args: vec![col("col")], @@ -2033,7 +2034,7 @@ mod tests { ), ( #[expect(deprecated)] - Expr::WindowFunction(WindowFunction { + Expr::from(WindowFunction { fun: WindowFunctionDefinition::AggregateUDF(count_udaf()), params: WindowFunctionParams { args: vec![Expr::Wildcard { @@ -2902,7 +2903,7 @@ mod tests { let func = WindowFunctionDefinition::WindowUDF(rank_udwf()); let mut window_func = WindowFunction::new(func, vec![]); window_func.params.order_by = vec![Sort::new(col("a"), true, true)]; - let expr = Expr::WindowFunction(window_func); + let expr = Expr::from(window_func); let ast = unparser.expr_to_sql(&expr)?; let actual = ast.to_string(); diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index 8496be1d7f9a..067da40cf9a8 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -241,15 +241,21 @@ pub fn window_expr_common_partition_keys(window_exprs: &[Expr]) -> Result<&[Expr let all_partition_keys = window_exprs .iter() .map(|expr| match expr { - Expr::WindowFunction(WindowFunction { - params: WindowFunctionParams { partition_by, .. }, - .. - }) => Ok(partition_by), - Expr::Alias(Alias { expr, .. }) => match expr.as_ref() { - Expr::WindowFunction(WindowFunction { + Expr::WindowFunction(window_fun) => { + let WindowFunction { params: WindowFunctionParams { partition_by, .. }, .. - }) => Ok(partition_by), + } = window_fun.as_ref(); + Ok(partition_by) + } + Expr::Alias(Alias { expr, .. }) => match expr.as_ref() { + Expr::WindowFunction(window_fun) => { + let WindowFunction { + params: WindowFunctionParams { partition_by, .. }, + .. + } = window_fun.as_ref(); + Ok(partition_by) + } expr => exec_err!("Impossibly got non-window expr {expr:?}"), }, expr => exec_err!("Impossibly got non-window expr {expr:?}"), diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/window_function.rs b/datafusion/substrait/src/logical_plan/consumer/expr/window_function.rs index 10a92a686b59..4a7fde256b6c 100644 --- a/datafusion/substrait/src/logical_plan/consumer/expr/window_function.rs +++ b/datafusion/substrait/src/logical_plan/consumer/expr/window_function.rs @@ -99,7 +99,7 @@ pub async fn from_window_function( from_substrait_func_args(consumer, &window.arguments, input_schema).await? }; - Ok(Expr::WindowFunction(expr::WindowFunction { + Ok(Expr::from(expr::WindowFunction { fun, params: WindowFunctionParams { args,