Skip to content

Improved window and aggregate function signature #1187

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions python/datafusion/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,11 @@


def expr_list_to_raw_expr_list(
expr_list: Optional[list[Expr]],
expr_list: Optional[list[Expr] | Expr],
) -> Optional[list[expr_internal.Expr]]:
"""Helper function to convert an optional list to raw expressions."""
if isinstance(expr_list, Expr):
expr_list = [expr_list]
return [e.expr for e in expr_list] if expr_list is not None else None


Expand All @@ -230,9 +232,11 @@ def sort_or_default(e: Expr | SortExpr) -> expr_internal.SortExpr:


def sort_list_to_raw_sort_list(
sort_list: Optional[list[Expr | SortExpr]],
sort_list: Optional[list[Expr | SortExpr] | Expr | SortExpr],
) -> Optional[list[expr_internal.SortExpr]]:
"""Helper function to return an optional sort list to raw variant."""
if isinstance(sort_list, (Expr, SortExpr)):
sort_list = [sort_list]
return [sort_or_default(e) for e in sort_list] if sort_list is not None else None


Expand Down Expand Up @@ -1140,9 +1144,9 @@ class Window:

def __init__(
self,
partition_by: Optional[list[Expr]] = None,
partition_by: Optional[list[Expr] | Expr] = None,
window_frame: Optional[WindowFrame] = None,
order_by: Optional[list[SortExpr | Expr]] = None,
order_by: Optional[list[SortExpr | Expr] | Expr | SortExpr] = None,
null_treatment: Optional[NullTreatment] = None,
) -> None:
"""Construct a window definition.
Expand Down
98 changes: 41 additions & 57 deletions python/datafusion/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,8 +428,8 @@ def when(when: Expr, then: Expr) -> CaseBuilder:
def window(
name: str,
args: list[Expr],
partition_by: list[Expr] | None = None,
order_by: list[Expr | SortExpr] | None = None,
partition_by: list[Expr] | Expr | None = None,
order_by: list[Expr | SortExpr] | Expr | SortExpr | None = None,
window_frame: WindowFrame | None = None,
ctx: SessionContext | None = None,
) -> Expr:
Expand All @@ -442,11 +442,11 @@ def window(
df.select(functions.lag(col("a")).partition_by(col("b")).build())
"""
args = [a.expr for a in args]
partition_by = expr_list_to_raw_expr_list(partition_by)
partition_by_raw = expr_list_to_raw_expr_list(partition_by)
order_by_raw = sort_list_to_raw_sort_list(order_by)
window_frame = window_frame.window_frame if window_frame is not None else None
ctx = ctx.ctx if ctx is not None else None
return Expr(f.window(name, args, partition_by, order_by_raw, window_frame, ctx))
return Expr(f.window(name, args, partition_by_raw, order_by_raw, window_frame, ctx))


