diff --git a/datafusion-examples/examples/expr_api.rs b/datafusion-examples/examples/expr_api.rs index 0082ed6eb9a9..591f6ac3de95 100644 --- a/datafusion-examples/examples/expr_api.rs +++ b/datafusion-examples/examples/expr_api.rs @@ -24,6 +24,7 @@ use arrow::record_batch::RecordBatch; use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit}; use datafusion::common::DFSchema; use datafusion::error::Result; +use datafusion::functions_aggregate::first_last::first_value_udaf; use datafusion::optimizer::simplify_expressions::ExprSimplifier; use datafusion::physical_expr::{analyze, AnalysisContext, ExprBoundaries}; use datafusion::prelude::*; @@ -32,7 +33,7 @@ use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::expr::BinaryExpr; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::simplify::SimplifyContext; -use datafusion_expr::{ColumnarValue, ExprSchemable, Operator}; +use datafusion_expr::{AggregateExt, ColumnarValue, ExprSchemable, Operator}; /// This example demonstrates the DataFusion [`Expr`] API. /// @@ -44,11 +45,12 @@ use datafusion_expr::{ColumnarValue, ExprSchemable, Operator}; /// also comes with APIs for evaluation, simplification, and analysis. /// /// The code in this example shows how to: -/// 1. Create [`Exprs`] using different APIs: [`main`]` -/// 2. Evaluate [`Exprs`] against data: [`evaluate_demo`] -/// 3. Simplify expressions: [`simplify_demo`] -/// 4. Analyze predicates for boundary ranges: [`range_analysis_demo`] -/// 5. Get the types of the expressions: [`expression_type_demo`] +/// 1. Create [`Expr`]s using different APIs: [`main`]` +/// 2. Use the fluent API to easly create complex [`Expr`]s: [`expr_fn_demo`] +/// 3. Evaluate [`Expr`]s against data: [`evaluate_demo`] +/// 4. Simplify expressions: [`simplify_demo`] +/// 5. Analyze predicates for boundary ranges: [`range_analysis_demo`] +/// 6. Get the types of the expressions: [`expression_type_demo`] #[tokio::main] async fn main() -> Result<()> { // The easiest way to do create expressions is to use the @@ -63,6 +65,9 @@ async fn main() -> Result<()> { )); assert_eq!(expr, expr2); + // See how to build aggregate functions with the expr_fn API + expr_fn_demo()?; + // See how to evaluate expressions evaluate_demo()?; @@ -78,6 +83,33 @@ async fn main() -> Result<()> { Ok(()) } +/// Datafusion's `expr_fn` API makes it easy to create [`Expr`]s for the +/// full range of expression types such as aggregates and window functions. +fn expr_fn_demo() -> Result<()> { + // Let's say you want to call the "first_value" aggregate function + let first_value = first_value_udaf(); + + // For example, to create the expression `FIRST_VALUE(price)` + // These expressions can be passed to `DataFrame::aggregate` and other + // APIs that take aggregate expressions. + let agg = first_value.call(vec![col("price")]); + assert_eq!(agg.to_string(), "first_value(price)"); + + // You can use the AggregateExt trait to create more complex aggregates + // such as `FIRST_VALUE(price FILTER quantity > 100 ORDER BY ts ) + let agg = first_value + .call(vec![col("price")]) + .order_by(vec![col("ts").sort(false, false)]) + .filter(col("quantity").gt(lit(100))) + .build()?; // build the aggregate + assert_eq!( + agg.to_string(), + "first_value(price) FILTER (WHERE quantity > Int32(100)) ORDER BY [ts DESC NULLS LAST]" + ); + + Ok(()) +} + /// DataFusion can also evaluate arbitrary expressions on Arrow arrays. fn evaluate_demo() -> Result<()> { // For example, let's say you have some integers in an array diff --git a/datafusion/core/tests/expr_api/mod.rs b/datafusion/core/tests/expr_api/mod.rs index 1db5aa9f235a..7085333bee03 100644 --- a/datafusion/core/tests/expr_api/mod.rs +++ b/datafusion/core/tests/expr_api/mod.rs @@ -15,14 +15,18 @@ // specific language governing permissions and limitations // under the License. -use arrow::util::pretty::pretty_format_columns; +use arrow::util::pretty::{pretty_format_batches, pretty_format_columns}; use arrow_array::builder::{ListBuilder, StringBuilder}; -use arrow_array::{ArrayRef, RecordBatch, StringArray, StructArray}; +use arrow_array::{ArrayRef, Int64Array, RecordBatch, StringArray, StructArray}; use arrow_schema::{DataType, Field}; use datafusion::prelude::*; -use datafusion_common::{DFSchema, ScalarValue}; +use datafusion_common::{assert_contains, DFSchema, ScalarValue}; +use datafusion_expr::AggregateExt; use datafusion_functions::core::expr_ext::FieldAccessor; +use datafusion_functions_aggregate::first_last::first_value_udaf; +use datafusion_functions_aggregate::sum::sum_udaf; use datafusion_functions_array::expr_ext::{IndexAccessor, SliceAccessor}; +use sqlparser::ast::NullTreatment; /// Tests of using and evaluating `Expr`s outside the context of a LogicalPlan use std::sync::{Arc, OnceLock}; @@ -162,6 +166,183 @@ fn test_list_range() { ); } +#[tokio::test] +async fn test_aggregate_error() { + let err = first_value_udaf() + .call(vec![col("props")]) + // not a sort column + .order_by(vec![col("id")]) + .build() + .unwrap_err() + .to_string(); + assert_contains!( + err, + "Error during planning: ORDER BY expressions must be Expr::Sort" + ); +} + +#[tokio::test] +async fn test_aggregate_ext_order_by() { + let agg = first_value_udaf().call(vec![col("props")]); + + // ORDER BY id ASC + let agg_asc = agg + .clone() + .order_by(vec![col("id").sort(true, true)]) + .build() + .unwrap() + .alias("asc"); + + // ORDER BY id DESC + let agg_desc = agg + .order_by(vec![col("id").sort(false, true)]) + .build() + .unwrap() + .alias("desc"); + + evaluate_agg_test( + agg_asc, + vec![ + "+-----------------+", + "| asc |", + "+-----------------+", + "| {a: 2021-02-01} |", + "+-----------------+", + ], + ) + .await; + + evaluate_agg_test( + agg_desc, + vec![ + "+-----------------+", + "| desc |", + "+-----------------+", + "| {a: 2021-02-03} |", + "+-----------------+", + ], + ) + .await; +} + +#[tokio::test] +async fn test_aggregate_ext_filter() { + let agg = first_value_udaf() + .call(vec![col("i")]) + .order_by(vec![col("i").sort(true, true)]) + .filter(col("i").is_not_null()) + .build() + .unwrap() + .alias("val"); + + #[rustfmt::skip] + evaluate_agg_test( + agg, + vec![ + "+-----+", + "| val |", + "+-----+", + "| 5 |", + "+-----+", + ], + ) + .await; +} + +#[tokio::test] +async fn test_aggregate_ext_distinct() { + let agg = sum_udaf() + .call(vec![lit(5)]) + // distinct sum should be 5, not 15 + .distinct() + .build() + .unwrap() + .alias("distinct"); + + evaluate_agg_test( + agg, + vec![ + "+----------+", + "| distinct |", + "+----------+", + "| 5 |", + "+----------+", + ], + ) + .await; +} + +#[tokio::test] +async fn test_aggregate_ext_null_treatment() { + let agg = first_value_udaf() + .call(vec![col("i")]) + .order_by(vec![col("i").sort(true, true)]); + + let agg_respect = agg + .clone() + .null_treatment(NullTreatment::RespectNulls) + .build() + .unwrap() + .alias("respect"); + + let agg_ignore = agg + .null_treatment(NullTreatment::IgnoreNulls) + .build() + .unwrap() + .alias("ignore"); + + evaluate_agg_test( + agg_respect, + vec![ + "+---------+", + "| respect |", + "+---------+", + "| |", + "+---------+", + ], + ) + .await; + + evaluate_agg_test( + agg_ignore, + vec![ + "+--------+", + "| ignore |", + "+--------+", + "| 5 |", + "+--------+", + ], + ) + .await; +} + +/// Evaluates the specified expr as an aggregate and compares the result to the +/// expected result. +async fn evaluate_agg_test(expr: Expr, expected_lines: Vec<&str>) { + let batch = test_batch(); + + let ctx = SessionContext::new(); + let group_expr = vec![]; + let agg_expr = vec![expr]; + let result = ctx + .read_batch(batch) + .unwrap() + .aggregate(group_expr, agg_expr) + .unwrap() + .collect() + .await + .unwrap(); + + let result = pretty_format_batches(&result).unwrap().to_string(); + let actual_lines = result.lines().collect::>(); + + assert_eq!( + expected_lines, actual_lines, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected_lines, actual_lines + ); +} + /// Converts the `Expr` to a `PhysicalExpr`, evaluates it against the provided /// `RecordBatch` and compares the result to the expected result. fn evaluate_expr_test(expr: Expr, expected_lines: Vec<&str>) { @@ -189,6 +370,8 @@ fn test_batch() -> RecordBatch { TEST_BATCH .get_or_init(|| { let string_array: ArrayRef = Arc::new(StringArray::from(vec!["1", "2", "3"])); + let int_array: ArrayRef = + Arc::new(Int64Array::from_iter(vec![Some(10), None, Some(5)])); // { a: "2021-02-01" } { a: "2021-02-02" } { a: "2021-02-03" } let struct_array: ArrayRef = Arc::from(StructArray::from(vec![( @@ -209,6 +392,7 @@ fn test_batch() -> RecordBatch { RecordBatch::try_from_iter(vec![ ("id", string_array), + ("i", int_array), ("props", struct_array), ("list", list_array), ]) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 1abd8c97ee10..10e2edd17b21 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -255,19 +255,23 @@ pub enum Expr { /// can be used. The first form consists of a series of boolean "when" expressions with /// corresponding "then" expressions, and an optional "else" expression. /// + /// ```text /// CASE WHEN condition THEN result /// [WHEN ...] /// [ELSE result] /// END + /// ``` /// /// The second form uses a base expression and then a series of "when" clauses that match on a /// literal value. /// + /// ```text /// CASE expression /// WHEN value THEN result /// [WHEN ...] /// [ELSE result] /// END + /// ``` Case(Case), /// Casts the expression to a given type and will return a runtime error if the expression cannot be cast. /// This expression is guaranteed to have a fixed type. @@ -279,7 +283,12 @@ pub enum Expr { Sort(Sort), /// Represents the call of a scalar function with a set of arguments. ScalarFunction(ScalarFunction), - /// Represents the call of an aggregate built-in function with arguments. + /// Calls an aggregate function with arguments, and optional + /// `ORDER BY`, `FILTER`, `DISTINCT` and `NULL TREATMENT`. + /// + /// See also [`AggregateExt`] to set these fields. + /// + /// [`AggregateExt`]: crate::udaf::AggregateExt AggregateFunction(AggregateFunction), /// Represents the call of a window function with arguments. WindowFunction(WindowFunction), @@ -623,6 +632,10 @@ impl AggregateFunctionDefinition { } /// Aggregate function +/// +/// See also [`AggregateExt`] to set these fields on `Expr` +/// +/// [`AggregateExt`]: crate::udaf::AggregateExt #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct AggregateFunction { /// Name of the function diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 8c9893b8a748..c2d40a7fe4f1 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -81,7 +81,7 @@ pub use signature::{ ArrayFunctionSignature, Signature, TypeSignature, Volatility, TIMEZONE_WILDCARD, }; pub use table_source::{TableProviderFilterPushDown, TableSource, TableType}; -pub use udaf::{AggregateUDF, AggregateUDFImpl, ReversedUDAF}; +pub use udaf::{AggregateExt, AggregateUDF, AggregateUDFImpl, ReversedUDAF}; pub use udf::{ScalarUDF, ScalarUDFImpl}; pub use udwf::{WindowUDF, WindowUDFImpl}; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index d778203207c9..a248518c2d94 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -17,6 +17,7 @@ //! [`AggregateUDF`]: User Defined Aggregate Functions +use crate::expr::AggregateFunction; use crate::function::{ AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs, }; @@ -26,7 +27,8 @@ use crate::utils::AggregateOrderSensitivity; use crate::{Accumulator, Expr}; use crate::{AccumulatorFactoryFunction, ReturnTypeFunction, Signature}; use arrow::datatypes::{DataType, Field}; -use datafusion_common::{exec_err, not_impl_err, Result}; +use datafusion_common::{exec_err, not_impl_err, plan_err, Result}; +use sqlparser::ast::NullTreatment; use std::any::Any; use std::fmt::{self, Debug, Formatter}; use std::sync::Arc; @@ -139,8 +141,7 @@ impl AggregateUDF { /// This utility allows using the UDAF without requiring access to /// the registry, such as with the DataFrame API. pub fn call(&self, args: Vec) -> Expr { - // TODO: Support dictinct, filter, order by and null_treatment - Expr::AggregateFunction(crate::expr::AggregateFunction::new_udf( + Expr::AggregateFunction(AggregateFunction::new_udf( Arc::new(self.clone()), args, false, @@ -606,3 +607,177 @@ impl AggregateUDFImpl for AggregateUDFLegacyWrapper { (self.accumulator)(acc_args) } } + +/// Extensions for configuring [`Expr::AggregateFunction`] +/// +/// Adds methods to [`Expr`] that make it easy to set optional aggregate options +/// such as `ORDER BY`, `FILTER` and `DISTINCT` +/// +/// # Example +/// ```no_run +/// # use datafusion_common::Result; +/// # use datafusion_expr::{AggregateUDF, col, Expr, lit}; +/// # use sqlparser::ast::NullTreatment; +/// # fn count(arg: Expr) -> Expr { todo!{} } +/// # fn first_value(arg: Expr) -> Expr { todo!{} } +/// # fn main() -> Result<()> { +/// use datafusion_expr::AggregateExt; +/// +/// // Create COUNT(x FILTER y > 5) +/// let agg = count(col("x")) +/// .filter(col("y").gt(lit(5))) +/// .build()?; +/// // Create FIRST_VALUE(x ORDER BY y IGNORE NULLS) +/// let sort_expr = col("y").sort(true, true); +/// let agg = first_value(col("x")) +/// .order_by(vec![sort_expr]) +/// .null_treatment(NullTreatment::IgnoreNulls) +/// .build()?; +/// # Ok(()) +/// # } +/// ``` +pub trait AggregateExt { + /// Add `ORDER BY ` + /// + /// Note: `order_by` must be [`Expr::Sort`] + fn order_by(self, order_by: Vec) -> AggregateBuilder; + /// Add `FILTER ` + fn filter(self, filter: Expr) -> AggregateBuilder; + /// Add `DISTINCT` + fn distinct(self) -> AggregateBuilder; + /// Add `RESPECT NULLS` or `IGNORE NULLS` + fn null_treatment(self, null_treatment: NullTreatment) -> AggregateBuilder; +} + +/// Implementation of [`AggregateExt`]. +/// +/// See [`AggregateExt`] for usage and examples +#[derive(Debug, Clone)] +pub struct AggregateBuilder { + udaf: Option, + order_by: Option>, + filter: Option, + distinct: bool, + null_treatment: Option, +} + +impl AggregateBuilder { + /// Create a new `AggregateBuilder`, see [`AggregateExt`] + + fn new(udaf: Option) -> Self { + Self { + udaf, + order_by: None, + filter: None, + distinct: false, + null_treatment: None, + } + } + + /// Updates and returns the in progress [`Expr::AggregateFunction`] + /// + /// # Errors: + /// + /// Returns an error of this builder [`AggregateExt`] was used with an + /// `Expr` variant other than [`Expr::AggregateFunction`] + pub fn build(self) -> Result { + let Self { + udaf, + order_by, + filter, + distinct, + null_treatment, + } = self; + + let Some(mut udaf) = udaf else { + return plan_err!( + "AggregateExt can only be used with Expr::AggregateFunction" + ); + }; + + if let Some(order_by) = &order_by { + for expr in order_by.iter() { + if !matches!(expr, Expr::Sort(_)) { + return plan_err!( + "ORDER BY expressions must be Expr::Sort, found {expr:?}" + ); + } + } + } + + udaf.order_by = order_by; + udaf.filter = filter.map(Box::new); + udaf.distinct = distinct; + udaf.null_treatment = null_treatment; + Ok(Expr::AggregateFunction(udaf)) + } + + /// Add `ORDER BY ` + /// + /// Note: `order_by` must be [`Expr::Sort`] + pub fn order_by(mut self, order_by: Vec) -> AggregateBuilder { + self.order_by = Some(order_by); + self + } + + /// Add `FILTER ` + pub fn filter(mut self, filter: Expr) -> AggregateBuilder { + self.filter = Some(filter); + self + } + + /// Add `DISTINCT` + pub fn distinct(mut self) -> AggregateBuilder { + self.distinct = true; + self + } + + /// Add `RESPECT NULLS` or `IGNORE NULLS` + pub fn null_treatment(mut self, null_treatment: NullTreatment) -> AggregateBuilder { + self.null_treatment = Some(null_treatment); + self + } +} + +impl AggregateExt for Expr { + fn order_by(self, order_by: Vec) -> AggregateBuilder { + match self { + Expr::AggregateFunction(udaf) => { + let mut builder = AggregateBuilder::new(Some(udaf)); + builder.order_by = Some(order_by); + builder + } + _ => AggregateBuilder::new(None), + } + } + fn filter(self, filter: Expr) -> AggregateBuilder { + match self { + Expr::AggregateFunction(udaf) => { + let mut builder = AggregateBuilder::new(Some(udaf)); + builder.filter = Some(filter); + builder + } + _ => AggregateBuilder::new(None), + } + } + fn distinct(self) -> AggregateBuilder { + match self { + Expr::AggregateFunction(udaf) => { + let mut builder = AggregateBuilder::new(Some(udaf)); + builder.distinct = true; + builder + } + _ => AggregateBuilder::new(None), + } + } + fn null_treatment(self, null_treatment: NullTreatment) -> AggregateBuilder { + match self { + Expr::AggregateFunction(udaf) => { + let mut builder = AggregateBuilder::new(Some(udaf)); + builder.null_treatment = Some(null_treatment); + builder + } + _ => AggregateBuilder::new(None), + } + } +} diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index 435d277473c4..dd38e3487264 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -31,20 +31,29 @@ use datafusion_common::{ use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::{format_state_name, AggregateOrderSensitivity}; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, ArrayFunctionSignature, Signature, TypeSignature, - Volatility, + Accumulator, AggregateExt, AggregateUDFImpl, ArrayFunctionSignature, Expr, Signature, + TypeSignature, Volatility, }; use datafusion_physical_expr_common::aggregate::utils::get_sort_options; use datafusion_physical_expr_common::sort_expr::{ limited_convert_logical_sort_exprs_to_physical, LexOrdering, PhysicalSortExpr, }; -make_udaf_expr_and_func!( - FirstValue, - first_value, - "Returns the first value in a group of values.", - first_value_udaf -); +create_func!(FirstValue, first_value_udaf); + +/// Returns the first value in a group of values. +pub fn first_value(expression: Expr, order_by: Option>) -> Expr { + if let Some(order_by) = order_by { + first_value_udaf() + .call(vec![expression]) + .order_by(order_by) + .build() + // guaranteed to be `Expr::AggregateFunction` + .unwrap() + } else { + first_value_udaf().call(vec![expression]) + } +} pub struct FirstValue { signature: Signature, diff --git a/datafusion/functions-aggregate/src/macros.rs b/datafusion/functions-aggregate/src/macros.rs index 6c3348d6c1d6..75bb9dc54719 100644 --- a/datafusion/functions-aggregate/src/macros.rs +++ b/datafusion/functions-aggregate/src/macros.rs @@ -48,24 +48,7 @@ macro_rules! make_udaf_expr_and_func { None, )) } - create_func!($UDAF, $AGGREGATE_UDF_FN); - }; - ($UDAF:ty, $EXPR_FN:ident, $($arg:ident)*, $distinct:ident, $DOC:expr, $AGGREGATE_UDF_FN:ident) => { - // "fluent expr_fn" style function - #[doc = $DOC] - pub fn $EXPR_FN( - $($arg: datafusion_expr::Expr,)* - distinct: bool, - ) -> datafusion_expr::Expr { - datafusion_expr::Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( - $AGGREGATE_UDF_FN(), - vec![$($arg),*], - distinct, - None, - None, - None - )) - } + create_func!($UDAF, $AGGREGATE_UDF_FN); }; ($UDAF:ty, $EXPR_FN:ident, $DOC:expr, $AGGREGATE_UDF_FN:ident) => { @@ -73,20 +56,17 @@ macro_rules! make_udaf_expr_and_func { #[doc = $DOC] pub fn $EXPR_FN( args: Vec, - distinct: bool, - filter: Option>, - order_by: Option>, - null_treatment: Option ) -> datafusion_expr::Expr { datafusion_expr::Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( $AGGREGATE_UDF_FN(), args, - distinct, - filter, - order_by, - null_treatment, + false, + None, + None, + None, )) } + create_func!($UDAF, $AGGREGATE_UDF_FN); }; } diff --git a/datafusion/optimizer/src/replace_distinct_aggregate.rs b/datafusion/optimizer/src/replace_distinct_aggregate.rs index 752e2b200741..b32a88635395 100644 --- a/datafusion/optimizer/src/replace_distinct_aggregate.rs +++ b/datafusion/optimizer/src/replace_distinct_aggregate.rs @@ -21,10 +21,9 @@ use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::tree_node::Transformed; use datafusion_common::{internal_err, Column, Result}; -use datafusion_expr::expr::AggregateFunction; use datafusion_expr::expr_rewriter::normalize_cols; use datafusion_expr::utils::expand_wildcard; -use datafusion_expr::{col, LogicalPlanBuilder}; +use datafusion_expr::{col, AggregateExt, LogicalPlanBuilder}; use datafusion_expr::{Aggregate, Distinct, DistinctOn, Expr, LogicalPlan}; /// Optimizer that replaces logical [[Distinct]] with a logical [[Aggregate]] @@ -95,17 +94,19 @@ impl OptimizerRule for ReplaceDistinctWithAggregate { let expr_cnt = on_expr.len(); // Construct the aggregation expression to be used to fetch the selected expressions. - let first_value_udaf = + let first_value_udaf: std::sync::Arc = config.function_registry().unwrap().udaf("first_value")?; let aggr_expr = select_expr.into_iter().map(|e| { - Expr::AggregateFunction(AggregateFunction::new_udf( - first_value_udaf.clone(), - vec![e], - false, - None, - sort_expr.clone(), - None, - )) + if let Some(order_by) = &sort_expr { + first_value_udaf + .call(vec![e]) + .order_by(order_by.clone()) + .build() + // guaranteed to be `Expr::AggregateFunction` + .unwrap() + } else { + first_value_udaf.call(vec![e]) + } }); let aggr_expr = normalize_cols(aggr_expr, input.as_ref())?; diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index f32f4b04938f..4f35b82e4908 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -26,6 +26,8 @@ use arrow::datatypes::{ DataType, Field, Fields, Int32Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, Schema, SchemaRef, TimeUnit, UnionFields, UnionMode, }; +use prost::Message; + use datafusion::datasource::provider::TableProviderFactory; use datafusion::datasource::TableProvider; use datafusion::execution::context::SessionState; @@ -64,8 +66,6 @@ use datafusion_proto::logical_plan::{ }; use datafusion_proto::protobuf; -use prost::Message; - #[cfg(feature = "json")] fn roundtrip_json_test(proto: &protobuf::LogicalExprNode) { let string = serde_json::to_string(proto).unwrap(); @@ -647,7 +647,8 @@ async fn roundtrip_expr_api() -> Result<()> { lit(1), ), array_replace_all(make_array(vec![lit(1), lit(2), lit(3)]), lit(2), lit(4)), - first_value(vec![lit(1)], false, None, None, None), + first_value(lit(1), None), + first_value(lit(1), Some(vec![lit(2).sort(true, true)])), covar_samp(lit(1.5), lit(2.2)), covar_pop(lit(1.5), lit(2.2)), sum(lit(1)), diff --git a/docs/source/user-guide/expressions.md b/docs/source/user-guide/expressions.md index a5fc13491677..cae9627210e5 100644 --- a/docs/source/user-guide/expressions.md +++ b/docs/source/user-guide/expressions.md @@ -304,6 +304,16 @@ select log(-1), log(0), sqrt(-1); | rollup(exprs) | Creates a grouping set for rollup sets. | | sum(expr) | Сalculates the sum of `expr`. | +## Aggregate Function Builder + +You can also use the `AggregateExt` trait to more easily build Aggregate arguments `Expr`. + +See `datafusion-examples/examples/expr_api.rs` for example usage. + +| Syntax | Equivalent to | +| ----------------------------------------------------------------------- | ----------------------------------- | +| first_value_udaf.call(vec![expr]).order_by(vec![expr]).build().unwrap() | first_value(expr, Some(vec![expr])) | + ## Subquery Expressions | Syntax | Description |