diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py index e785cab0..c0b49571 100644 --- a/python/datafusion/expr.py +++ b/python/datafusion/expr.py @@ -216,9 +216,11 @@ def expr_list_to_raw_expr_list( - expr_list: Optional[list[Expr]], + expr_list: Optional[list[Expr] | Expr], ) -> Optional[list[expr_internal.Expr]]: """Helper function to convert an optional list to raw expressions.""" + if isinstance(expr_list, Expr): + expr_list = [expr_list] return [e.expr for e in expr_list] if expr_list is not None else None @@ -230,9 +232,11 @@ def sort_or_default(e: Expr | SortExpr) -> expr_internal.SortExpr: def sort_list_to_raw_sort_list( - sort_list: Optional[list[Expr | SortExpr]], + sort_list: Optional[list[Expr | SortExpr] | Expr | SortExpr], ) -> Optional[list[expr_internal.SortExpr]]: """Helper function to return an optional sort list to raw variant.""" + if isinstance(sort_list, (Expr, SortExpr)): + sort_list = [sort_list] return [sort_or_default(e) for e in sort_list] if sort_list is not None else None @@ -1140,9 +1144,9 @@ class Window: def __init__( self, - partition_by: Optional[list[Expr]] = None, + partition_by: Optional[list[Expr] | Expr] = None, window_frame: Optional[WindowFrame] = None, - order_by: Optional[list[SortExpr | Expr]] = None, + order_by: Optional[list[SortExpr | Expr] | Expr | SortExpr] = None, null_treatment: Optional[NullTreatment] = None, ) -> None: """Construct a window definition. diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index f430cdf4..34068805 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -428,8 +428,8 @@ def when(when: Expr, then: Expr) -> CaseBuilder: def window( name: str, args: list[Expr], - partition_by: list[Expr] | None = None, - order_by: list[Expr | SortExpr] | None = None, + partition_by: list[Expr] | Expr | None = None, + order_by: list[Expr | SortExpr] | Expr | SortExpr | None = None, window_frame: WindowFrame | None = None, ctx: SessionContext | None = None, ) -> Expr: @@ -442,11 +442,11 @@ def window( df.select(functions.lag(col("a")).partition_by(col("b")).build()) """ args = [a.expr for a in args] - partition_by = expr_list_to_raw_expr_list(partition_by) + partition_by_raw = expr_list_to_raw_expr_list(partition_by) order_by_raw = sort_list_to_raw_sort_list(order_by) window_frame = window_frame.window_frame if window_frame is not None else None ctx = ctx.ctx if ctx is not None else None - return Expr(f.window(name, args, partition_by, order_by_raw, window_frame, ctx)) + return Expr(f.window(name, args, partition_by_raw, order_by_raw, window_frame, ctx)) # scalar functions @@ -1723,7 +1723,7 @@ def array_agg( expression: Expr, distinct: bool = False, filter: Optional[Expr] = None, - order_by: Optional[list[Expr | SortExpr]] = None, + order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None, ) -> Expr: """Aggregate values into an array. @@ -2222,7 +2222,7 @@ def regr_syy( def first_value( expression: Expr, filter: Optional[Expr] = None, - order_by: Optional[list[Expr | SortExpr]] = None, + order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None, null_treatment: NullTreatment = NullTreatment.RESPECT_NULLS, ) -> Expr: """Returns the first value in a group of values. @@ -2254,7 +2254,7 @@ def first_value( def last_value( expression: Expr, filter: Optional[Expr] = None, - order_by: Optional[list[Expr | SortExpr]] = None, + order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None, null_treatment: NullTreatment = NullTreatment.RESPECT_NULLS, ) -> Expr: """Returns the last value in a group of values. @@ -2287,7 +2287,7 @@ def nth_value( expression: Expr, n: int, filter: Optional[Expr] = None, - order_by: Optional[list[Expr | SortExpr]] = None, + order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None, null_treatment: NullTreatment = NullTreatment.RESPECT_NULLS, ) -> Expr: """Returns the n-th value in a group of values. @@ -2407,8 +2407,8 @@ def lead( arg: Expr, shift_offset: int = 1, default_value: Optional[Any] = None, - partition_by: Optional[list[Expr]] = None, - order_by: Optional[list[Expr | SortExpr]] = None, + partition_by: Optional[list[Expr] | Expr] = None, + order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None, ) -> Expr: """Create a lead window function. @@ -2442,9 +2442,7 @@ def lead( if not isinstance(default_value, pa.Scalar) and default_value is not None: default_value = pa.scalar(default_value) - partition_cols = ( - [col.expr for col in partition_by] if partition_by is not None else None - ) + partition_by_raw = expr_list_to_raw_expr_list(partition_by) order_by_raw = sort_list_to_raw_sort_list(order_by) return Expr( @@ -2452,7 +2450,7 @@ def lead( arg.expr, shift_offset, default_value, - partition_by=partition_cols, + partition_by=partition_by_raw, order_by=order_by_raw, ) ) @@ -2462,8 +2460,8 @@ def lag( arg: Expr, shift_offset: int = 1, default_value: Optional[Any] = None, - partition_by: Optional[list[Expr]] = None, - order_by: Optional[list[Expr | SortExpr]] = None, + partition_by: Optional[list[Expr] | Expr] = None, + order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None, ) -> Expr: """Create a lag window function. @@ -2494,9 +2492,7 @@ def lag( if not isinstance(default_value, pa.Scalar): default_value = pa.scalar(default_value) - partition_cols = ( - [col.expr for col in partition_by] if partition_by is not None else None - ) + partition_by_raw = expr_list_to_raw_expr_list(partition_by) order_by_raw = sort_list_to_raw_sort_list(order_by) return Expr( @@ -2504,15 +2500,15 @@ def lag( arg.expr, shift_offset, default_value, - partition_by=partition_cols, + partition_by=partition_by_raw, order_by=order_by_raw, ) ) def row_number( - partition_by: Optional[list[Expr]] = None, - order_by: Optional[list[Expr | SortExpr]] = None, + partition_by: Optional[list[Expr] | Expr] = None, + order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None, ) -> Expr: """Create a row number window function. @@ -2533,22 +2529,20 @@ def row_number( partition_by: Expressions to partition the window frame on. order_by: Set ordering within the window frame. """ - partition_cols = ( - [col.expr for col in partition_by] if partition_by is not None else None - ) + partition_by_raw = expr_list_to_raw_expr_list(partition_by) order_by_raw = sort_list_to_raw_sort_list(order_by) return Expr( f.row_number( - partition_by=partition_cols, + partition_by=partition_by_raw, order_by=order_by_raw, ) ) def rank( - partition_by: Optional[list[Expr]] = None, - order_by: Optional[list[Expr | SortExpr]] = None, + partition_by: Optional[list[Expr] | Expr] = None, + order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None, ) -> Expr: """Create a rank window function. @@ -2574,22 +2568,20 @@ def rank( partition_by: Expressions to partition the window frame on. order_by: Set ordering within the window frame. """ - partition_cols = ( - [col.expr for col in partition_by] if partition_by is not None else None - ) + partition_by_raw = expr_list_to_raw_expr_list(partition_by) order_by_raw = sort_list_to_raw_sort_list(order_by) return Expr( f.rank( - partition_by=partition_cols, + partition_by=partition_by_raw, order_by=order_by_raw, ) ) def dense_rank( - partition_by: Optional[list[Expr]] = None, - order_by: Optional[list[Expr | SortExpr]] = None, + partition_by: Optional[list[Expr] | Expr] = None, + order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None, ) -> Expr: """Create a dense_rank window function. @@ -2610,22 +2602,20 @@ def dense_rank( partition_by: Expressions to partition the window frame on. order_by: Set ordering within the window frame. """ - partition_cols = ( - [col.expr for col in partition_by] if partition_by is not None else None - ) + partition_by_raw = expr_list_to_raw_expr_list(partition_by) order_by_raw = sort_list_to_raw_sort_list(order_by) return Expr( f.dense_rank( - partition_by=partition_cols, + partition_by=partition_by_raw, order_by=order_by_raw, ) ) def percent_rank( - partition_by: Optional[list[Expr]] = None, - order_by: Optional[list[Expr | SortExpr]] = None, + partition_by: Optional[list[Expr] | Expr] = None, + order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None, ) -> Expr: """Create a percent_rank window function. @@ -2647,22 +2637,20 @@ def percent_rank( partition_by: Expressions to partition the window frame on. order_by: Set ordering within the window frame. """ - partition_cols = ( - [col.expr for col in partition_by] if partition_by is not None else None - ) + partition_by_raw = expr_list_to_raw_expr_list(partition_by) order_by_raw = sort_list_to_raw_sort_list(order_by) return Expr( f.percent_rank( - partition_by=partition_cols, + partition_by=partition_by_raw, order_by=order_by_raw, ) ) def cume_dist( - partition_by: Optional[list[Expr]] = None, - order_by: Optional[list[Expr | SortExpr]] = None, + partition_by: Optional[list[Expr] | Expr] = None, + order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None, ) -> Expr: """Create a cumulative distribution window function. @@ -2684,14 +2672,12 @@ def cume_dist( partition_by: Expressions to partition the window frame on. order_by: Set ordering within the window frame. """ - partition_cols = ( - [col.expr for col in partition_by] if partition_by is not None else None - ) + partition_by_raw = expr_list_to_raw_expr_list(partition_by) order_by_raw = sort_list_to_raw_sort_list(order_by) return Expr( f.cume_dist( - partition_by=partition_cols, + partition_by=partition_by_raw, order_by=order_by_raw, ) ) @@ -2699,8 +2685,8 @@ def cume_dist( def ntile( groups: int, - partition_by: Optional[list[Expr]] = None, - order_by: Optional[list[Expr | SortExpr]] = None, + partition_by: Optional[list[Expr] | Expr] = None, + order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None, ) -> Expr: """Create a n-tile window function. @@ -2725,15 +2711,13 @@ def ntile( partition_by: Expressions to partition the window frame on. order_by: Set ordering within the window frame. """ - partition_cols = ( - [col.expr for col in partition_by] if partition_by is not None else None - ) + partition_by_raw = expr_list_to_raw_expr_list(partition_by) order_by_raw = sort_list_to_raw_sort_list(order_by) return Expr( f.ntile( Expr.literal(groups).expr, - partition_by=partition_cols, + partition_by=partition_by_raw, order_by=order_by_raw, ) ) @@ -2743,7 +2727,7 @@ def string_agg( expression: Expr, delimiter: str, filter: Optional[Expr] = None, - order_by: Optional[list[Expr | SortExpr]] = None, + order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None, ) -> Expr: """Concatenates the input strings. diff --git a/python/tests/test_aggregation.py b/python/tests/test_aggregation.py index 49dfb38c..96269b16 100644 --- a/python/tests/test_aggregation.py +++ b/python/tests/test_aggregation.py @@ -154,6 +154,11 @@ def test_aggregation_stats(df, agg_expr, calc_expected): pa.array([[6, 4, 4]]), False, ), + ( + f.array_agg(column("b"), order_by=column("c")), + pa.array([[6, 4, 4]]), + False, + ), (f.avg(column("b"), filter=column("a") != lit(1)), pa.array([5.0]), False), (f.sum(column("b"), filter=column("a") != lit(1)), pa.array([10]), False), (f.count(column("b"), distinct=True), pa.array([2]), False), @@ -329,6 +334,15 @@ def test_bit_and_bool_fns(df, name, expr, result): ), [None, None], ), + ( + "first_value_no_list_order_by", + f.first_value( + column("b"), + order_by=column("b"), + null_treatment=NullTreatment.RESPECT_NULLS, + ), + [None, None], + ), ( "first_value_ignore_null", f.first_value( @@ -343,6 +357,11 @@ def test_bit_and_bool_fns(df, name, expr, result): f.last_value(column("a"), order_by=[column("a").sort(ascending=False)]), [0, 4], ), + ( + "last_value_no_list_ordered", + f.last_value(column("a"), order_by=column("a")), + [3, 6], + ), ( "last_value_with_null", f.last_value( @@ -366,6 +385,11 @@ def test_bit_and_bool_fns(df, name, expr, result): f.nth_value(column("a"), 2, order_by=[column("a").sort(ascending=False)]), [2, 5], ), + ( + "nth_value_no_list_ordered", + f.nth_value(column("a"), 2, order_by=column("a").sort(ascending=False)), + [2, 5], + ), ( "nth_value_with_null", f.nth_value( @@ -414,6 +438,11 @@ def test_first_last_value(df_partitioned, name, expr, result) -> None: f.string_agg(column("a"), ",", order_by=[column("b")]), "one,three,two,two", ), + ( + "string_agg", + f.string_agg(column("a"), ",", order_by=column("b")), + "one,three,two,two", + ), ], ) def test_string_agg(name, expr, result) -> None: diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py index a3870ead..3cb76b68 100644 --- a/python/tests/test_dataframe.py +++ b/python/tests/test_dataframe.py @@ -551,12 +551,25 @@ def test_distinct(): ), [2, 1, 3, 4, 2, 1, 3], ), + ( + "row_w_params_no_lists", + f.row_number( + order_by=column("b"), + partition_by=column("c"), + ), + [2, 1, 3, 4, 2, 1, 3], + ), ("rank", f.rank(order_by=[column("b")]), [3, 1, 3, 5, 6, 1, 6]), ( "rank_w_params", f.rank(order_by=[column("b"), column("a")], partition_by=[column("c")]), [2, 1, 3, 4, 2, 1, 3], ), + ( + "rank_w_params_no_lists", + f.rank(order_by=column("a"), partition_by=column("c")), + [1, 2, 3, 4, 1, 2, 3], + ), ( "dense_rank", f.dense_rank(order_by=[column("b")]), @@ -567,6 +580,11 @@ def test_distinct(): f.dense_rank(order_by=[column("b"), column("a")], partition_by=[column("c")]), [2, 1, 3, 4, 2, 1, 3], ), + ( + "dense_rank_w_params_no_lists", + f.dense_rank(order_by=column("a"), partition_by=column("c")), + [1, 2, 3, 4, 1, 2, 3], + ), ( "percent_rank", f.round(f.percent_rank(order_by=[column("b")]), literal(3)), @@ -582,6 +600,14 @@ def test_distinct(): ), [0.333, 0.0, 0.667, 1.0, 0.5, 0.0, 1.0], ), + ( + "percent_rank_w_params_no_lists", + f.round( + f.percent_rank(order_by=column("a"), partition_by=column("c")), + literal(3), + ), + [0.0, 0.333, 0.667, 1.0, 0.0, 0.5, 1.0], + ), ( "cume_dist", f.round(f.cume_dist(order_by=[column("b")]), literal(3)), @@ -597,6 +623,14 @@ def test_distinct(): ), [0.5, 0.25, 0.75, 1.0, 0.667, 0.333, 1.0], ), + ( + "cume_dist_w_params_no_lists", + f.round( + f.cume_dist(order_by=column("a"), partition_by=column("c")), + literal(3), + ), + [0.25, 0.5, 0.75, 1.0, 0.333, 0.667, 1.0], + ), ( "ntile", f.ntile(2, order_by=[column("b")]), @@ -607,6 +641,11 @@ def test_distinct(): f.ntile(2, order_by=[column("b"), column("a")], partition_by=[column("c")]), [1, 1, 2, 2, 1, 1, 2], ), + ( + "ntile_w_params_no_lists", + f.ntile(2, order_by=column("b"), partition_by=column("c")), + [1, 1, 2, 2, 1, 1, 2], + ), ("lead", f.lead(column("b"), order_by=[column("b")]), [7, None, 8, 9, 9, 7, None]), ( "lead_w_params", @@ -619,6 +658,17 @@ def test_distinct(): ), [8, 7, -1, -1, -1, 9, -1], ), + ( + "lead_w_params_no_lists", + f.lead( + column("b"), + shift_offset=2, + default_value=-1, + order_by=column("b"), + partition_by=column("c"), + ), + [8, 7, -1, -1, -1, 9, -1], + ), ("lag", f.lag(column("b"), order_by=[column("b")]), [None, None, 7, 7, 8, None, 9]), ( "lag_w_params", @@ -631,6 +681,17 @@ def test_distinct(): ), [-1, -1, None, 7, -1, -1, None], ), + ( + "lag_w_params_no_lists", + f.lag( + column("b"), + shift_offset=2, + default_value=-1, + order_by=column("b"), + partition_by=column("c"), + ), + [-1, -1, None, 7, -1, -1, None], + ), ( "first_value", f.first_value(column("a")).over( @@ -638,6 +699,13 @@ def test_distinct(): ), [1, 1, 1, 1, 5, 5, 5], ), + ( + "first_value_without_list_args", + f.first_value(column("a")).over( + Window(partition_by=column("c"), order_by=column("b")) + ), + [1, 1, 1, 1, 5, 5, 5], + ), ( "last_value", f.last_value(column("a")).over(