# scalar functions
Expand Down Expand Up @@ -1723,7 +1723,7 @@ def array_agg(
expression: Expr,
distinct: bool = False,
filter: Optional[Expr] = None,
order_by: Optional[list[Expr | SortExpr]] = None,
order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None,
) -> Expr:
"""Aggregate values into an array.

Expand Down Expand Up @@ -2222,7 +2222,7 @@ def regr_syy(
def first_value(
expression: Expr,
filter: Optional[Expr] = None,
order_by: Optional[list[Expr | SortExpr]] = None,
order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None,
null_treatment: NullTreatment = NullTreatment.RESPECT_NULLS,
) -> Expr:
"""Returns the first value in a group of values.
Expand Down Expand Up @@ -2254,7 +2254,7 @@ def first_value(
def last_value(
expression: Expr,
filter: Optional[Expr] = None,
order_by: Optional[list[Expr | SortExpr]] = None,
order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None,
null_treatment: NullTreatment = NullTreatment.RESPECT_NULLS,
) -> Expr:
"""Returns the last value in a group of values.
Expand Down Expand Up @@ -2287,7 +2287,7 @@ def nth_value(
expression: Expr,
n: int,
filter: Optional[Expr] = None,
order_by: Optional[list[Expr | SortExpr]] = None,
order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None,
null_treatment: NullTreatment = NullTreatment.RESPECT_NULLS,
) -> Expr:
"""Returns the n-th value in a group of values.
Expand Down Expand Up @@ -2407,8 +2407,8 @@ def lead(
arg: Expr,
shift_offset: int = 1,
default_value: Optional[Any] = None,
partition_by: Optional[list[Expr]] = None,
order_by: Optional[list[Expr | SortExpr]] = None,
partition_by: Optional[list[Expr] | Expr] = None,
order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None,
) -> Expr:
"""Create a lead window function.

Expand Down Expand Up @@ -2442,17 +2442,15 @@ def lead(
if not isinstance(default_value, pa.Scalar) and default_value is not None:
default_value = pa.scalar(default_value)

partition_cols = (
[col.expr for col in partition_by] if partition_by is not None else None
)
partition_by_raw = expr_list_to_raw_expr_list(partition_by)
order_by_raw = sort_list_to_raw_sort_list(order_by)

return Expr(
f.lead(
arg.expr,
shift_offset,
default_value,
partition_by=partition_cols,
partition_by=partition_by_raw,
order_by=order_by_raw,
)
)
Expand All @@ -2462,8 +2460,8 @@ def lag(
arg: Expr,
shift_offset: int = 1,
default_value: Optional[Any] = None,
partition_by: Optional[list[Expr]] = None,
order_by: Optional[list[Expr | SortExpr]] = None,
partition_by: Optional[list[Expr] | Expr] = None,
order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None,
) -> Expr:
"""Create a lag window function.

Expand Down Expand Up @@ -2494,25 +2492,23 @@ def lag(
if not isinstance(default_value, pa.Scalar):
default_value = pa.scalar(default_value)

partition_cols = (
[col.expr for col in partition_by] if partition_by is not None else None
)
partition_by_raw = expr_list_to_raw_expr_list(partition_by)
order_by_raw = sort_list_to_raw_sort_list(order_by)

return Expr(
f.lag(
arg.expr,
shift_offset,
default_value,
partition_by=partition_cols,
partition_by=partition_by_raw,
order_by=order_by_raw,
)
)


def row_number(
partition_by: Optional[list[Expr]] = None,
order_by: Optional[list[Expr | SortExpr]] = None,
partition_by: Optional[list[Expr] | Expr] = None,
order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None,
) -> Expr:
"""Create a row number window function.

Expand All @@ -2533,22 +2529,20 @@ def row_number(
partition_by: Expressions to partition the window frame on.
order_by: Set ordering within the window frame.
"""
partition_cols = (
[col.expr for col in partition_by] if partition_by is not None else None
)
partition_by_raw = expr_list_to_raw_expr_list(partition_by)
order_by_raw = sort_list_to_raw_sort_list(order_by)

return Expr(
f.row_number(
partition_by=partition_cols,
partition_by=partition_by_raw,
order_by=order_by_raw,
)
)


def rank(
partition_by: Optional[list[Expr]] = None,
order_by: Optional[list[Expr | SortExpr]] = None,
partition_by: Optional[list[Expr] | Expr] = None,
order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None,
) -> Expr:
"""Create a rank window function.

Expand All @@ -2574,22 +2568,20 @@ def rank(
partition_by: Expressions to partition the window frame on.
order_by: Set ordering within the window frame.
"""
partition_cols = (
[col.expr for col in partition_by] if partition_by is not None else None
)
partition_by_raw = expr_list_to_raw_expr_list(partition_by)
order_by_raw = sort_list_to_raw_sort_list(order_by)

return Expr(
f.rank(
partition_by=partition_cols,
partition_by=partition_by_raw,
order_by=order_by_raw,
)
)


def dense_rank(
partition_by: Optional[list[Expr]] = None,
order_by: Optional[list[Expr | SortExpr]] = None,
partition_by: Optional[list[Expr] | Expr] = None,
order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None,
) -> Expr:
"""Create a dense_rank window function.

Expand All @@ -2610,22 +2602,20 @@ def dense_rank(
partition_by: Expressions to partition the window frame on.
order_by: Set ordering within the window frame.
"""
partition_cols = (
[col.expr for col in partition_by] if partition_by is not None else None
)
partition_by_raw = expr_list_to_raw_expr_list(partition_by)
order_by_raw = sort_list_to_raw_sort_list(order_by)

return Expr(
f.dense_rank(
partition_by=partition_cols,
partition_by=partition_by_raw,
order_by=order_by_raw,
)
)


def percent_rank(
partition_by: Optional[list[Expr]] = None,
order_by: Optional[list[Expr | SortExpr]] = None,
partition_by: Optional[list[Expr] | Expr] = None,
order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None,
) -> Expr:
"""Create a percent_rank window function.

Expand All @@ -2647,22 +2637,20 @@ def percent_rank(
partition_by: Expressions to partition the window frame on.
order_by: Set ordering within the window frame.
"""
partition_cols = (
[col.expr for col in partition_by] if partition_by is not None else None
)
partition_by_raw = expr_list_to_raw_expr_list(partition_by)
order_by_raw = sort_list_to_raw_sort_list(order_by)

return Expr(
f.percent_rank(
partition_by=partition_cols,
partition_by=partition_by_raw,
order_by=order_by_raw,
)
)


def cume_dist(
partition_by: Optional[list[Expr]] = None,
order_by: Optional[list[Expr | SortExpr]] = None,
partition_by: Optional[list[Expr] | Expr] = None,
order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None,
) -> Expr:
"""Create a cumulative distribution window function.

Expand All @@ -2684,23 +2672,21 @@ def cume_dist(
partition_by: Expressions to partition the window frame on.
order_by: Set ordering within the window frame.
"""
partition_cols = (
[col.expr for col in partition_by] if partition_by is not None else None
)
partition_by_raw = expr_list_to_raw_expr_list(partition_by)
order_by_raw = sort_list_to_raw_sort_list(order_by)

return Expr(
f.cume_dist(
partition_by=partition_cols,
partition_by=partition_by_raw,
order_by=order_by_raw,
)
)


def ntile(
groups: int,
partition_by: Optional[list[Expr]] = None,
order_by: Optional[list[Expr | SortExpr]] = None,
partition_by: Optional[list[Expr] | Expr] = None,
order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None,
) -> Expr:
"""Create a n-tile window function.

Expand All @@ -2725,15 +2711,13 @@ def ntile(
partition_by: Expressions to partition the window frame on.
order_by: Set ordering within the window frame.
"""
partition_cols = (
[col.expr for col in partition_by] if partition_by is not None else None
)
partition_by_raw = expr_list_to_raw_expr_list(partition_by)
order_by_raw = sort_list_to_raw_sort_list(order_by)

return Expr(
f.ntile(
Expr.literal(groups).expr,
partition_by=partition_cols,
partition_by=partition_by_raw,
order_by=order_by_raw,
)
)
Expand All @@ -2743,7 +2727,7 @@ def string_agg(
expression: Expr,
delimiter: str,
filter: Optional[Expr] = None,
order_by: Optional[list[Expr | SortExpr]] = None,
order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None,
) -> Expr:
"""Concatenates the input strings.

Expand Down
29 changes: 29 additions & 0 deletions python/tests/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,11 @@ def test_aggregation_stats(df, agg_expr, calc_expected):
pa.array([[6, 4, 4]]),
False,
),
(
f.array_agg(column("b"), order_by=column("c")),
pa.array([[6, 4, 4]]),
False,
),
(f.avg(column("b"), filter=column("a") != lit(1)), pa.array([5.0]), False),
(f.sum(column("b"), filter=column("a") != lit(1)), pa.array([10]), False),
(f.count(column("b"), distinct=True), pa.array([2]), False),
Expand Down Expand Up @@ -329,6 +334,15 @@ def test_bit_and_bool_fns(df, name, expr, result):
),
[None, None],
),
(
"first_value_no_list_order_by",
f.first_value(
column("b"),
order_by=column("b"),
null_treatment=NullTreatment.RESPECT_NULLS,
),
[None, None],
),
(
"first_value_ignore_null",
f.first_value(
Expand All @@ -343,6 +357,11 @@ def test_bit_and_bool_fns(df, name, expr, result):
f.last_value(column("a"), order_by=[column("a").sort(ascending=False)]),
[0, 4],
),
(
"last_value_no_list_ordered",
f.last_value(column("a"), order_by=column("a")),
[3, 6],
),
(
"last_value_with_null",
f.last_value(
Expand All @@ -366,6 +385,11 @@ def test_bit_and_bool_fns(df, name, expr, result):
f.nth_value(column("a"), 2, order_by=[column("a").sort(ascending=False)]),
[2, 5],
),
(
"nth_value_no_list_ordered",
f.nth_value(column("a"), 2, order_by=column("a").sort(ascending=False)),
[2, 5],
),
(
"nth_value_with_null",
f.nth_value(
Expand Down Expand Up @@ -414,6 +438,11 @@ def test_first_last_value(df_partitioned, name, expr, result) -> None:
f.string_agg(column("a"), ",", order_by=[column("b")]),
"one,three,two,two",
),
(
"string_agg",
f.string_agg(column("a"), ",", order_by=column("b")),
"one,three,two,two",
),
],
)
def test_string_agg(name, expr, result) -> None:
Expand Down
Loading