diff --git a/Cargo.lock b/Cargo.lock index a90e3d365c06..d016500667db 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2396,6 +2396,7 @@ dependencies = [ "ctor", "datafusion-common", "datafusion-expr", + "datafusion-functions", "datafusion-functions-aggregate", "datafusion-functions-window", "datafusion-functions-window-common", diff --git a/datafusion/functions-aggregate/src/grouping.rs b/datafusion/functions-aggregate/src/grouping.rs index 0727cf33036a..e8f931d5a7f8 100644 --- a/datafusion/functions-aggregate/src/grouping.rs +++ b/datafusion/functions-aggregate/src/grouping.rs @@ -35,25 +35,25 @@ make_udaf_expr_and_func!( Grouping, grouping, expression, - "Returns 1 if the data is aggregated across the specified column or 0 for not aggregated in the result set.", + "Returns the level of grouping, equals to (grouping(c1) << (n-1)) + (grouping(c2) << (n-2)) + … + grouping(cn).", grouping_udaf ); #[user_doc( doc_section(label = "General Functions"), - description = "Returns 1 if the data is aggregated across the specified column, or 0 if it is not aggregated in the result set.", + description = "Returns the level of grouping, equals to (grouping(c1) << (n-1)) + (grouping(c2) << (n-2)) + … + grouping(cn).", syntax_example = "grouping(expression)", sql_example = r#"```sql > SELECT column_name, GROUPING(column_name) AS group_column FROM table_name GROUP BY GROUPING SETS ((column_name), ()); -+-------------+-------------+ ++-------------+--------------+ | column_name | group_column | -+-------------+-------------+ -| value1 | 0 | -| value2 | 0 | -| NULL | 1 | -+-------------+-------------+ ++-------------+--------------+ +| value1 | 0 | +| value2 | 0 | +| NULL | 1 | ++-------------+--------------+ ```"#, argument( name = "expression", diff --git a/datafusion/functions/src/core/grouping.rs b/datafusion/functions/src/core/grouping.rs new file mode 100644 index 000000000000..6e58997a4e9a --- /dev/null +++ b/datafusion/functions/src/core/grouping.rs @@ -0,0 +1,434 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ + Array, ArrayRef, AsArray, Int32Array, UInt16Array, UInt32Array, UInt64Array, + UInt8Array, +}; +use arrow::compute::cast; +use arrow::datatypes::{DataType, Field, Int32Type}; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, +}; +use datafusion_macros::user_doc; +use std::any::Any; +use std::sync::Arc; + +use crate::utils::make_scalar_function; + +macro_rules! grouping_id { + ($grouping_id:expr, $indices:expr, $type:ty, $array_type:ty) => {{ + let grouping_id = match $grouping_id.as_any().downcast_ref::<$array_type>() { + Some(array) => array, + None => { + return exec_err!( + "grouping function requires {} grouping_id array", + stringify!($type) + ) + } + }; + grouping_id + .iter() + .zip($indices.iter()) + .map(|(grouping_id, indices)| { + grouping_id.map(|grouping_id| { + let mut result = 0 as $type; + match indices { + Some(indices) => { + for index in indices.as_primitive::().iter() { + if let Some(index) = index { + let bit = (grouping_id >> index) & 1; + result = (result << 1) | bit; + } + } + } + None => { + result = grouping_id; + } + } + result as i32 + }) + }) + .collect() + }}; +} + +#[user_doc( + doc_section(label = "Other Functions"), + description = "Developer API: Returns the level of grouping, equals to (((grouping_id >> array[0]) & 1) << (n-1)) + (((grouping_id >> array[1]) & 1) << (n-2)) + ... + (((grouping_id >> array[n-1]) & 1) << 0). Returns grouping_id if indices is not provided.", + syntax_example = "grouping(grouping_id[, indices])", + sql_example = r#"```sql +> SELECT grouping(__grouping_id, make_array(0)) FROM table GROUP BY GROUPING SETS ((a), (b)); ++----------------+ +| grouping | ++----------------+ +| 1 | +| 0 | ++----------------+ +```"#, + argument( + name = "grouping_id", + description = "The internal grouping ID column (UInt8/16/32/64)" + ), + argument( + name = "indices", + description = "The indices of the column in the grouping set (Int32)" + ) +)] +#[derive(Debug)] +pub struct GroupingFunc { + signature: Signature, +} + +impl Default for GroupingFunc { + fn default() -> Self { + GroupingFunc::new() + } +} + +impl GroupingFunc { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for GroupingFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "grouping" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Int32) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(grouping_inner, vec![])(&args.args) + } + + fn short_circuits(&self) -> bool { + false + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + if arg_types.len() != 2 && arg_types.len() != 1 { + return exec_err!( + "grouping function requires 1 or 2 arguments, got {}", + arg_types.len() + ); + } + + if !arg_types[0].is_unsigned_integer() { + return exec_err!( + "grouping function requires unsigned integer for first argument, got {}", + arg_types[0] + ); + } + + if arg_types.len() == 1 { + return Ok(vec![arg_types[0].clone()]); + } + + let DataType::List(field) = &arg_types[1] else { + return exec_err!( + "grouping function requires list for second argument, got {}", + arg_types[1] + ); + }; + + if !field.data_type().is_integer() { + return exec_err!( + "grouping function requires list of integers for second argument, got {}", + arg_types[1] + ); + } + + Ok(vec![ + arg_types[0].clone(), + DataType::List(Arc::new(Field::new_list_field(DataType::Int32, false))), + ]) + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} + +fn grouping_inner(args: &[ArrayRef]) -> Result { + if args.len() != 2 && args.len() != 1 { + return exec_err!( + "grouping function requires 1 or 2 arguments, got {}", + args.len() + ); + } + + if args.len() == 1 { + return cast(&args[0], &DataType::Int32).map_err(|e| e.into()); + } + + let grouping_id = &args[0]; + let indices = &args[1]; + let indices = indices.as_list::(); + + let result: Int32Array = match grouping_id.data_type() { + DataType::UInt8 => grouping_id!(grouping_id, indices, u8, UInt8Array), + DataType::UInt16 => grouping_id!(grouping_id, indices, u16, UInt16Array), + DataType::UInt32 => grouping_id!(grouping_id, indices, u32, UInt32Array), + DataType::UInt64 => grouping_id!(grouping_id, indices, u64, UInt64Array), + _ => { + return exec_err!( + "grouping function requires UInt8/16/32/64 for grouping_id, got {}", + grouping_id.data_type() + ) + } + }; + + Ok(Arc::new(result)) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::{ + array::{ + Int32Array, ListArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, + }, + datatypes::Int32Type, + }; + use datafusion_common::{Result, ScalarValue}; + + #[test] + fn test_grouping_uint8() -> Result<()> { + let grouping_id = UInt8Array::from(vec![Some(1), Some(2), Some(3), Some(4)]); + let indices = vec![Some(vec![Some(0), Some(1)])]; + + let args = vec![ + ColumnarValue::Array(Arc::new(grouping_id)), + ColumnarValue::Scalar(ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(indices.clone()), + ))), + ]; + + let arg_fields_owned = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + + let func = GroupingFunc::new(); + let result = func.invoke_with_args(ScalarFunctionArgs { + args, + number_rows: 4, + arg_fields: arg_fields_owned + .iter() + .map(|f| Arc::new(f.clone())) + .collect::>(), + return_field: Arc::new(Field::new("f", DataType::Int32, true)), + })?; + + let result = match result { + ColumnarValue::Array(array) => array, + _ => panic!("Expected array result"), + }; + + let result = result.as_any().downcast_ref::().unwrap(); + assert_eq!(result.values().to_vec(), vec![2, 1, 3, 0]); + Ok(()) + } + + #[test] + fn test_grouping_uint16() -> Result<()> { + let grouping_id = UInt16Array::from(vec![Some(1), Some(2), Some(3), Some(4)]); + let indices = vec![Some(vec![Some(0), Some(1)])]; + + let args = vec![ + ColumnarValue::Array(Arc::new(grouping_id)), + ColumnarValue::Scalar(ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(indices), + ))), + ]; + + let arg_fields_owned = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + + let func = GroupingFunc::new(); + let result = func.invoke_with_args(ScalarFunctionArgs { + args, + number_rows: 4, + arg_fields: arg_fields_owned + .iter() + .map(|f| Arc::new(f.clone())) + .collect::>(), + return_field: Arc::new(Field::new("f", DataType::Int32, true)), + })?; + + let result = match result { + ColumnarValue::Array(array) => array, + _ => panic!("Expected array result"), + }; + + let result = result.as_any().downcast_ref::().unwrap(); + assert_eq!(result.values().to_vec(), vec![2, 1, 3, 0]); + Ok(()) + } + + #[test] + fn test_grouping_uint32() -> Result<()> { + let grouping_id = UInt32Array::from(vec![Some(1), Some(2), Some(3), Some(4)]); + let indices = vec![Some(vec![Some(0), Some(1)])]; + + let args = vec![ + ColumnarValue::Array(Arc::new(grouping_id)), + ColumnarValue::Scalar(ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(indices), + ))), + ]; + + let arg_fields_owned = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + + let func = GroupingFunc::new(); + let result = func.invoke_with_args(ScalarFunctionArgs { + args, + number_rows: 4, + arg_fields: arg_fields_owned + .iter() + .map(|f| Arc::new(f.clone())) + .collect::>(), + return_field: Arc::new(Field::new("f", DataType::Int32, true)), + })?; + + let result = match result { + ColumnarValue::Array(array) => array, + _ => panic!("Expected array result"), + }; + + let result = result.as_any().downcast_ref::().unwrap(); + assert_eq!(result.values().to_vec(), vec![2, 1, 3, 0]); + Ok(()) + } + + #[test] + fn test_grouping_uint64() -> Result<()> { + let grouping_id = UInt64Array::from(vec![Some(1), Some(2), Some(3), Some(4)]); + let indices = vec![Some(vec![Some(0), Some(1)])]; + + let args = vec![ + ColumnarValue::Array(Arc::new(grouping_id)), + ColumnarValue::Scalar(ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(indices), + ))), + ]; + + let arg_fields_owned = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + + let func = GroupingFunc::new(); + let result = func.invoke_with_args(ScalarFunctionArgs { + args, + arg_fields: arg_fields_owned + .iter() + .map(|f| Arc::new(f.clone())) + .collect::>(), + number_rows: 4, + return_field: Arc::new(Field::new("f", DataType::Int32, true)), + })?; + + let result = match result { + ColumnarValue::Array(array) => array, + _ => panic!("Expected array result"), + }; + + let result = result.as_any().downcast_ref::().unwrap(); + assert_eq!(result.values().to_vec(), vec![2, 1, 3, 0]); + Ok(()) + } + + #[test] + fn test_grouping_with_invalid_args() -> Result<()> { + let grouping_id = UInt8Array::from(vec![Some(1), Some(2), Some(3), Some(4)]); + let indices = vec![Some(vec![Some(0)])]; + + // Test with too many arguments + let args = vec![ + ColumnarValue::Array(Arc::new(grouping_id)), + ColumnarValue::Scalar(ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(indices.clone()), + ))), + ColumnarValue::Scalar(ScalarValue::Int32(Some(1))), + ]; + + let arg_fields_owned = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + + let func = GroupingFunc::new(); + let result = func.invoke_with_args(ScalarFunctionArgs { + args, + arg_fields: arg_fields_owned + .iter() + .map(|f| Arc::new(f.clone())) + .collect::>(), + number_rows: 4, + return_field: Arc::new(Field::new("f", DataType::Int32, true)), + }); + assert!(result.is_err()); + + // Test with invalid array type + let args = vec![ + ColumnarValue::Array(Arc::new(Int32Array::from(vec![Some(1)]))), + ColumnarValue::Scalar(ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(indices), + ))), + ]; + + let func = GroupingFunc::new(); + let result = func.invoke_with_args(ScalarFunctionArgs { + args, + arg_fields: arg_fields_owned + .iter() + .map(|f| Arc::new(f.clone())) + .collect::>(), + number_rows: 1, + return_field: Arc::new(Field::new("f", DataType::Int32, true)), + }); + assert!(result.is_err()); + Ok(()) + } +} diff --git a/datafusion/functions/src/core/mod.rs b/datafusion/functions/src/core/mod.rs index db080cd62847..0a1e0b6b0717 100644 --- a/datafusion/functions/src/core/mod.rs +++ b/datafusion/functions/src/core/mod.rs @@ -27,6 +27,7 @@ pub mod expr_ext; pub mod getfield; pub mod greatest; mod greatest_least_utils; +pub mod grouping; pub mod least; pub mod named_struct; pub mod nullif; @@ -55,6 +56,7 @@ make_udf_function!(least::LeastFunc, least); make_udf_function!(union_extract::UnionExtractFun, union_extract); make_udf_function!(union_tag::UnionTagFunc, union_tag); make_udf_function!(version::VersionFunc, version); +make_udf_function!(grouping::GroupingFunc, grouping); pub mod expr_fn { use datafusion_expr::{Expr, Literal}; diff --git a/datafusion/optimizer/Cargo.toml b/datafusion/optimizer/Cargo.toml index 60358d20e2a1..cd21a1f7a50b 100644 --- a/datafusion/optimizer/Cargo.toml +++ b/datafusion/optimizer/Cargo.toml @@ -45,6 +45,7 @@ arrow = { workspace = true } chrono = { workspace = true } datafusion-common = { workspace = true, default-features = true } datafusion-expr = { workspace = true } +datafusion-functions = { workspace = true } datafusion-physical-expr = { workspace = true } indexmap = { workspace = true } itertools = { workspace = true } diff --git a/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs b/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs index fa7ff1b8b19d..46177884a26d 100644 --- a/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs +++ b/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs @@ -18,25 +18,21 @@ //! Analyzed rule to replace TableScan references //! such as DataFrames and Views and inlines the LogicalPlan. -use std::cmp::Ordering; use std::collections::HashMap; use std::sync::Arc; use crate::analyzer::AnalyzerRule; -use arrow::datatypes::DataType; +use arrow::array::ListArray; +use arrow::datatypes::Int32Type; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::{ - internal_datafusion_err, plan_err, Column, DFSchemaRef, Result, ScalarValue, -}; +use datafusion_common::{plan_err, Column, DFSchemaRef, Result, ScalarValue}; use datafusion_expr::expr::{AggregateFunction, Alias}; use datafusion_expr::logical_plan::LogicalPlan; use datafusion_expr::utils::grouping_set_to_exprlist; -use datafusion_expr::{ - bitwise_and, bitwise_or, bitwise_shift_left, bitwise_shift_right, cast, Aggregate, - Expr, Projection, -}; +use datafusion_expr::{Aggregate, Expr, Projection}; +use datafusion_functions::core::grouping; use itertools::Itertools; /// Replaces grouping aggregation function with value derived from internal grouping id @@ -193,18 +189,6 @@ fn grouping_function_on_id( } let group_by_expr_count = group_by_expr.len(); - let literal = |value: usize| { - if group_by_expr_count < 8 { - Expr::Literal(ScalarValue::from(value as u8), None) - } else if group_by_expr_count < 16 { - Expr::Literal(ScalarValue::from(value as u16), None) - } else if group_by_expr_count < 32 { - Expr::Literal(ScalarValue::from(value as u32), None) - } else { - Expr::Literal(ScalarValue::from(value as u64), None) - } - }; - let grouping_id_column = Expr::Column(Column::from(Aggregate::INTERNAL_GROUPING_ID)); // The grouping call is exactly our internal grouping id if args.len() == group_by_expr_count @@ -214,35 +198,23 @@ fn grouping_function_on_id( .enumerate() .all(|(idx, expr)| group_by_expr.get(expr) == Some(&idx)) { - return Ok(cast(grouping_id_column, DataType::Int32)); + return Ok(grouping().call(vec![grouping_id_column])); } - args.iter() - .rev() - .enumerate() - .map(|(arg_idx, expr)| { - group_by_expr.get(expr).map(|group_by_idx| { - let group_by_bit = - bitwise_and(grouping_id_column.clone(), literal(1 << group_by_idx)); - match group_by_idx.cmp(&arg_idx) { - Ordering::Less => { - bitwise_shift_left(group_by_bit, literal(arg_idx - group_by_idx)) - } - Ordering::Greater => { - bitwise_shift_right(group_by_bit, literal(group_by_idx - arg_idx)) - } - Ordering::Equal => group_by_bit, - } - }) - }) - .collect::>>() - .and_then(|bit_exprs| { - bit_exprs - .into_iter() - .reduce(bitwise_or) - .map(|expr| cast(expr, DataType::Int32)) - }) - .ok_or_else(|| { - internal_datafusion_err!("Grouping sets should contains at least one element") + let args = args + .iter() + .flat_map(|expr| { + group_by_expr + .get(expr) + .map(|group_by_idx| Some(*group_by_idx as i32)) }) + .collect::>(); + + let indices = Expr::Literal( + ScalarValue::List(Arc::new(ListArray::from_iter_primitive::( + vec![Some(args)], + ))), + None, + ); + Ok(grouping().call(vec![grouping_id_column, indices])) } diff --git a/datafusion/sql/Cargo.toml b/datafusion/sql/Cargo.toml index b778db46769d..898bcf50f976 100644 --- a/datafusion/sql/Cargo.toml +++ b/datafusion/sql/Cargo.toml @@ -48,6 +48,7 @@ arrow = { workspace = true } bigdecimal = { workspace = true } datafusion-common = { workspace = true, default-features = true } datafusion-expr = { workspace = true } +datafusion-functions-aggregate = { workspace = true } indexmap = { workspace = true } log = { workspace = true } recursive = { workspace = true, optional = true } @@ -57,7 +58,6 @@ sqlparser = { workspace = true } [dev-dependencies] ctor = { workspace = true } datafusion-functions = { workspace = true, default-features = true } -datafusion-functions-aggregate = { workspace = true } datafusion-functions-nested = { workspace = true } datafusion-functions-window = { workspace = true } env_logger = { workspace = true } diff --git a/datafusion/sql/src/unparser/utils.rs b/datafusion/sql/src/unparser/utils.rs index 89fa392c183f..f3613320c67d 100644 --- a/datafusion/sql/src/unparser/utils.rs +++ b/datafusion/sql/src/unparser/utils.rs @@ -21,16 +21,20 @@ use super::{ dialect::CharacterLengthStyle, dialect::DateFieldExtractStyle, rewrite::TableAliasRewriter, Unparser, }; +use arrow::array::{Array, Int32Array}; use datafusion_common::{ internal_err, tree_node::{Transformed, TransformedResult, TreeNode}, Column, DataFusionError, Result, ScalarValue, }; use datafusion_expr::{ - expr, utils::grouping_set_to_exprlist, Aggregate, Expr, LogicalPlan, - LogicalPlanBuilder, Projection, SortExpr, Unnest, Window, + expr::{self, AggregateFunction}, + utils::grouping_set_to_exprlist, + Aggregate, Expr, LogicalPlan, LogicalPlanBuilder, Projection, SortExpr, Unnest, + Window, }; +use datafusion_functions_aggregate::grouping::grouping_udaf; use indexmap::IndexSet; use sqlparser::ast; use sqlparser::tokenizer::Span; @@ -195,6 +199,32 @@ pub(crate) fn unproject_agg_exprs( agg: &Aggregate, windows: Option<&[&Window]>, ) -> Result { + // replace grouping function + let expr = expr.transform(|sub_expr| { + match sub_expr { + Expr::ScalarFunction(grouping) if grouping.name() == "grouping" => { + if grouping.args.len() != 1 && grouping.args.len() != 2 { + return internal_err!("Grouping function must have one or two arguments"); + } + let grouping_expr = grouping_set_to_exprlist(&agg.group_expr)?; + let args = if grouping.args.len() == 1 { + grouping_expr.iter().map(|e| (*e).clone()).collect() + } else if let Expr::Literal(ScalarValue::List(list), _) = &grouping.args[1] { + if list.len() != 1 { + return internal_err!("The second argument of grouping function must be a list with exactly one element"); + } + let values = list.value(0).as_any().downcast_ref::().unwrap().values().to_vec(); + values.iter().map(|i: &i32| grouping_expr[*i as usize].clone()).collect() + } else { + return internal_err!("The second argument of grouping function must be a list"); + }; + Ok(Transformed::yes(Expr::AggregateFunction(AggregateFunction::new_udf( + grouping_udaf(), args, false, None, None, None)))) + } + _ => Ok(Transformed::no(sub_expr)) + } + }) + .map(|e| e.data)?; expr.transform(|sub_expr| { if let Expr::Column(c) = sub_expr { if let Some(unprojected_expr) = find_agg_expr(agg, &c)? { diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index b4697c2fe473..100ab40e7e7f 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -15,18 +15,21 @@ // specific language governing permissions and limitations // under the License. -use arrow::datatypes::{DataType, Field, Schema}; +use arrow::array::ListArray; +use arrow::datatypes::{DataType, Field, Int32Type, Schema}; use datafusion_common::{ - assert_contains, Column, DFSchema, DFSchemaRef, DataFusionError, Result, + assert_contains, Column, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, TableReference, }; use datafusion_expr::test::function_stub::{ count_udaf, max_udaf, min_udaf, sum, sum_udaf, }; use datafusion_expr::{ - cast, col, lit, table_scan, wildcard, EmptyRelation, Expr, Extension, LogicalPlan, - LogicalPlanBuilder, Union, UserDefinedLogicalNode, UserDefinedLogicalNodeCore, + cast, col, lit, table_scan, wildcard, Aggregate, EmptyRelation, Expr, Extension, + LogicalPlan, LogicalPlanBuilder, Union, UserDefinedLogicalNode, + UserDefinedLogicalNodeCore, }; +use datafusion_functions::core::grouping; use datafusion_functions::unicode; use datafusion_functions_aggregate::grouping::grouping_udaf; use datafusion_functions_nested::make_array::make_array_udf; @@ -2561,3 +2564,59 @@ fn test_not_ilike_filter_with_escape() { @"SELECT person.first_name FROM person WHERE person.first_name NOT ILIKE 'A!_%' ESCAPE '!'" ); } + +#[test] +fn test_grouping() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("c1", DataType::Int32, false), + Field::new("c2", DataType::Int32, false), + Field::new("c3", DataType::Int32, false), + ]); + let table_scan = table_scan(Some("test"), &schema, Some(vec![0, 1, 2]))?.build()?; + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate( + vec![Expr::GroupingSet( + datafusion_expr::GroupingSet::GroupingSets(vec![ + vec![col("c1"), col("c2")], + vec![col("c1")], + vec![col("c2")], + vec![], + ]), + )], + vec![sum(col("c3"))], + )? + .build()?; + + let group1 = + ScalarValue::List(Arc::new(ListArray::from_iter_primitive::( + vec![Some(vec![Some(0)])], + ))); + let group2 = + ScalarValue::List(Arc::new(ListArray::from_iter_primitive::( + vec![Some(vec![Some(1)])], + ))); + let group3 = + ScalarValue::List(Arc::new(ListArray::from_iter_primitive::( + vec![Some(vec![Some(0), Some(1)])], + ))); + let project = LogicalPlanBuilder::from(plan) + .project(vec![ + grouping() + .call(vec![col(Aggregate::INTERNAL_GROUPING_ID), lit(group1)]) + .alias("grouping(test.c1)"), + grouping() + .call(vec![col(Aggregate::INTERNAL_GROUPING_ID), lit(group2)]) + .alias("grouping(test.c2)"), + grouping() + .call(vec![col(Aggregate::INTERNAL_GROUPING_ID), lit(group3)]) + .alias("grouping(test.c1,test.c2)"), + ])? + .build()?; + let unparser = Unparser::new(&UnparserPostgreSqlDialect {}); + let sql = unparser.plan_to_sql(&project)?; + assert_snapshot!( + sql, + @r#"SELECT grouping("test"."c1") AS "grouping(test.c1)", grouping("test"."c2") AS "grouping(test.c2)", grouping("test"."c1", "test"."c2") AS "grouping(test.c1,test.c2)" FROM "test" GROUP BY GROUPING SETS (("test"."c1", "test"."c2"), ("test"."c1"), ("test"."c2"), ())"# + ); + Ok(()) +} diff --git a/docs/source/user-guide/sql/aggregate_functions.md b/docs/source/user-guide/sql/aggregate_functions.md index 774a4fae6bf3..c2d48dd78fdd 100644 --- a/docs/source/user-guide/sql/aggregate_functions.md +++ b/docs/source/user-guide/sql/aggregate_functions.md @@ -248,7 +248,7 @@ first_value(expression [ORDER BY expression]) ### `grouping` -Returns 1 if the data is aggregated across the specified column, or 0 if it is not aggregated in the result set. +Returns the level of grouping, equals to (grouping(c1) << (n-1)) + (grouping(c2) << (n-2)) + … + grouping(cn). ```sql grouping(expression) @@ -264,13 +264,13 @@ grouping(expression) > SELECT column_name, GROUPING(column_name) AS group_column FROM table_name GROUP BY GROUPING SETS ((column_name), ()); -+-------------+-------------+ ++-------------+--------------+ | column_name | group_column | -+-------------+-------------+ -| value1 | 0 | -| value2 | 0 | -| NULL | 1 | -+-------------+-------------+ ++-------------+--------------+ +| value1 | 0 | +| value2 | 0 | +| NULL | 1 | ++-------------+--------------+ ``` ### `last_value`