From d8d30c5209a6d4d569f0d13c43830670b45ff399 Mon Sep 17 00:00:00 2001 From: Chongchen Chen Date: Mon, 19 May 2025 12:13:50 +0800 Subject: [PATCH 01/18] grouping udf --- datafusion/core/tests/sql/sql_api.rs | 34 +++ datafusion/functions/src/core/grouping.rs | 270 ++++++++++++++++++++++ datafusion/functions/src/core/mod.rs | 2 + 3 files changed, 306 insertions(+) create mode 100644 datafusion/functions/src/core/grouping.rs diff --git a/datafusion/core/tests/sql/sql_api.rs b/datafusion/core/tests/sql/sql_api.rs index ec086bcc50c7..999faae850fb 100644 --- a/datafusion/core/tests/sql/sql_api.rs +++ b/datafusion/core/tests/sql/sql_api.rs @@ -17,6 +17,7 @@ use datafusion::prelude::*; +use datafusion_sql::unparser::plan_to_sql; use tempfile::TempDir; #[tokio::test] @@ -206,3 +207,36 @@ async fn ddl_can_not_be_planned_by_session_state() { "This feature is not implemented: Unsupported logical plan: DropTable" ); } + +#[tokio::test] +async fn unparse_group_by_with_filter() { + let ctx = SessionContext::new(); + let df = ctx.sql("CREATE TABLE test (c1 VARCHAR,c2 VARCHAR,c3 INT) as values ('a','A',1), ('b','B',2);").await.unwrap(); + let _ = df.collect().await.unwrap(); + + ctx.sql("set datafusion.sql_parser.dialect = 'Postgres';").await.unwrap(); + + let df = ctx.sql(r#"select + c1, + c2, + CASE WHEN grouping(c1) = 1 THEN sum(c3) filter (where c2 = 'A') ELSE NULL END as gx, + grouping(c1) as g0, + grouping(c2) as g1, + grouping(c1, c2) as g2, + grouping(c2, c1) as g3 +from + test +group by + grouping sets ( + (c1, c2), + (c1), + (c2), + () + ) +order by + c1, c2, g0, g1, g2, g3;"#).await.unwrap(); + let plan = df.into_optimized_plan().unwrap(); + let sql = plan_to_sql(&plan).unwrap().to_string(); + println!("!!!{}", sql); + assert_eq!(sql, "SELECT c1,c2,c3 FROM test GROUP BY c1,c2,c3 HAVING c3 > 1;"); +} diff --git a/datafusion/functions/src/core/grouping.rs b/datafusion/functions/src/core/grouping.rs new file mode 100644 index 000000000000..64a8ba07994d --- /dev/null +++ b/datafusion/functions/src/core/grouping.rs @@ -0,0 +1,270 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Array, Int32Array, UInt8Array, UInt16Array, UInt32Array, UInt64Array}; +use arrow::datatypes::DataType; +use datafusion_common::{exec_err, Result, ScalarValue}; +use datafusion_expr::{ + ColumnarValue, Documentation, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs, + ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_macros::user_doc; +use std::any::Any; +use std::sync::Arc; + +#[user_doc( + doc_section(label = "Aggregate Functions"), + description = "Returns 1 if the specified column is not included in the grouping set, 0 if it is included.", + syntax_example = "grouping(grouping_id, indices)", + sql_example = r#"```sql +> SELECT grouping(grouping_id, 0) FROM table GROUP BY GROUPING SETS ((a), (b)); ++----------------+ +| grouping | ++----------------+ +| 1 | +| 0 | ++----------------+ +```"#, + argument( + name = "grouping_id", + description = "The internal grouping ID column (UInt8/16/32/64)" + ), + argument( + name = "indices", + description = "The indices of the column in the grouping set (Int32)" + ) +)] +#[derive(Debug)] +pub struct GroupingFunc { + signature: Signature, +} + +impl Default for GroupingFunc { + fn default() -> Self { + GroupingFunc::new() + } +} + +impl GroupingFunc { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for GroupingFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "grouping" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Int32) + } + + fn return_type_from_args(&self, _args: ReturnTypeArgs) -> Result { + Ok(ReturnInfo::new(DataType::Int32, false)) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let args = args.args; + if args.len() != 2 { + return exec_err!( + "grouping function requires exactly 2 arguments, got {}", + args.len() + ); + } + + let grouping_id = match &args[0] { + ColumnarValue::Array(array) => array, + ColumnarValue::Scalar(_) => { + return exec_err!("grouping function requires array input for grouping_id") + } + }; + + let indices = match &args[1] { + ColumnarValue::Scalar(scalar) => { + match scalar { + ScalarValue::List(array) => { + // Get the values array from the list array + let Some(values) = array.values().as_any().downcast_ref::() else { + return exec_err!("grouping function requires Int32 indices array") + }; + values + } + ScalarValue::FixedSizeList(array) => { + // Get the values array from the list array + let Some(values) = array.values().as_any().downcast_ref::() else { + return exec_err!("grouping function requires Int32 indices array") + }; + values + } + _ => { + return exec_err!("grouping function requires list of Int32 indices") + } + } + } + ColumnarValue::Array(_) => { + return exec_err!("grouping function requires scalar input for indices") + } + }; + + if indices.null_count() > 0 { + return exec_err!("grouping function requires non-null indices array"); + } + + let result: Int32Array = match grouping_id.data_type() { + DataType::UInt8 => { + let grouping_id = match grouping_id.as_any().downcast_ref::() { + Some(array) => array, + None => return exec_err!("grouping function requires UInt8 grouping_id array"), + }; + grouping_id + .iter() + .map(|grouping_id| { + grouping_id.map(|grouping_id| { + let mut result = 0u8; + for (i, index) in indices.iter().enumerate() { + if let Some(index) = index { + let bit = (grouping_id >> index) & 1; + result |= bit << i; + } + } + result as i32 + }) + }) + .collect() + } + DataType::UInt16 => { + let grouping_id = match grouping_id.as_any().downcast_ref::() { + Some(array) => array, + None => return exec_err!("grouping function requires UInt16 grouping_id array"), + }; + grouping_id + .iter() + .map(|grouping_id| { + grouping_id.map(|grouping_id| { + let mut result = 0u16; + for (i, index) in indices.iter().enumerate() { + if let Some(index) = index { + let bit = (grouping_id >> index) & 1; + result |= bit << i; + } + } + result as i32 + }) + }) + .collect() + } + DataType::UInt32 => { + let grouping_id = match grouping_id.as_any().downcast_ref::() { + Some(array) => array, + None => return exec_err!("grouping function requires UInt32 grouping_id array"), + }; + grouping_id + .iter() + .map(|grouping_id| { + grouping_id.map(|grouping_id| { + let mut result = 0u32; + for (i, index) in indices.iter().enumerate() { + if let Some(index) = index { + let bit = (grouping_id >> index) & 1; + result |= bit << i; + } + } + result as i32 + }) + }) + .collect() + } + DataType::UInt64 => { + let grouping_id = match grouping_id.as_any().downcast_ref::() { + Some(array) => array, + None => return exec_err!("grouping function requires UInt64 grouping_id array"), + }; + grouping_id + .iter() + .map(|grouping_id| { + grouping_id.map(|grouping_id| { + let mut result = 0u64; + for (i, index) in indices.iter().enumerate() { + if let Some(index) = index { + let bit = (grouping_id >> index) & 1; + result |= bit << i; + } + } + result as i32 + }) + }) + .collect() + } + _ => { + return exec_err!( + "grouping function requires UInt8/16/32/64 for grouping_id, got {}", + grouping_id.data_type() + ) + } + }; + + Ok(ColumnarValue::Array(Arc::new(result))) + } + + fn short_circuits(&self) -> bool { + false + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + if arg_types.len() != 2 { + return exec_err!( + "grouping function requires exactly 2 arguments, got {}", + arg_types.len() + ); + } + + match arg_types[0] { + DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => {} + _ => { + return exec_err!( + "grouping function requires UInt8/16/32/64 for first argument, got {}", + arg_types[0] + ) + } + } + + if arg_types[1] != DataType::Int32 { + return exec_err!( + "grouping function requires Int32 for second argument, got {}", + arg_types[1] + ); + } + + Ok(arg_types.to_vec()) + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} + diff --git a/datafusion/functions/src/core/mod.rs b/datafusion/functions/src/core/mod.rs index c6329b1ee0af..b09dbc338d1a 100644 --- a/datafusion/functions/src/core/mod.rs +++ b/datafusion/functions/src/core/mod.rs @@ -37,6 +37,7 @@ pub mod planner; pub mod r#struct; pub mod union_extract; pub mod version; +pub mod grouping; // create UDFs make_udf_function!(arrow_cast::ArrowCastFunc, arrow_cast); @@ -53,6 +54,7 @@ make_udf_function!(greatest::GreatestFunc, greatest); make_udf_function!(least::LeastFunc, least); make_udf_function!(union_extract::UnionExtractFun, union_extract); make_udf_function!(version::VersionFunc, version); +make_udf_function!(grouping::GroupingFunc, grouping); pub mod expr_fn { use datafusion_expr::{Expr, Literal}; From 6c540e76466beff1855ee653dc09331b851e6d77 Mon Sep 17 00:00:00 2001 From: Chongchen Chen Date: Mon, 19 May 2025 16:52:28 +0800 Subject: [PATCH 02/18] update --- Cargo.lock | 2 + datafusion/functions/src/core/grouping.rs | 128 ++++++------------ datafusion/optimizer/Cargo.toml | 2 + .../src/analyzer/resolve_grouping_function.rs | 59 ++------ 4 files changed, 55 insertions(+), 136 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 299ea0dc4c6f..cae64ed62dc8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2393,7 +2393,9 @@ dependencies = [ "ctor", "datafusion-common", "datafusion-expr", + "datafusion-functions", "datafusion-functions-aggregate", + "datafusion-functions-nested", "datafusion-functions-window", "datafusion-functions-window-common", "datafusion-physical-expr", diff --git a/datafusion/functions/src/core/grouping.rs b/datafusion/functions/src/core/grouping.rs index 64a8ba07994d..4ebd940d3e16 100644 --- a/datafusion/functions/src/core/grouping.rs +++ b/datafusion/functions/src/core/grouping.rs @@ -26,12 +26,36 @@ use datafusion_macros::user_doc; use std::any::Any; use std::sync::Arc; +macro_rules! grouping_id { + ($grouping_id:expr, $indices:expr, $type:ty, $array_type:ty) => {{ + let grouping_id = match $grouping_id.as_any().downcast_ref::<$array_type>() { + Some(array) => array, + None => return exec_err!("grouping function requires {} grouping_id array", stringify!($type)), + }; + grouping_id + .iter() + .map(|grouping_id| { + grouping_id.map(|grouping_id| { + let mut result = 0 as $type; + for (i, index) in $indices.iter().enumerate() { + if let Some(index) = index { + let bit = (grouping_id >> index) & 1; + result |= bit << i; + } + } + result as i32 + }) + }) + .collect() + }}; +} + #[user_doc( - doc_section(label = "Aggregate Functions"), - description = "Returns 1 if the specified column is not included in the grouping set, 0 if it is included.", - syntax_example = "grouping(grouping_id, indices)", + doc_section(label = "Other Functions"), + description = "[Developer API] Returns 1 if the specified column is not included in the grouping set, 0 if it is included.", + syntax_example = "grouping(grouping_id[, indices])", sql_example = r#"```sql -> SELECT grouping(grouping_id, 0) FROM table GROUP BY GROUPING SETS ((a), (b)); +> SELECT grouping(__grouping_id, make_array(0)) FROM table GROUP BY GROUPING SETS ((a), (b)); +----------------+ | grouping | +----------------+ @@ -90,9 +114,9 @@ impl ScalarUDFImpl for GroupingFunc { fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { let args = args.args; - if args.len() != 2 { + if args.len() != 2 && args.len() != 1 { return exec_err!( - "grouping function requires exactly 2 arguments, got {}", + "grouping function requires 1 or 2 arguments, got {}", args.len() ); } @@ -104,6 +128,10 @@ impl ScalarUDFImpl for GroupingFunc { } }; + if args.len() == 1 { + return args[0].cast_to(grouping_id.data_type(), None); + } + let indices = match &args[1] { ColumnarValue::Scalar(scalar) => { match scalar { @@ -136,90 +164,10 @@ impl ScalarUDFImpl for GroupingFunc { } let result: Int32Array = match grouping_id.data_type() { - DataType::UInt8 => { - let grouping_id = match grouping_id.as_any().downcast_ref::() { - Some(array) => array, - None => return exec_err!("grouping function requires UInt8 grouping_id array"), - }; - grouping_id - .iter() - .map(|grouping_id| { - grouping_id.map(|grouping_id| { - let mut result = 0u8; - for (i, index) in indices.iter().enumerate() { - if let Some(index) = index { - let bit = (grouping_id >> index) & 1; - result |= bit << i; - } - } - result as i32 - }) - }) - .collect() - } - DataType::UInt16 => { - let grouping_id = match grouping_id.as_any().downcast_ref::() { - Some(array) => array, - None => return exec_err!("grouping function requires UInt16 grouping_id array"), - }; - grouping_id - .iter() - .map(|grouping_id| { - grouping_id.map(|grouping_id| { - let mut result = 0u16; - for (i, index) in indices.iter().enumerate() { - if let Some(index) = index { - let bit = (grouping_id >> index) & 1; - result |= bit << i; - } - } - result as i32 - }) - }) - .collect() - } - DataType::UInt32 => { - let grouping_id = match grouping_id.as_any().downcast_ref::() { - Some(array) => array, - None => return exec_err!("grouping function requires UInt32 grouping_id array"), - }; - grouping_id - .iter() - .map(|grouping_id| { - grouping_id.map(|grouping_id| { - let mut result = 0u32; - for (i, index) in indices.iter().enumerate() { - if let Some(index) = index { - let bit = (grouping_id >> index) & 1; - result |= bit << i; - } - } - result as i32 - }) - }) - .collect() - } - DataType::UInt64 => { - let grouping_id = match grouping_id.as_any().downcast_ref::() { - Some(array) => array, - None => return exec_err!("grouping function requires UInt64 grouping_id array"), - }; - grouping_id - .iter() - .map(|grouping_id| { - grouping_id.map(|grouping_id| { - let mut result = 0u64; - for (i, index) in indices.iter().enumerate() { - if let Some(index) = index { - let bit = (grouping_id >> index) & 1; - result |= bit << i; - } - } - result as i32 - }) - }) - .collect() - } + DataType::UInt8 => grouping_id!(grouping_id, indices, u8, UInt8Array), + DataType::UInt16 => grouping_id!(grouping_id, indices, u16, UInt16Array), + DataType::UInt32 => grouping_id!(grouping_id, indices, u32, UInt32Array), + DataType::UInt64 => grouping_id!(grouping_id, indices, u64, UInt64Array), _ => { return exec_err!( "grouping function requires UInt8/16/32/64 for grouping_id, got {}", diff --git a/datafusion/optimizer/Cargo.toml b/datafusion/optimizer/Cargo.toml index 60358d20e2a1..6c98384c9345 100644 --- a/datafusion/optimizer/Cargo.toml +++ b/datafusion/optimizer/Cargo.toml @@ -45,6 +45,8 @@ arrow = { workspace = true } chrono = { workspace = true } datafusion-common = { workspace = true, default-features = true } datafusion-expr = { workspace = true } +datafusion-functions = { workspace = true } +datafusion-functions-nested = { workspace = true } datafusion-physical-expr = { workspace = true } indexmap = { workspace = true } itertools = { workspace = true } diff --git a/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs b/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs index f8a818563609..dfea88224914 100644 --- a/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs +++ b/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs @@ -18,13 +18,11 @@ //! Analyzed rule to replace TableScan references //! such as DataFrames and Views and inlines the LogicalPlan. -use std::cmp::Ordering; use std::collections::HashMap; use std::sync::Arc; use crate::analyzer::AnalyzerRule; -use arrow::datatypes::DataType; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{ @@ -33,11 +31,10 @@ use datafusion_common::{ use datafusion_expr::expr::{AggregateFunction, Alias}; use datafusion_expr::logical_plan::LogicalPlan; use datafusion_expr::utils::grouping_set_to_exprlist; -use datafusion_expr::{ - bitwise_and, bitwise_or, bitwise_shift_left, bitwise_shift_right, cast, Aggregate, - Expr, Projection, -}; +use datafusion_expr::{Aggregate, Expr, Projection}; use itertools::Itertools; +use datafusion_functions::core::grouping; +use datafusion_functions_nested::make_array::make_array_udf; /// Replaces grouping aggregation function with value derived from internal grouping id #[derive(Default, Debug)] @@ -193,18 +190,6 @@ fn grouping_function_on_id( } let group_by_expr_count = group_by_expr.len(); - let literal = |value: usize| { - if group_by_expr_count < 8 { - Expr::Literal(ScalarValue::from(value as u8)) - } else if group_by_expr_count < 16 { - Expr::Literal(ScalarValue::from(value as u16)) - } else if group_by_expr_count < 32 { - Expr::Literal(ScalarValue::from(value as u32)) - } else { - Expr::Literal(ScalarValue::from(value as u64)) - } - }; - let grouping_id_column = Expr::Column(Column::from(Aggregate::INTERNAL_GROUPING_ID)); // The grouping call is exactly our internal grouping id if args.len() == group_by_expr_count @@ -214,35 +199,17 @@ fn grouping_function_on_id( .enumerate() .all(|(idx, expr)| group_by_expr.get(expr) == Some(&idx)) { - return Ok(cast(grouping_id_column, DataType::Int32)); + return Ok(grouping().call(vec![grouping_id_column])); } - args.iter() - .rev() - .enumerate() - .map(|(arg_idx, expr)| { - group_by_expr.get(expr).map(|group_by_idx| { - let group_by_bit = - bitwise_and(grouping_id_column.clone(), literal(1 << group_by_idx)); - match group_by_idx.cmp(&arg_idx) { - Ordering::Less => { - bitwise_shift_left(group_by_bit, literal(arg_idx - group_by_idx)) - } - Ordering::Greater => { - bitwise_shift_right(group_by_bit, literal(group_by_idx - arg_idx)) - } - Ordering::Equal => group_by_bit, - } - }) - }) - .collect::>>() - .and_then(|bit_exprs| { - bit_exprs - .into_iter() - .reduce(bitwise_or) - .map(|expr| cast(expr, DataType::Int32)) - }) - .ok_or_else(|| { - internal_datafusion_err!("Grouping sets should contains at least one element") + args.iter().map(|expr| { + group_by_expr.get(expr).map(|group_by_idx| { + Expr::Literal(ScalarValue::from(*group_by_idx as i32)) }) + }).collect::>>() + .and_then(|exprs| { + Some(grouping().call(vec![grouping_id_column, make_array_udf().call(exprs)])) + }).ok_or_else(|| { + internal_datafusion_err!("Grouping sets should contains at least one element") + }) } From 7bbc6168dd77522a471d47d28384da5cb0c981c9 Mon Sep 17 00:00:00 2001 From: Chongchen Chen Date: Mon, 19 May 2025 17:42:42 +0800 Subject: [PATCH 03/18] add test --- datafusion/functions/src/core/grouping.rs | 178 ++++++++++++++++++++++ 1 file changed, 178 insertions(+) diff --git a/datafusion/functions/src/core/grouping.rs b/datafusion/functions/src/core/grouping.rs index 4ebd940d3e16..36af08809787 100644 --- a/datafusion/functions/src/core/grouping.rs +++ b/datafusion/functions/src/core/grouping.rs @@ -216,3 +216,181 @@ impl ScalarUDFImpl for GroupingFunc { } } +#[cfg(test)] +mod tests { + use super::*; + use arrow::{array::{Int32Array, ListArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array}, datatypes::Int32Type}; + use datafusion_common::Result; + + #[test] + fn test_grouping_uint8() -> Result<()> { + let grouping_id = UInt8Array::from(vec![Some(1), Some(2), Some(3), Some(4)]); + let indices = vec![Some(vec![Some(0), Some(1)])]; + + let args = vec![ + ColumnarValue::Array(Arc::new(grouping_id)), + ColumnarValue::Scalar(ScalarValue::List(Arc::new(ListArray::from_iter_primitive::(indices)))), + ]; + + let func = GroupingFunc::new(); + let return_type = DataType::Int32; + let result = func.invoke_with_args(ScalarFunctionArgs { + args, + number_rows: 4, + return_type: &return_type, + })?; + + let result = match result { + ColumnarValue::Array(array) => array, + _ => panic!("Expected array result"), + }; + + let result = result.as_any().downcast_ref::().unwrap(); + assert_eq!(result.values().to_vec(), vec![1, 2, 3, 3]); + Ok(()) + } + + #[test] + fn test_grouping_uint16() -> Result<()> { + let grouping_id = UInt16Array::from(vec![Some(1), Some(2), Some(3), Some(4)]); + let indices = vec![Some(vec![Some(0), Some(1)])]; + + let args = vec![ + ColumnarValue::Array(Arc::new(grouping_id)), + ColumnarValue::Scalar(ScalarValue::List(Arc::new(ListArray::from_iter_primitive::(indices)))), + ]; + + let func = GroupingFunc::new(); + let return_type = DataType::Int32; + let result = func.invoke_with_args(ScalarFunctionArgs { + args, + number_rows: 4, + return_type: &return_type, + })?; + + let result = match result { + ColumnarValue::Array(array) => array, + _ => panic!("Expected array result"), + }; + + let result = result.as_any().downcast_ref::().unwrap(); + assert_eq!(result.values().to_vec(), vec![1, 2, 3, 3]); + Ok(()) + } + + #[test] + fn test_grouping_uint32() -> Result<()> { + let grouping_id = UInt32Array::from(vec![Some(1), Some(2), Some(3), Some(4)]); + let indices = vec![Some(vec![Some(0), Some(1)])]; + + let args = vec![ + ColumnarValue::Array(Arc::new(grouping_id)), + ColumnarValue::Scalar(ScalarValue::List(Arc::new(ListArray::from_iter_primitive::(indices)))), + ]; + + let func = GroupingFunc::new(); + let return_type = DataType::Int32; + let result = func.invoke_with_args(ScalarFunctionArgs { + args, + number_rows: 4, + return_type: &return_type, + })?; + + let result = match result { + ColumnarValue::Array(array) => array, + _ => panic!("Expected array result"), + }; + + let result = result.as_any().downcast_ref::().unwrap(); + assert_eq!(result.values().to_vec(), vec![1, 2, 3, 3]); + Ok(()) + } + + #[test] + fn test_grouping_uint64() -> Result<()> { + let grouping_id = UInt64Array::from(vec![Some(1), Some(2), Some(3), Some(4)]); + let indices = vec![Some(vec![Some(0), Some(1)])]; + + let args = vec![ + ColumnarValue::Array(Arc::new(grouping_id)), + ColumnarValue::Scalar(ScalarValue::List(Arc::new(ListArray::from_iter_primitive::(indices)))), + ]; + + let func = GroupingFunc::new(); + let return_type = DataType::Int32; + let result = func.invoke_with_args(ScalarFunctionArgs { + args, + number_rows: 4, + return_type: &return_type, + })?; + + let result = match result { + ColumnarValue::Array(array) => array, + _ => panic!("Expected array result"), + }; + + let result = result.as_any().downcast_ref::().unwrap(); + assert_eq!(result.values().to_vec(), vec![1, 2, 3, 3]); + Ok(()) + } + + #[test] + fn test_grouping_with_null_indices() -> Result<()> { + let grouping_id = UInt8Array::from(vec![Some(1), Some(2), Some(3), Some(4)]); + let indices = vec![Some(vec![Some(0), None])]; + + let args = vec![ + ColumnarValue::Array(Arc::new(grouping_id)), + ColumnarValue::Scalar(ScalarValue::List(Arc::new(ListArray::from_iter_primitive::(indices)))), + ]; + + let func = GroupingFunc::new(); + let return_type = DataType::Int32; + let result = func.invoke_with_args(ScalarFunctionArgs { + args, + number_rows: 4, + return_type: &return_type, + }); + assert!(result.is_err()); + Ok(()) + } + + #[test] + fn test_grouping_with_invalid_args() -> Result<()> { + let grouping_id = UInt8Array::from(vec![Some(1), Some(2), Some(3), Some(4)]); + let indices = vec![Some(vec![Some(0)])]; + + // Test with too many arguments + let args = vec![ + ColumnarValue::Array(Arc::new(grouping_id)), + ColumnarValue::Scalar(ScalarValue::List(Arc::new(ListArray::from_iter_primitive::(indices)))), + ColumnarValue::Scalar(ScalarValue::Int32(Some(1))), + ]; + + let func = GroupingFunc::new(); + let return_type = DataType::Int32; + let result = func.invoke_with_args(ScalarFunctionArgs { + args, + number_rows: 4, + return_type: &return_type, + }); + assert!(result.is_err()); + + // Test with invalid array type + let args = vec![ + ColumnarValue::Array(Arc::new(Int32Array::from(vec![Some(1)]))), + ColumnarValue::Scalar(ScalarValue::List(Arc::new(ListArray::from_iter_primitive::(indices)))), + ]; + + let func = GroupingFunc::new(); + let return_type = DataType::Int32; + let result = func.invoke_with_args(ScalarFunctionArgs { + args, + number_rows: 1, + return_type: &return_type, + }); + assert!(result.is_err()); + Ok(()) + } +} + From fe6b56d3fc5a215eed19e863da91b280f8ea84f4 Mon Sep 17 00:00:00 2001 From: Chongchen Chen Date: Tue, 20 May 2025 22:29:52 +0800 Subject: [PATCH 04/18] WIP grouping optimization --- datafusion/functions-aggregate/src/grouping.rs | 16 ++++++++-------- datafusion/functions/src/core/grouping.rs | 2 +- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/datafusion/functions-aggregate/src/grouping.rs b/datafusion/functions-aggregate/src/grouping.rs index 445774ff11e7..1f869bdcf4b6 100644 --- a/datafusion/functions-aggregate/src/grouping.rs +++ b/datafusion/functions-aggregate/src/grouping.rs @@ -35,25 +35,25 @@ make_udaf_expr_and_func!( Grouping, grouping, expression, - "Returns 1 if the data is aggregated across the specified column or 0 for not aggregated in the result set.", + "Returns the level of grouping, equals to (grouping(c1) << (n-1)) + (grouping(c2) << (n-2)) + … + grouping(cn)", grouping_udaf ); #[user_doc( doc_section(label = "General Functions"), - description = "Returns 1 if the data is aggregated across the specified column, or 0 if it is not aggregated in the result set.", + description = "Returns the level of grouping, equals to (grouping(c1) << (n-1)) + (grouping(c2) << (n-2)) + … + grouping(cn)", syntax_example = "grouping(expression)", sql_example = r#"```sql > SELECT column_name, GROUPING(column_name) AS group_column FROM table_name GROUP BY GROUPING SETS ((column_name), ()); -+-------------+-------------+ ++-------------+--------------+ | column_name | group_column | -+-------------+-------------+ -| value1 | 0 | -| value2 | 0 | -| NULL | 1 | -+-------------+-------------+ ++-------------+--------------+ +| value1 | 0 | +| value2 | 0 | +| NULL | 1 | ++-------------+--------------+ ```"#, argument( name = "expression", diff --git a/datafusion/functions/src/core/grouping.rs b/datafusion/functions/src/core/grouping.rs index 36af08809787..bacff72e91c7 100644 --- a/datafusion/functions/src/core/grouping.rs +++ b/datafusion/functions/src/core/grouping.rs @@ -52,7 +52,7 @@ macro_rules! grouping_id { #[user_doc( doc_section(label = "Other Functions"), - description = "[Developer API] Returns 1 if the specified column is not included in the grouping set, 0 if it is included.", + description = "Developer API: Returns the level of grouping, equals to (((grouping_id >> array[0]) & 1) << (n-1)) + (((grouping_id >> array[1]) & 1) << (n-2)) + ... + (((grouping_id >> array[n-1]) & 1) << 0). Returns grouping_id if indices is not provided.", syntax_example = "grouping(grouping_id[, indices])", sql_example = r#"```sql > SELECT grouping(__grouping_id, make_array(0)) FROM table GROUP BY GROUPING SETS ((a), (b)); From 501bb0de719b29f110871fbad586b3fe57bb770d Mon Sep 17 00:00:00 2001 From: Chongchen Chen Date: Tue, 20 May 2025 23:24:00 +0800 Subject: [PATCH 05/18] update grouping --- datafusion/functions/src/core/grouping.rs | 115 +++++++++------------- 1 file changed, 47 insertions(+), 68 deletions(-) diff --git a/datafusion/functions/src/core/grouping.rs b/datafusion/functions/src/core/grouping.rs index bacff72e91c7..1cb34951c548 100644 --- a/datafusion/functions/src/core/grouping.rs +++ b/datafusion/functions/src/core/grouping.rs @@ -15,9 +15,11 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{Array, Int32Array, UInt8Array, UInt16Array, UInt32Array, UInt64Array}; -use arrow::datatypes::DataType; -use datafusion_common::{exec_err, Result, ScalarValue}; +use arrow::array::{Array, ArrayRef, Int32Array, PrimitiveArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array}; +use arrow::compute::cast; +use arrow::datatypes::{DataType, Int32Type}; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::function::Hint; use datafusion_expr::{ ColumnarValue, Documentation, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, @@ -26,6 +28,8 @@ use datafusion_macros::user_doc; use std::any::Any; use std::sync::Arc; +use crate::utils::make_scalar_function; + macro_rules! grouping_id { ($grouping_id:expr, $indices:expr, $type:ty, $array_type:ty) => {{ let grouping_id = match $grouping_id.as_any().downcast_ref::<$array_type>() { @@ -113,70 +117,7 @@ impl ScalarUDFImpl for GroupingFunc { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - let args = args.args; - if args.len() != 2 && args.len() != 1 { - return exec_err!( - "grouping function requires 1 or 2 arguments, got {}", - args.len() - ); - } - - let grouping_id = match &args[0] { - ColumnarValue::Array(array) => array, - ColumnarValue::Scalar(_) => { - return exec_err!("grouping function requires array input for grouping_id") - } - }; - - if args.len() == 1 { - return args[0].cast_to(grouping_id.data_type(), None); - } - - let indices = match &args[1] { - ColumnarValue::Scalar(scalar) => { - match scalar { - ScalarValue::List(array) => { - // Get the values array from the list array - let Some(values) = array.values().as_any().downcast_ref::() else { - return exec_err!("grouping function requires Int32 indices array") - }; - values - } - ScalarValue::FixedSizeList(array) => { - // Get the values array from the list array - let Some(values) = array.values().as_any().downcast_ref::() else { - return exec_err!("grouping function requires Int32 indices array") - }; - values - } - _ => { - return exec_err!("grouping function requires list of Int32 indices") - } - } - } - ColumnarValue::Array(_) => { - return exec_err!("grouping function requires scalar input for indices") - } - }; - - if indices.null_count() > 0 { - return exec_err!("grouping function requires non-null indices array"); - } - - let result: Int32Array = match grouping_id.data_type() { - DataType::UInt8 => grouping_id!(grouping_id, indices, u8, UInt8Array), - DataType::UInt16 => grouping_id!(grouping_id, indices, u16, UInt16Array), - DataType::UInt32 => grouping_id!(grouping_id, indices, u32, UInt32Array), - DataType::UInt64 => grouping_id!(grouping_id, indices, u64, UInt64Array), - _ => { - return exec_err!( - "grouping function requires UInt8/16/32/64 for grouping_id, got {}", - grouping_id.data_type() - ) - } - }; - - Ok(ColumnarValue::Array(Arc::new(result))) + make_scalar_function(grouping_inner, vec![Hint::Pad, Hint::AcceptsSingular])(&args.args) } fn short_circuits(&self) -> bool { @@ -200,8 +141,14 @@ impl ScalarUDFImpl for GroupingFunc { ) } } + let DataType::List(field) = &arg_types[1] else { + return exec_err!( + "grouping function requires Int32 for second argument, got {}", + arg_types[1] + ); + }; - if arg_types[1] != DataType::Int32 { + if field.data_type() != &DataType::Int32 { return exec_err!( "grouping function requires Int32 for second argument, got {}", arg_types[1] @@ -216,6 +163,38 @@ impl ScalarUDFImpl for GroupingFunc { } } +fn grouping_inner(args: &[ArrayRef]) -> Result { + if args.len() != 2 && args.len() != 1 { + return exec_err!( + "grouping function requires 1 or 2 arguments, got {}", + args.len() + ); + } + + if args.len() == 1 { + return cast(&args[0], &DataType::Int32).map_err(|e| e.into()); + } + + let grouping_id = &args[0]; + let indices = &args[1]; + let indices = indices.as_any().downcast_ref::>().unwrap(); + + let result: Int32Array = match grouping_id.data_type() { + DataType::UInt8 => grouping_id!(grouping_id, indices, u8, UInt8Array), + DataType::UInt16 => grouping_id!(grouping_id, indices, u16, UInt16Array), + DataType::UInt32 => grouping_id!(grouping_id, indices, u32, UInt32Array), + DataType::UInt64 => grouping_id!(grouping_id, indices, u64, UInt64Array), + _ => { + return exec_err!( + "grouping function requires UInt8/16/32/64 for grouping_id, got {}", + grouping_id.data_type() + ) + } + }; + + Ok(Arc::new(result)) +} + #[cfg(test)] mod tests { use super::*; From 84820460244bb5e970b45d75dbf578ec40705393 Mon Sep 17 00:00:00 2001 From: Chongchen Chen Date: Wed, 21 May 2025 21:25:01 +0800 Subject: [PATCH 06/18] fix bug --- datafusion/functions/src/core/grouping.rs | 59 +++++++++++++---------- 1 file changed, 34 insertions(+), 25 deletions(-) diff --git a/datafusion/functions/src/core/grouping.rs b/datafusion/functions/src/core/grouping.rs index 1cb34951c548..a953320aea52 100644 --- a/datafusion/functions/src/core/grouping.rs +++ b/datafusion/functions/src/core/grouping.rs @@ -15,11 +15,10 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{Array, ArrayRef, Int32Array, PrimitiveArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array}; +use arrow::array::{Array, ArrayRef, AsArray, Int32Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array}; use arrow::compute::cast; -use arrow::datatypes::{DataType, Int32Type}; +use arrow::datatypes::{DataType, Field, Int32Type}; use datafusion_common::{exec_err, Result}; -use datafusion_expr::function::Hint; use datafusion_expr::{ ColumnarValue, Documentation, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, @@ -38,13 +37,21 @@ macro_rules! grouping_id { }; grouping_id .iter() - .map(|grouping_id| { + .zip($indices.iter()) + .map(|(grouping_id, indices)| { grouping_id.map(|grouping_id| { let mut result = 0 as $type; - for (i, index) in $indices.iter().enumerate() { - if let Some(index) = index { - let bit = (grouping_id >> index) & 1; - result |= bit << i; + match indices { + Some(indices) => { + for index in indices.as_primitive::().iter() { + if let Some(index) = index { + let bit = (grouping_id >> index) & 1; + result = (result << 1) | bit; + } + } + } + None => { + result = grouping_id; } } result as i32 @@ -117,7 +124,7 @@ impl ScalarUDFImpl for GroupingFunc { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - make_scalar_function(grouping_inner, vec![Hint::Pad, Hint::AcceptsSingular])(&args.args) + make_scalar_function(grouping_inner, vec![])(&args.args) } fn short_circuits(&self) -> bool { @@ -125,37 +132,39 @@ impl ScalarUDFImpl for GroupingFunc { } fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - if arg_types.len() != 2 { + if arg_types.len() != 2 && arg_types.len() != 1 { return exec_err!( - "grouping function requires exactly 2 arguments, got {}", + "grouping function requires 1 or 2 arguments, got {}", arg_types.len() ); } - match arg_types[0] { - DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => {} - _ => { - return exec_err!( - "grouping function requires UInt8/16/32/64 for first argument, got {}", - arg_types[0] - ) - } + if !arg_types[0].is_unsigned_integer() { + return exec_err!( + "grouping function requires unsigned integer for first argument, got {}", + arg_types[0] + ) + } + + if arg_types.len() == 1 { + return Ok(vec![arg_types[0].clone()]); } + let DataType::List(field) = &arg_types[1] else { return exec_err!( - "grouping function requires Int32 for second argument, got {}", + "grouping function requires list for second argument, got {}", arg_types[1] ); }; - if field.data_type() != &DataType::Int32 { + if !field.data_type().is_integer() { return exec_err!( - "grouping function requires Int32 for second argument, got {}", + "grouping function requires list of integers for second argument, got {}", arg_types[1] ); } - Ok(arg_types.to_vec()) + Ok(vec![arg_types[0].clone(), DataType::List(Arc::new(Field::new_list_field(DataType::Int32, false))),]) } fn documentation(&self) -> Option<&Documentation> { @@ -177,7 +186,7 @@ fn grouping_inner(args: &[ArrayRef]) -> Result { let grouping_id = &args[0]; let indices = &args[1]; - let indices = indices.as_any().downcast_ref::>().unwrap(); + let indices = indices.as_list::(); let result: Int32Array = match grouping_id.data_type() { DataType::UInt8 => grouping_id!(grouping_id, indices, u8, UInt8Array), @@ -199,7 +208,7 @@ fn grouping_inner(args: &[ArrayRef]) -> Result { mod tests { use super::*; use arrow::{array::{Int32Array, ListArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array}, datatypes::Int32Type}; - use datafusion_common::Result; + use datafusion_common::{Result, ScalarValue}; #[test] fn test_grouping_uint8() -> Result<()> { From 7b79382163ba1a6a2fbc270ec13b531bda85b9ff Mon Sep 17 00:00:00 2001 From: Chongchen Chen Date: Wed, 21 May 2025 22:53:57 +0800 Subject: [PATCH 07/18] merge master --- datafusion/functions/src/core/grouping.rs | 73 ++++++++++++++++++----- 1 file changed, 59 insertions(+), 14 deletions(-) diff --git a/datafusion/functions/src/core/grouping.rs b/datafusion/functions/src/core/grouping.rs index a953320aea52..f4a0d423702b 100644 --- a/datafusion/functions/src/core/grouping.rs +++ b/datafusion/functions/src/core/grouping.rs @@ -20,7 +20,7 @@ use arrow::compute::cast; use arrow::datatypes::{DataType, Field, Int32Type}; use datafusion_common::{exec_err, Result}; use datafusion_expr::{ - ColumnarValue, Documentation, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, }; use datafusion_macros::user_doc; @@ -119,10 +119,6 @@ impl ScalarUDFImpl for GroupingFunc { Ok(DataType::Int32) } - fn return_type_from_args(&self, _args: ReturnTypeArgs) -> Result { - Ok(ReturnInfo::new(DataType::Int32, false)) - } - fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { make_scalar_function(grouping_inner, vec![])(&args.args) } @@ -217,15 +213,22 @@ mod tests { let args = vec![ ColumnarValue::Array(Arc::new(grouping_id)), - ColumnarValue::Scalar(ScalarValue::List(Arc::new(ListArray::from_iter_primitive::(indices)))), + ColumnarValue::Scalar(ScalarValue::List(Arc::new(ListArray::from_iter_primitive::(indices.clone())))), ]; + let arg_fields_owned = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + let func = GroupingFunc::new(); let return_type = DataType::Int32; let result = func.invoke_with_args(ScalarFunctionArgs { args, number_rows: 4, - return_type: &return_type, + arg_fields: arg_fields_owned, + return_field: &Field::new("f", DataType::Int32, true), })?; let result = match result { @@ -248,12 +251,19 @@ mod tests { ColumnarValue::Scalar(ScalarValue::List(Arc::new(ListArray::from_iter_primitive::(indices)))), ]; + let arg_fields_owned = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + let func = GroupingFunc::new(); let return_type = DataType::Int32; let result = func.invoke_with_args(ScalarFunctionArgs { args, number_rows: 4, - return_type: &return_type, + arg_fields: arg_fields_owned, + return_field: &Field::new("f", DataType::Int32, true), })?; let result = match result { @@ -276,12 +286,19 @@ mod tests { ColumnarValue::Scalar(ScalarValue::List(Arc::new(ListArray::from_iter_primitive::(indices)))), ]; + let arg_fields_owned = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + let func = GroupingFunc::new(); let return_type = DataType::Int32; let result = func.invoke_with_args(ScalarFunctionArgs { args, number_rows: 4, - return_type: &return_type, + arg_fields: arg_fields_owned, + return_field: &Field::new("f", DataType::Int32, true), })?; let result = match result { @@ -304,12 +321,19 @@ mod tests { ColumnarValue::Scalar(ScalarValue::List(Arc::new(ListArray::from_iter_primitive::(indices)))), ]; + let arg_fields_owned = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + let func = GroupingFunc::new(); let return_type = DataType::Int32; let result = func.invoke_with_args(ScalarFunctionArgs { args, + arg_fields: arg_fields_owned, number_rows: 4, - return_type: &return_type, + return_field: &Field::new("f", DataType::Int32, true), })?; let result = match result { @@ -332,12 +356,19 @@ mod tests { ColumnarValue::Scalar(ScalarValue::List(Arc::new(ListArray::from_iter_primitive::(indices)))), ]; + let arg_fields_owned = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + let func = GroupingFunc::new(); let return_type = DataType::Int32; let result = func.invoke_with_args(ScalarFunctionArgs { args, + arg_fields: arg_fields_owned, number_rows: 4, - return_type: &return_type, + return_field: &Field::new("f", DataType::Int32, true), }); assert!(result.is_err()); Ok(()) @@ -351,16 +382,29 @@ mod tests { // Test with too many arguments let args = vec![ ColumnarValue::Array(Arc::new(grouping_id)), - ColumnarValue::Scalar(ScalarValue::List(Arc::new(ListArray::from_iter_primitive::(indices)))), + ColumnarValue::Scalar(ScalarValue::List(Arc::new(ListArray::from_iter_primitive::(indices.clone())))), ColumnarValue::Scalar(ScalarValue::Int32(Some(1))), ]; + let arg_fields_owned = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + + let arg_fields_owned = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + let func = GroupingFunc::new(); let return_type = DataType::Int32; let result = func.invoke_with_args(ScalarFunctionArgs { args, + arg_fields: arg_fields_owned, number_rows: 4, - return_type: &return_type, + return_field: &Field::new("f", DataType::Int32, true), }); assert!(result.is_err()); @@ -374,8 +418,9 @@ mod tests { let return_type = DataType::Int32; let result = func.invoke_with_args(ScalarFunctionArgs { args, + arg_fields: arg_fields_owned, number_rows: 1, - return_type: &return_type, + return_field: &Field::new("f", DataType::Int32, true), }); assert!(result.is_err()); Ok(()) From c723433440bd882194850d3104c5a1baba9cbf3c Mon Sep 17 00:00:00 2001 From: Chongchen Chen Date: Wed, 21 May 2025 22:59:27 +0800 Subject: [PATCH 08/18] update --- datafusion/functions/src/core/grouping.rs | 21 +++++++-------------- 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/datafusion/functions/src/core/grouping.rs b/datafusion/functions/src/core/grouping.rs index f4a0d423702b..459928d0b4db 100644 --- a/datafusion/functions/src/core/grouping.rs +++ b/datafusion/functions/src/core/grouping.rs @@ -223,11 +223,10 @@ mod tests { .collect::>(); let func = GroupingFunc::new(); - let return_type = DataType::Int32; let result = func.invoke_with_args(ScalarFunctionArgs { args, number_rows: 4, - arg_fields: arg_fields_owned, + arg_fields: arg_fields_owned.iter().collect::>(), return_field: &Field::new("f", DataType::Int32, true), })?; @@ -258,11 +257,10 @@ mod tests { .collect::>(); let func = GroupingFunc::new(); - let return_type = DataType::Int32; let result = func.invoke_with_args(ScalarFunctionArgs { args, number_rows: 4, - arg_fields: arg_fields_owned, + arg_fields: arg_fields_owned.iter().collect::>(), return_field: &Field::new("f", DataType::Int32, true), })?; @@ -293,11 +291,10 @@ mod tests { .collect::>(); let func = GroupingFunc::new(); - let return_type = DataType::Int32; let result = func.invoke_with_args(ScalarFunctionArgs { args, number_rows: 4, - arg_fields: arg_fields_owned, + arg_fields: arg_fields_owned.iter().collect::>(), return_field: &Field::new("f", DataType::Int32, true), })?; @@ -328,10 +325,9 @@ mod tests { .collect::>(); let func = GroupingFunc::new(); - let return_type = DataType::Int32; let result = func.invoke_with_args(ScalarFunctionArgs { args, - arg_fields: arg_fields_owned, + arg_fields: arg_fields_owned.iter().collect::>(), number_rows: 4, return_field: &Field::new("f", DataType::Int32, true), })?; @@ -363,10 +359,9 @@ mod tests { .collect::>(); let func = GroupingFunc::new(); - let return_type = DataType::Int32; let result = func.invoke_with_args(ScalarFunctionArgs { args, - arg_fields: arg_fields_owned, + arg_fields: arg_fields_owned.iter().collect::>(), number_rows: 4, return_field: &Field::new("f", DataType::Int32, true), }); @@ -399,10 +394,9 @@ mod tests { .collect::>(); let func = GroupingFunc::new(); - let return_type = DataType::Int32; let result = func.invoke_with_args(ScalarFunctionArgs { args, - arg_fields: arg_fields_owned, + arg_fields: arg_fields_owned.iter().collect::>(), number_rows: 4, return_field: &Field::new("f", DataType::Int32, true), }); @@ -415,10 +409,9 @@ mod tests { ]; let func = GroupingFunc::new(); - let return_type = DataType::Int32; let result = func.invoke_with_args(ScalarFunctionArgs { args, - arg_fields: arg_fields_owned, + arg_fields: arg_fields_owned.iter().collect::>(), number_rows: 1, return_field: &Field::new("f", DataType::Int32, true), }); From 516d255239b63c4c94c10a78114c821a1b59f05c Mon Sep 17 00:00:00 2001 From: Chongchen Chen Date: Wed, 21 May 2025 23:09:15 +0800 Subject: [PATCH 09/18] fix test --- datafusion/functions/src/core/grouping.rs | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/datafusion/functions/src/core/grouping.rs b/datafusion/functions/src/core/grouping.rs index 459928d0b4db..72466a544666 100644 --- a/datafusion/functions/src/core/grouping.rs +++ b/datafusion/functions/src/core/grouping.rs @@ -236,7 +236,7 @@ mod tests { }; let result = result.as_any().downcast_ref::().unwrap(); - assert_eq!(result.values().to_vec(), vec![1, 2, 3, 3]); + assert_eq!(result.values().to_vec(), vec![2, 1, 3, 0]); Ok(()) } @@ -270,7 +270,7 @@ mod tests { }; let result = result.as_any().downcast_ref::().unwrap(); - assert_eq!(result.values().to_vec(), vec![1, 2, 3, 3]); + assert_eq!(result.values().to_vec(), vec![2, 1, 3, 0]); Ok(()) } @@ -304,7 +304,7 @@ mod tests { }; let result = result.as_any().downcast_ref::().unwrap(); - assert_eq!(result.values().to_vec(), vec![1, 2, 3, 3]); + assert_eq!(result.values().to_vec(), vec![2, 1, 3, 0]); Ok(()) } @@ -338,7 +338,7 @@ mod tests { }; let result = result.as_any().downcast_ref::().unwrap(); - assert_eq!(result.values().to_vec(), vec![1, 2, 3, 3]); + assert_eq!(result.values().to_vec(), vec![2, 1, 3, 0]); Ok(()) } @@ -381,12 +381,6 @@ mod tests { ColumnarValue::Scalar(ScalarValue::Int32(Some(1))), ]; - let arg_fields_owned = args - .iter() - .enumerate() - .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) - .collect::>(); - let arg_fields_owned = args .iter() .enumerate() From 0c7184aa35e5fc0b3a492bd01a328e64f1c9916a Mon Sep 17 00:00:00 2001 From: Chongchen Chen Date: Wed, 21 May 2025 23:12:36 +0800 Subject: [PATCH 10/18] update test --- datafusion/functions/src/core/grouping.rs | 27 ----------------------- 1 file changed, 27 deletions(-) diff --git a/datafusion/functions/src/core/grouping.rs b/datafusion/functions/src/core/grouping.rs index 72466a544666..def157b645d5 100644 --- a/datafusion/functions/src/core/grouping.rs +++ b/datafusion/functions/src/core/grouping.rs @@ -342,33 +342,6 @@ mod tests { Ok(()) } - #[test] - fn test_grouping_with_null_indices() -> Result<()> { - let grouping_id = UInt8Array::from(vec![Some(1), Some(2), Some(3), Some(4)]); - let indices = vec![Some(vec![Some(0), None])]; - - let args = vec![ - ColumnarValue::Array(Arc::new(grouping_id)), - ColumnarValue::Scalar(ScalarValue::List(Arc::new(ListArray::from_iter_primitive::(indices)))), - ]; - - let arg_fields_owned = args - .iter() - .enumerate() - .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) - .collect::>(); - - let func = GroupingFunc::new(); - let result = func.invoke_with_args(ScalarFunctionArgs { - args, - arg_fields: arg_fields_owned.iter().collect::>(), - number_rows: 4, - return_field: &Field::new("f", DataType::Int32, true), - }); - assert!(result.is_err()); - Ok(()) - } - #[test] fn test_grouping_with_invalid_args() -> Result<()> { let grouping_id = UInt8Array::from(vec![Some(1), Some(2), Some(3), Some(4)]); From a9d4ffc9596f4218fdae5962d84b18fa48a7b399 Mon Sep 17 00:00:00 2001 From: Chongchen Chen Date: Wed, 21 May 2025 23:28:17 +0800 Subject: [PATCH 11/18] fmt --- Cargo.lock | 1 - datafusion/functions/src/core/grouping.rs | 85 ++++++++++++------- datafusion/functions/src/core/mod.rs | 2 +- datafusion/optimizer/Cargo.toml | 1 - .../src/analyzer/resolve_grouping_function.rs | 30 ++++--- 5 files changed, 73 insertions(+), 46 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 37b0a4c93946..52010bdfbbd2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2391,7 +2391,6 @@ dependencies = [ "datafusion-expr", "datafusion-functions", "datafusion-functions-aggregate", - "datafusion-functions-nested", "datafusion-functions-window", "datafusion-functions-window-common", "datafusion-physical-expr", diff --git a/datafusion/functions/src/core/grouping.rs b/datafusion/functions/src/core/grouping.rs index def157b645d5..1ed2adb943a6 100644 --- a/datafusion/functions/src/core/grouping.rs +++ b/datafusion/functions/src/core/grouping.rs @@ -15,13 +15,16 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{Array, ArrayRef, AsArray, Int32Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array}; +use arrow::array::{ + Array, ArrayRef, AsArray, Int32Array, UInt16Array, UInt32Array, UInt64Array, + UInt8Array, +}; use arrow::compute::cast; use arrow::datatypes::{DataType, Field, Int32Type}; use datafusion_common::{exec_err, Result}; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarFunctionArgs, - ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_macros::user_doc; use std::any::Any; @@ -33,7 +36,12 @@ macro_rules! grouping_id { ($grouping_id:expr, $indices:expr, $type:ty, $array_type:ty) => {{ let grouping_id = match $grouping_id.as_any().downcast_ref::<$array_type>() { Some(array) => array, - None => return exec_err!("grouping function requires {} grouping_id array", stringify!($type)), + None => { + return exec_err!( + "grouping function requires {} grouping_id array", + stringify!($type) + ) + } }; grouping_id .iter() @@ -139,7 +147,7 @@ impl ScalarUDFImpl for GroupingFunc { return exec_err!( "grouping function requires unsigned integer for first argument, got {}", arg_types[0] - ) + ); } if arg_types.len() == 1 { @@ -160,7 +168,10 @@ impl ScalarUDFImpl for GroupingFunc { ); } - Ok(vec![arg_types[0].clone(), DataType::List(Arc::new(Field::new_list_field(DataType::Int32, false))),]) + Ok(vec![ + arg_types[0].clone(), + DataType::List(Arc::new(Field::new_list_field(DataType::Int32, false))), + ]) } fn documentation(&self) -> Option<&Documentation> { @@ -203,17 +214,24 @@ fn grouping_inner(args: &[ArrayRef]) -> Result { #[cfg(test)] mod tests { use super::*; - use arrow::{array::{Int32Array, ListArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array}, datatypes::Int32Type}; + use arrow::{ + array::{ + Int32Array, ListArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, + }, + datatypes::Int32Type, + }; use datafusion_common::{Result, ScalarValue}; #[test] fn test_grouping_uint8() -> Result<()> { let grouping_id = UInt8Array::from(vec![Some(1), Some(2), Some(3), Some(4)]); let indices = vec![Some(vec![Some(0), Some(1)])]; - + let args = vec![ ColumnarValue::Array(Arc::new(grouping_id)), - ColumnarValue::Scalar(ScalarValue::List(Arc::new(ListArray::from_iter_primitive::(indices.clone())))), + ColumnarValue::Scalar(ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(indices.clone()), + ))), ]; let arg_fields_owned = args @@ -223,13 +241,13 @@ mod tests { .collect::>(); let func = GroupingFunc::new(); - let result = func.invoke_with_args(ScalarFunctionArgs { + let result = func.invoke_with_args(ScalarFunctionArgs { args, number_rows: 4, arg_fields: arg_fields_owned.iter().collect::>(), return_field: &Field::new("f", DataType::Int32, true), })?; - + let result = match result { ColumnarValue::Array(array) => array, _ => panic!("Expected array result"), @@ -244,10 +262,12 @@ mod tests { fn test_grouping_uint16() -> Result<()> { let grouping_id = UInt16Array::from(vec![Some(1), Some(2), Some(3), Some(4)]); let indices = vec![Some(vec![Some(0), Some(1)])]; - + let args = vec![ ColumnarValue::Array(Arc::new(grouping_id)), - ColumnarValue::Scalar(ScalarValue::List(Arc::new(ListArray::from_iter_primitive::(indices)))), + ColumnarValue::Scalar(ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(indices), + ))), ]; let arg_fields_owned = args @@ -257,13 +277,13 @@ mod tests { .collect::>(); let func = GroupingFunc::new(); - let result = func.invoke_with_args(ScalarFunctionArgs { + let result = func.invoke_with_args(ScalarFunctionArgs { args, number_rows: 4, arg_fields: arg_fields_owned.iter().collect::>(), return_field: &Field::new("f", DataType::Int32, true), })?; - + let result = match result { ColumnarValue::Array(array) => array, _ => panic!("Expected array result"), @@ -278,10 +298,12 @@ mod tests { fn test_grouping_uint32() -> Result<()> { let grouping_id = UInt32Array::from(vec![Some(1), Some(2), Some(3), Some(4)]); let indices = vec![Some(vec![Some(0), Some(1)])]; - + let args = vec![ ColumnarValue::Array(Arc::new(grouping_id)), - ColumnarValue::Scalar(ScalarValue::List(Arc::new(ListArray::from_iter_primitive::(indices)))), + ColumnarValue::Scalar(ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(indices), + ))), ]; let arg_fields_owned = args @@ -291,13 +313,13 @@ mod tests { .collect::>(); let func = GroupingFunc::new(); - let result = func.invoke_with_args(ScalarFunctionArgs { + let result = func.invoke_with_args(ScalarFunctionArgs { args, number_rows: 4, arg_fields: arg_fields_owned.iter().collect::>(), return_field: &Field::new("f", DataType::Int32, true), })?; - + let result = match result { ColumnarValue::Array(array) => array, _ => panic!("Expected array result"), @@ -312,10 +334,12 @@ mod tests { fn test_grouping_uint64() -> Result<()> { let grouping_id = UInt64Array::from(vec![Some(1), Some(2), Some(3), Some(4)]); let indices = vec![Some(vec![Some(0), Some(1)])]; - + let args = vec![ ColumnarValue::Array(Arc::new(grouping_id)), - ColumnarValue::Scalar(ScalarValue::List(Arc::new(ListArray::from_iter_primitive::(indices)))), + ColumnarValue::Scalar(ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(indices), + ))), ]; let arg_fields_owned = args @@ -325,13 +349,13 @@ mod tests { .collect::>(); let func = GroupingFunc::new(); - let result = func.invoke_with_args(ScalarFunctionArgs { + let result = func.invoke_with_args(ScalarFunctionArgs { args, arg_fields: arg_fields_owned.iter().collect::>(), number_rows: 4, return_field: &Field::new("f", DataType::Int32, true), })?; - + let result = match result { ColumnarValue::Array(array) => array, _ => panic!("Expected array result"), @@ -346,11 +370,13 @@ mod tests { fn test_grouping_with_invalid_args() -> Result<()> { let grouping_id = UInt8Array::from(vec![Some(1), Some(2), Some(3), Some(4)]); let indices = vec![Some(vec![Some(0)])]; - + // Test with too many arguments let args = vec![ ColumnarValue::Array(Arc::new(grouping_id)), - ColumnarValue::Scalar(ScalarValue::List(Arc::new(ListArray::from_iter_primitive::(indices.clone())))), + ColumnarValue::Scalar(ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(indices.clone()), + ))), ColumnarValue::Scalar(ScalarValue::Int32(Some(1))), ]; @@ -361,7 +387,7 @@ mod tests { .collect::>(); let func = GroupingFunc::new(); - let result = func.invoke_with_args(ScalarFunctionArgs { + let result = func.invoke_with_args(ScalarFunctionArgs { args, arg_fields: arg_fields_owned.iter().collect::>(), number_rows: 4, @@ -372,11 +398,13 @@ mod tests { // Test with invalid array type let args = vec![ ColumnarValue::Array(Arc::new(Int32Array::from(vec![Some(1)]))), - ColumnarValue::Scalar(ScalarValue::List(Arc::new(ListArray::from_iter_primitive::(indices)))), + ColumnarValue::Scalar(ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(indices), + ))), ]; let func = GroupingFunc::new(); - let result = func.invoke_with_args(ScalarFunctionArgs { + let result = func.invoke_with_args(ScalarFunctionArgs { args, arg_fields: arg_fields_owned.iter().collect::>(), number_rows: 1, @@ -386,4 +414,3 @@ mod tests { Ok(()) } } - diff --git a/datafusion/functions/src/core/mod.rs b/datafusion/functions/src/core/mod.rs index 473258b2c4c5..0a1e0b6b0717 100644 --- a/datafusion/functions/src/core/mod.rs +++ b/datafusion/functions/src/core/mod.rs @@ -27,6 +27,7 @@ pub mod expr_ext; pub mod getfield; pub mod greatest; mod greatest_least_utils; +pub mod grouping; pub mod least; pub mod named_struct; pub mod nullif; @@ -38,7 +39,6 @@ pub mod r#struct; pub mod union_extract; pub mod union_tag; pub mod version; -pub mod grouping; // create UDFs make_udf_function!(arrow_cast::ArrowCastFunc, arrow_cast); diff --git a/datafusion/optimizer/Cargo.toml b/datafusion/optimizer/Cargo.toml index 6c98384c9345..cd21a1f7a50b 100644 --- a/datafusion/optimizer/Cargo.toml +++ b/datafusion/optimizer/Cargo.toml @@ -46,7 +46,6 @@ chrono = { workspace = true } datafusion-common = { workspace = true, default-features = true } datafusion-expr = { workspace = true } datafusion-functions = { workspace = true } -datafusion-functions-nested = { workspace = true } datafusion-physical-expr = { workspace = true } indexmap = { workspace = true } itertools = { workspace = true } diff --git a/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs b/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs index dfea88224914..c77309cc2f97 100644 --- a/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs +++ b/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs @@ -23,18 +23,17 @@ use std::sync::Arc; use crate::analyzer::AnalyzerRule; +use arrow::array::ListArray; +use arrow::datatypes::Int32Type; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::{ - internal_datafusion_err, plan_err, Column, DFSchemaRef, Result, ScalarValue, -}; +use datafusion_common::{plan_err, Column, DFSchemaRef, Result, ScalarValue}; use datafusion_expr::expr::{AggregateFunction, Alias}; use datafusion_expr::logical_plan::LogicalPlan; use datafusion_expr::utils::grouping_set_to_exprlist; use datafusion_expr::{Aggregate, Expr, Projection}; -use itertools::Itertools; use datafusion_functions::core::grouping; -use datafusion_functions_nested::make_array::make_array_udf; +use itertools::Itertools; /// Replaces grouping aggregation function with value derived from internal grouping id #[derive(Default, Debug)] @@ -202,14 +201,17 @@ fn grouping_function_on_id( return Ok(grouping().call(vec![grouping_id_column])); } - args.iter().map(|expr| { - group_by_expr.get(expr).map(|group_by_idx| { - Expr::Literal(ScalarValue::from(*group_by_idx as i32)) + let args = args + .iter() + .flat_map(|expr| { + group_by_expr + .get(expr) + .map(|group_by_idx| Some(*group_by_idx as i32)) }) - }).collect::>>() - .and_then(|exprs| { - Some(grouping().call(vec![grouping_id_column, make_array_udf().call(exprs)])) - }).ok_or_else(|| { - internal_datafusion_err!("Grouping sets should contains at least one element") - }) + .collect::>(); + + let indices = Expr::Literal(ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(args)]), + ))); + Ok(grouping().call(vec![grouping_id_column, indices])) } From ee4cf29bb378ac05c0f3d9c7eeb323ac96f96a26 Mon Sep 17 00:00:00 2001 From: Chongchen Chen Date: Thu, 22 May 2025 08:47:53 +0800 Subject: [PATCH 12/18] unparse --- datafusion/sql/Cargo.toml | 2 +- datafusion/sql/src/unparser/utils.rs | 32 +++++++++++++++++++++++++++- 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/datafusion/sql/Cargo.toml b/datafusion/sql/Cargo.toml index b778db46769d..898bcf50f976 100644 --- a/datafusion/sql/Cargo.toml +++ b/datafusion/sql/Cargo.toml @@ -48,6 +48,7 @@ arrow = { workspace = true } bigdecimal = { workspace = true } datafusion-common = { workspace = true, default-features = true } datafusion-expr = { workspace = true } +datafusion-functions-aggregate = { workspace = true } indexmap = { workspace = true } log = { workspace = true } recursive = { workspace = true, optional = true } @@ -57,7 +58,6 @@ sqlparser = { workspace = true } [dev-dependencies] ctor = { workspace = true } datafusion-functions = { workspace = true, default-features = true } -datafusion-functions-aggregate = { workspace = true } datafusion-functions-nested = { workspace = true } datafusion-functions-window = { workspace = true } env_logger = { workspace = true } diff --git a/datafusion/sql/src/unparser/utils.rs b/datafusion/sql/src/unparser/utils.rs index 37f0a7797200..2349e276b985 100644 --- a/datafusion/sql/src/unparser/utils.rs +++ b/datafusion/sql/src/unparser/utils.rs @@ -21,16 +21,18 @@ use super::{ dialect::CharacterLengthStyle, dialect::DateFieldExtractStyle, rewrite::TableAliasRewriter, Unparser, }; +use arrow::array::{Array, Int32Array}; use datafusion_common::{ internal_err, tree_node::{Transformed, TransformedResult, TreeNode}, Column, DataFusionError, Result, ScalarValue, }; use datafusion_expr::{ - expr, utils::grouping_set_to_exprlist, Aggregate, Expr, LogicalPlan, + expr::{self, AggregateFunction}, utils::grouping_set_to_exprlist, Aggregate, Expr, LogicalPlan, LogicalPlanBuilder, Projection, SortExpr, Unnest, Window, }; +use datafusion_functions_aggregate::grouping::grouping_udaf; use indexmap::IndexSet; use sqlparser::ast; use sqlparser::tokenizer::Span; @@ -195,6 +197,34 @@ pub(crate) fn unproject_agg_exprs( agg: &Aggregate, windows: Option<&[&Window]>, ) -> Result { + // replace grouping function + let expr = expr.transform(|sub_expr| { + match sub_expr { + Expr::ScalarFunction(grouping) if grouping.name() == "grouping" => { + if grouping.args.len() != 1 && grouping.args.len() != 2 { + return internal_err!("Grouping function must have one or two arguments"); + } + let grouping_expr = grouping_set_to_exprlist(&agg.group_expr)?; + let args = if grouping.args.len() == 1 { + agg.group_expr.clone() + } else { + if let Expr::Literal(ScalarValue::List(list)) = &grouping.args[1] { + if list.len() != 1 { + return internal_err!("The second argument of grouping function must be a list with exactly one element"); + } + let values = list.value(0).as_any().downcast_ref::().unwrap().values().to_vec(); + values.iter().map(|i: &i32| grouping_expr[*i as usize].clone()).collect() + } else { + return internal_err!("The second argument of grouping function must be a list"); + } + }; + return Ok(Transformed::yes(Expr::AggregateFunction(AggregateFunction::new_udf( + grouping_udaf(), args, false, None, None, None)))); + } + _ => Ok(Transformed::no(sub_expr)) + } + }) + .map(|e| e.data)?; expr.transform(|sub_expr| { if let Expr::Column(c) = sub_expr { if let Some(unprojected_expr) = find_agg_expr(agg, &c)? { From 208f6a607c3d663e2782671077a7d90a94983835 Mon Sep 17 00:00:00 2001 From: Chongchen Chen Date: Fri, 23 May 2025 14:50:25 +0800 Subject: [PATCH 13/18] fix bug --- datafusion/sql/src/unparser/utils.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/sql/src/unparser/utils.rs b/datafusion/sql/src/unparser/utils.rs index 2349e276b985..601b2d42b678 100644 --- a/datafusion/sql/src/unparser/utils.rs +++ b/datafusion/sql/src/unparser/utils.rs @@ -206,7 +206,7 @@ pub(crate) fn unproject_agg_exprs( } let grouping_expr = grouping_set_to_exprlist(&agg.group_expr)?; let args = if grouping.args.len() == 1 { - agg.group_expr.clone() + grouping_expr.iter().map(|e| (*e).clone()).collect() } else { if let Expr::Literal(ScalarValue::List(list)) = &grouping.args[1] { if list.len() != 1 { From fc3d2fa97ad9a34059a8bca43a562b3b91f87c84 Mon Sep 17 00:00:00 2001 From: Chongchen Chen Date: Fri, 23 May 2025 15:20:28 +0800 Subject: [PATCH 14/18] update --- datafusion/core/tests/sql/sql_api.rs | 34 -------------- datafusion/sql/tests/cases/plan_to_sql.rs | 57 +++++++++++++++++++++-- 2 files changed, 52 insertions(+), 39 deletions(-) diff --git a/datafusion/core/tests/sql/sql_api.rs b/datafusion/core/tests/sql/sql_api.rs index 999faae850fb..ec086bcc50c7 100644 --- a/datafusion/core/tests/sql/sql_api.rs +++ b/datafusion/core/tests/sql/sql_api.rs @@ -17,7 +17,6 @@ use datafusion::prelude::*; -use datafusion_sql::unparser::plan_to_sql; use tempfile::TempDir; #[tokio::test] @@ -207,36 +206,3 @@ async fn ddl_can_not_be_planned_by_session_state() { "This feature is not implemented: Unsupported logical plan: DropTable" ); } - -#[tokio::test] -async fn unparse_group_by_with_filter() { - let ctx = SessionContext::new(); - let df = ctx.sql("CREATE TABLE test (c1 VARCHAR,c2 VARCHAR,c3 INT) as values ('a','A',1), ('b','B',2);").await.unwrap(); - let _ = df.collect().await.unwrap(); - - ctx.sql("set datafusion.sql_parser.dialect = 'Postgres';").await.unwrap(); - - let df = ctx.sql(r#"select - c1, - c2, - CASE WHEN grouping(c1) = 1 THEN sum(c3) filter (where c2 = 'A') ELSE NULL END as gx, - grouping(c1) as g0, - grouping(c2) as g1, - grouping(c1, c2) as g2, - grouping(c2, c1) as g3 -from - test -group by - grouping sets ( - (c1, c2), - (c1), - (c2), - () - ) -order by - c1, c2, g0, g1, g2, g3;"#).await.unwrap(); - let plan = df.into_optimized_plan().unwrap(); - let sql = plan_to_sql(&plan).unwrap().to_string(); - println!("!!!{}", sql); - assert_eq!(sql, "SELECT c1,c2,c3 FROM test GROUP BY c1,c2,c3 HAVING c3 > 1;"); -} diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index b541018436d2..932530ae8b92 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -15,19 +15,19 @@ // specific language governing permissions and limitations // under the License. -use arrow::datatypes::{DataType, Field, Schema}; +use arrow::array::ListArray; +use arrow::datatypes::{DataType, Field, Int32Type, Schema}; use datafusion_common::{ - assert_contains, Column, DFSchema, DFSchemaRef, DataFusionError, Result, - TableReference, + assert_contains, Column, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, TableReference }; use datafusion_expr::test::function_stub::{ count_udaf, max_udaf, min_udaf, sum, sum_udaf, }; use datafusion_expr::{ - cast, col, lit, table_scan, wildcard, EmptyRelation, Expr, Extension, LogicalPlan, - LogicalPlanBuilder, Union, UserDefinedLogicalNode, UserDefinedLogicalNodeCore, + cast, col, lit, table_scan, wildcard, Aggregate, EmptyRelation, Expr, Extension, LogicalPlan, LogicalPlanBuilder, Union, UserDefinedLogicalNode, UserDefinedLogicalNodeCore }; use datafusion_functions::unicode; +use datafusion_functions::core::grouping; use datafusion_functions_aggregate::grouping::grouping_udaf; use datafusion_functions_nested::make_array::make_array_udf; use datafusion_functions_nested::map::map_udf; @@ -2516,3 +2516,50 @@ fn test_not_ilike_filter_with_escape() { @"SELECT person.first_name FROM person WHERE person.first_name NOT ILIKE 'A!_%' ESCAPE '!'" ); } + +#[test] +fn test_grouping() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("c1", DataType::Int32, false), + Field::new("c2", DataType::Int32, false), + Field::new("c3", DataType::Int32, false), + ]); + let table_scan = table_scan(Some("test"), &schema, Some(vec![0, 1, 2]))?.build()?; + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate( + vec![ + Expr::GroupingSet(datafusion_expr::GroupingSet::GroupingSets(vec![ + vec![col("c1"), col("c2")], + vec![col("c1")], + vec![col("c2")], + vec![], + ])) + ], + vec![ + sum(col("c3")) + ])? + .build()?; + + let group1 = ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![Some(0)])]), + )); + let group2 = ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![Some(1)])]), + )); + let group3 = ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![Some(0), Some(1)])]), + )); + let project = LogicalPlanBuilder::from(plan) + .project(vec![ + grouping().call(vec![col(Aggregate::INTERNAL_GROUPING_ID), lit(group1)]).alias("grouping(test.c1)"), + grouping().call(vec![col(Aggregate::INTERNAL_GROUPING_ID), lit(group2)]).alias("grouping(test.c2)"), + grouping().call(vec![col(Aggregate::INTERNAL_GROUPING_ID), lit(group3)]).alias("grouping(test.c1,test.c2)"), + ])?.build()?; + let unparser = Unparser::new(&UnparserPostgreSqlDialect {}); + let sql = unparser.plan_to_sql(&project)?; + assert_snapshot!( + sql, + @r#"SELECT grouping("test"."c1") AS "grouping(test.c1)", grouping("test"."c2") AS "grouping(test.c2)", grouping("test"."c1", "test"."c2") AS "grouping(test.c1,test.c2)" FROM "test" GROUP BY GROUPING SETS (("test"."c1", "test"."c2"), ("test"."c1"), ("test"."c2"), ())"# + ); + Ok(()) +} From 12ad681313adb1263f0abaeedd9f556c2b18f284 Mon Sep 17 00:00:00 2001 From: Chongchen Chen Date: Fri, 23 May 2025 15:30:33 +0800 Subject: [PATCH 15/18] clippy --- datafusion/sql/src/unparser/utils.rs | 26 ++++---- datafusion/sql/tests/cases/plan_to_sql.rs | 72 +++++++++++++---------- 2 files changed, 55 insertions(+), 43 deletions(-) diff --git a/datafusion/sql/src/unparser/utils.rs b/datafusion/sql/src/unparser/utils.rs index 601b2d42b678..fb3716b9dc27 100644 --- a/datafusion/sql/src/unparser/utils.rs +++ b/datafusion/sql/src/unparser/utils.rs @@ -28,8 +28,10 @@ use datafusion_common::{ Column, DataFusionError, Result, ScalarValue, }; use datafusion_expr::{ - expr::{self, AggregateFunction}, utils::grouping_set_to_exprlist, Aggregate, Expr, LogicalPlan, - LogicalPlanBuilder, Projection, SortExpr, Unnest, Window, + expr::{self, AggregateFunction}, + utils::grouping_set_to_exprlist, + Aggregate, Expr, LogicalPlan, LogicalPlanBuilder, Projection, SortExpr, Unnest, + Window, }; use datafusion_functions_aggregate::grouping::grouping_udaf; @@ -207,19 +209,17 @@ pub(crate) fn unproject_agg_exprs( let grouping_expr = grouping_set_to_exprlist(&agg.group_expr)?; let args = if grouping.args.len() == 1 { grouping_expr.iter().map(|e| (*e).clone()).collect() - } else { - if let Expr::Literal(ScalarValue::List(list)) = &grouping.args[1] { - if list.len() != 1 { - return internal_err!("The second argument of grouping function must be a list with exactly one element"); - } - let values = list.value(0).as_any().downcast_ref::().unwrap().values().to_vec(); - values.iter().map(|i: &i32| grouping_expr[*i as usize].clone()).collect() - } else { - return internal_err!("The second argument of grouping function must be a list"); + } else if let Expr::Literal(ScalarValue::List(list)) = &grouping.args[1] { + if list.len() != 1 { + return internal_err!("The second argument of grouping function must be a list with exactly one element"); } + let values = list.value(0).as_any().downcast_ref::().unwrap().values().to_vec(); + values.iter().map(|i: &i32| grouping_expr[*i as usize].clone()).collect() + } else { + return internal_err!("The second argument of grouping function must be a list"); }; - return Ok(Transformed::yes(Expr::AggregateFunction(AggregateFunction::new_udf( - grouping_udaf(), args, false, None, None, None)))); + Ok(Transformed::yes(Expr::AggregateFunction(AggregateFunction::new_udf( + grouping_udaf(), args, false, None, None, None)))) } _ => Ok(Transformed::no(sub_expr)) } diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 932530ae8b92..78db923402ea 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -18,16 +18,19 @@ use arrow::array::ListArray; use arrow::datatypes::{DataType, Field, Int32Type, Schema}; use datafusion_common::{ - assert_contains, Column, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, TableReference + assert_contains, Column, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, + TableReference, }; use datafusion_expr::test::function_stub::{ count_udaf, max_udaf, min_udaf, sum, sum_udaf, }; use datafusion_expr::{ - cast, col, lit, table_scan, wildcard, Aggregate, EmptyRelation, Expr, Extension, LogicalPlan, LogicalPlanBuilder, Union, UserDefinedLogicalNode, UserDefinedLogicalNodeCore + cast, col, lit, table_scan, wildcard, Aggregate, EmptyRelation, Expr, Extension, + LogicalPlan, LogicalPlanBuilder, Union, UserDefinedLogicalNode, + UserDefinedLogicalNodeCore, }; -use datafusion_functions::unicode; use datafusion_functions::core::grouping; +use datafusion_functions::unicode; use datafusion_functions_aggregate::grouping::grouping_udaf; use datafusion_functions_nested::make_array::make_array_udf; use datafusion_functions_nested::map::map_udf; @@ -2526,35 +2529,44 @@ fn test_grouping() -> Result<()> { ]); let table_scan = table_scan(Some("test"), &schema, Some(vec![0, 1, 2]))?.build()?; let plan = LogicalPlanBuilder::from(table_scan) - .aggregate( - vec![ - Expr::GroupingSet(datafusion_expr::GroupingSet::GroupingSets(vec![ - vec![col("c1"), col("c2")], - vec![col("c1")], - vec![col("c2")], - vec![], - ])) - ], - vec![ - sum(col("c3")) - ])? - .build()?; + .aggregate( + vec![Expr::GroupingSet( + datafusion_expr::GroupingSet::GroupingSets(vec![ + vec![col("c1"), col("c2")], + vec![col("c1")], + vec![col("c2")], + vec![], + ]), + )], + vec![sum(col("c3"))], + )? + .build()?; - let group1 = ScalarValue::List(Arc::new( - ListArray::from_iter_primitive::(vec![Some(vec![Some(0)])]), - )); - let group2 = ScalarValue::List(Arc::new( - ListArray::from_iter_primitive::(vec![Some(vec![Some(1)])]), - )); - let group3 = ScalarValue::List(Arc::new( - ListArray::from_iter_primitive::(vec![Some(vec![Some(0), Some(1)])]), - )); + let group1 = + ScalarValue::List(Arc::new(ListArray::from_iter_primitive::( + vec![Some(vec![Some(0)])], + ))); + let group2 = + ScalarValue::List(Arc::new(ListArray::from_iter_primitive::( + vec![Some(vec![Some(1)])], + ))); + let group3 = + ScalarValue::List(Arc::new(ListArray::from_iter_primitive::( + vec![Some(vec![Some(0), Some(1)])], + ))); let project = LogicalPlanBuilder::from(plan) - .project(vec![ - grouping().call(vec![col(Aggregate::INTERNAL_GROUPING_ID), lit(group1)]).alias("grouping(test.c1)"), - grouping().call(vec![col(Aggregate::INTERNAL_GROUPING_ID), lit(group2)]).alias("grouping(test.c2)"), - grouping().call(vec![col(Aggregate::INTERNAL_GROUPING_ID), lit(group3)]).alias("grouping(test.c1,test.c2)"), - ])?.build()?; + .project(vec![ + grouping() + .call(vec![col(Aggregate::INTERNAL_GROUPING_ID), lit(group1)]) + .alias("grouping(test.c1)"), + grouping() + .call(vec![col(Aggregate::INTERNAL_GROUPING_ID), lit(group2)]) + .alias("grouping(test.c2)"), + grouping() + .call(vec![col(Aggregate::INTERNAL_GROUPING_ID), lit(group3)]) + .alias("grouping(test.c1,test.c2)"), + ])? + .build()?; let unparser = Unparser::new(&UnparserPostgreSqlDialect {}); let sql = unparser.plan_to_sql(&project)?; assert_snapshot!( From fe431e8ec9027887f0f323c952b1e775053ecac7 Mon Sep 17 00:00:00 2001 From: Chongchen Chen Date: Fri, 23 May 2025 19:47:40 +0800 Subject: [PATCH 16/18] update doc --- docs/source/user-guide/sql/aggregate_functions.md | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/docs/source/user-guide/sql/aggregate_functions.md b/docs/source/user-guide/sql/aggregate_functions.md index 774a4fae6bf3..c2d48dd78fdd 100644 --- a/docs/source/user-guide/sql/aggregate_functions.md +++ b/docs/source/user-guide/sql/aggregate_functions.md @@ -248,7 +248,7 @@ first_value(expression [ORDER BY expression]) ### `grouping` -Returns 1 if the data is aggregated across the specified column, or 0 if it is not aggregated in the result set. +Returns the level of grouping, equals to (grouping(c1) << (n-1)) + (grouping(c2) << (n-2)) + … + grouping(cn). ```sql grouping(expression) @@ -264,13 +264,13 @@ grouping(expression) > SELECT column_name, GROUPING(column_name) AS group_column FROM table_name GROUP BY GROUPING SETS ((column_name), ()); -+-------------+-------------+ ++-------------+--------------+ | column_name | group_column | -+-------------+-------------+ -| value1 | 0 | -| value2 | 0 | -| NULL | 1 | -+-------------+-------------+ ++-------------+--------------+ +| value1 | 0 | +| value2 | 0 | +| NULL | 1 | ++-------------+--------------+ ``` ### `last_value` From 28103dbf6e1d4b67d3a755b5ab51c5dea57326c4 Mon Sep 17 00:00:00 2001 From: Chongchen Chen Date: Fri, 23 May 2025 20:13:46 +0800 Subject: [PATCH 17/18] update doc --- datafusion/functions-aggregate/src/grouping.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/functions-aggregate/src/grouping.rs b/datafusion/functions-aggregate/src/grouping.rs index 1f869bdcf4b6..4cbb02c5373f 100644 --- a/datafusion/functions-aggregate/src/grouping.rs +++ b/datafusion/functions-aggregate/src/grouping.rs @@ -35,13 +35,13 @@ make_udaf_expr_and_func!( Grouping, grouping, expression, - "Returns the level of grouping, equals to (grouping(c1) << (n-1)) + (grouping(c2) << (n-2)) + … + grouping(cn)", + "Returns the level of grouping, equals to (grouping(c1) << (n-1)) + (grouping(c2) << (n-2)) + … + grouping(cn).", grouping_udaf ); #[user_doc( doc_section(label = "General Functions"), - description = "Returns the level of grouping, equals to (grouping(c1) << (n-1)) + (grouping(c2) << (n-2)) + … + grouping(cn)", + description = "Returns the level of grouping, equals to (grouping(c1) << (n-1)) + (grouping(c2) << (n-2)) + … + grouping(cn).", syntax_example = "grouping(expression)", sql_example = r#"```sql > SELECT column_name, GROUPING(column_name) AS group_column From 3373cd0f7e06e8ebe72b2a7cd1273a88de016663 Mon Sep 17 00:00:00 2001 From: Chongchen Chen Date: Sun, 15 Jun 2025 14:22:21 +0800 Subject: [PATCH 18/18] update --- datafusion/functions/src/core/grouping.rs | 42 +++++++++++++------ .../src/analyzer/resolve_grouping_function.rs | 21 +++------- datafusion/sql/src/unparser/utils.rs | 2 +- 3 files changed, 37 insertions(+), 28 deletions(-) diff --git a/datafusion/functions/src/core/grouping.rs b/datafusion/functions/src/core/grouping.rs index 1ed2adb943a6..6e58997a4e9a 100644 --- a/datafusion/functions/src/core/grouping.rs +++ b/datafusion/functions/src/core/grouping.rs @@ -244,8 +244,11 @@ mod tests { let result = func.invoke_with_args(ScalarFunctionArgs { args, number_rows: 4, - arg_fields: arg_fields_owned.iter().collect::>(), - return_field: &Field::new("f", DataType::Int32, true), + arg_fields: arg_fields_owned + .iter() + .map(|f| Arc::new(f.clone())) + .collect::>(), + return_field: Arc::new(Field::new("f", DataType::Int32, true)), })?; let result = match result { @@ -280,8 +283,11 @@ mod tests { let result = func.invoke_with_args(ScalarFunctionArgs { args, number_rows: 4, - arg_fields: arg_fields_owned.iter().collect::>(), - return_field: &Field::new("f", DataType::Int32, true), + arg_fields: arg_fields_owned + .iter() + .map(|f| Arc::new(f.clone())) + .collect::>(), + return_field: Arc::new(Field::new("f", DataType::Int32, true)), })?; let result = match result { @@ -316,8 +322,11 @@ mod tests { let result = func.invoke_with_args(ScalarFunctionArgs { args, number_rows: 4, - arg_fields: arg_fields_owned.iter().collect::>(), - return_field: &Field::new("f", DataType::Int32, true), + arg_fields: arg_fields_owned + .iter() + .map(|f| Arc::new(f.clone())) + .collect::>(), + return_field: Arc::new(Field::new("f", DataType::Int32, true)), })?; let result = match result { @@ -351,9 +360,12 @@ mod tests { let func = GroupingFunc::new(); let result = func.invoke_with_args(ScalarFunctionArgs { args, - arg_fields: arg_fields_owned.iter().collect::>(), + arg_fields: arg_fields_owned + .iter() + .map(|f| Arc::new(f.clone())) + .collect::>(), number_rows: 4, - return_field: &Field::new("f", DataType::Int32, true), + return_field: Arc::new(Field::new("f", DataType::Int32, true)), })?; let result = match result { @@ -389,9 +401,12 @@ mod tests { let func = GroupingFunc::new(); let result = func.invoke_with_args(ScalarFunctionArgs { args, - arg_fields: arg_fields_owned.iter().collect::>(), + arg_fields: arg_fields_owned + .iter() + .map(|f| Arc::new(f.clone())) + .collect::>(), number_rows: 4, - return_field: &Field::new("f", DataType::Int32, true), + return_field: Arc::new(Field::new("f", DataType::Int32, true)), }); assert!(result.is_err()); @@ -406,9 +421,12 @@ mod tests { let func = GroupingFunc::new(); let result = func.invoke_with_args(ScalarFunctionArgs { args, - arg_fields: arg_fields_owned.iter().collect::>(), + arg_fields: arg_fields_owned + .iter() + .map(|f| Arc::new(f.clone())) + .collect::>(), number_rows: 1, - return_field: &Field::new("f", DataType::Int32, true), + return_field: Arc::new(Field::new("f", DataType::Int32, true)), }); assert!(result.is_err()); Ok(()) diff --git a/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs b/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs index 0d510ec49836..46177884a26d 100644 --- a/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs +++ b/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs @@ -189,18 +189,6 @@ fn grouping_function_on_id( } let group_by_expr_count = group_by_expr.len(); - let literal = |value: usize| { - if group_by_expr_count < 8 { - Expr::Literal(ScalarValue::from(value as u8), None) - } else if group_by_expr_count < 16 { - Expr::Literal(ScalarValue::from(value as u16), None) - } else if group_by_expr_count < 32 { - Expr::Literal(ScalarValue::from(value as u32), None) - } else { - Expr::Literal(ScalarValue::from(value as u64), None) - } - }; - let grouping_id_column = Expr::Column(Column::from(Aggregate::INTERNAL_GROUPING_ID)); // The grouping call is exactly our internal grouping id if args.len() == group_by_expr_count @@ -222,8 +210,11 @@ fn grouping_function_on_id( }) .collect::>(); - let indices = Expr::Literal(ScalarValue::List(Arc::new( - ListArray::from_iter_primitive::(vec![Some(args)]), - ))); + let indices = Expr::Literal( + ScalarValue::List(Arc::new(ListArray::from_iter_primitive::( + vec![Some(args)], + ))), + None, + ); Ok(grouping().call(vec![grouping_id_column, indices])) } diff --git a/datafusion/sql/src/unparser/utils.rs b/datafusion/sql/src/unparser/utils.rs index 4c84f326f826..f3613320c67d 100644 --- a/datafusion/sql/src/unparser/utils.rs +++ b/datafusion/sql/src/unparser/utils.rs @@ -209,7 +209,7 @@ pub(crate) fn unproject_agg_exprs( let grouping_expr = grouping_set_to_exprlist(&agg.group_expr)?; let args = if grouping.args.len() == 1 { grouping_expr.iter().map(|e| (*e).clone()).collect() - } else if let Expr::Literal(ScalarValue::List(list)) = &grouping.args[1] { + } else if let Expr::Literal(ScalarValue::List(list), _) = &grouping.args[1] { if list.len() != 1 { return internal_err!("The second argument of grouping function must be a list with exactly one element"); }