Skip to content

Commit 09525e5

Browse files
authored
Support quantile in cudf-polars grouped aggregations (#18634)
- Closes #17123 Authors: - Lawrence Mitchell (https://github.com/wence-) Approvers: - Vyas Ramasubramani (https://github.com/vyasr) URL: #18634
1 parent 00cd766 commit 09525e5

File tree

6 files changed

+33
-3
lines changed

6 files changed

+33
-3
lines changed

python/cudf_polars/cudf_polars/dsl/expressions/aggregation.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,13 @@ def __init__(
6666
else plc.types.NullPolicy.INCLUDE
6767
)
6868
elif name == "quantile":
69-
_, quantile = self.children
69+
child, quantile = self.children
7070
if not isinstance(quantile, Literal):
7171
raise NotImplementedError("Only support literal quantile values")
72+
if options == "equiprobable":
73+
raise NotImplementedError("Quantile with equiprobable interpolation")
74+
if plc.traits.is_duration(child.dtype):
75+
raise NotImplementedError("Quantile with duration data type")
7276
req = plc.aggregation.quantile(
7377
quantiles=[quantile.value.as_py()], interp=Agg.interp_mapping[options]
7478
)

python/cudf_polars/cudf_polars/dsl/ir.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1012,7 +1012,10 @@ def do_evaluate(
10121012
# A count aggregation, we need a column so use a key column
10131013
col = keys[0].obj
10141014
elif isinstance(value, expr.Agg):
1015-
(child,) = value.children
1015+
if value.name == "quantile":
1016+
child = value.children[0]
1017+
else:
1018+
(child,) = value.children
10161019
col = child.evaluate(df).obj
10171020
else:
10181021
# Anything else, we pre-evaluate

python/cudf_polars/cudf_polars/dsl/utils/aggregations.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,12 @@ def decompose_single_agg(
7272
if isinstance(agg, (expr.Literal, expr.LiteralColumn)):
7373
return [], named_expr, False
7474
if isinstance(agg, expr.Agg):
75-
(child,) = agg.children
75+
if agg.name == "quantile":
76+
# Second child the requested quantile (which is asserted
77+
# to be a literal on construction)
78+
child = agg.children[0]
79+
else:
80+
(child,) = agg.children
7681
needs_masking = agg.name in {"min", "max"} and plc.traits.is_floating_point(
7782
child.dtype
7883
)

python/cudf_polars/cudf_polars/testing/plugin.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ def pytest_configure(config: pytest.Config) -> None:
160160
"tests/unit/operations/arithmetic/test_list.py::test_list_arithmetic_values[exec_op_with_expr_no_type_coercion-broadcast_both-none]": "cudf-polars doesn't nullify division by zero",
161161
"tests/unit/operations/arithmetic/test_list.py::test_list_arithmetic_values[exec_op_with_expr_no_type_coercion-broadcast_none-none]": "cudf-polars doesn't nullify division by zero",
162162
"tests/unit/operations/test_abs.py::test_abs_duration": "Need to raise for unsupported uops on timelike values",
163+
"tests/unit/operations/test_group_by.py::test_group_by_shorthand_quantile": "libcudf quantiles are round to nearest ties to even, polars quantiles are round to nearest ties away from zero",
163164
"tests/unit/operations/test_group_by.py::test_group_by_mean_by_dtype[input7-expected7-Float32-Float32]": "Mismatching dtypes, needs cudf#15852",
164165
"tests/unit/operations/test_group_by.py::test_group_by_mean_by_dtype[input10-expected10-Date-output_dtype10]": "Unsupported groupby-agg for a particular dtype",
165166
"tests/unit/operations/test_group_by.py::test_group_by_mean_by_dtype[input11-expected11-input_dtype11-output_dtype11]": "Unsupported groupby-agg for a particular dtype",

python/cudf_polars/tests/expressions/test_agg.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,18 @@ def test_quantile_invalid_q(df):
125125
assert_ir_translation_raises(q, NotImplementedError)
126126

127127

128+
def test_quantile_equiprobable_unsupported(df):
129+
expr = pl.col("a").quantile(0.5, interpolation="equiprobable")
130+
q = df.select(expr)
131+
assert_ir_translation_raises(q, NotImplementedError)
132+
133+
134+
def test_quantile_duration_unsupported():
135+
df = pl.LazyFrame({"a": pl.Series([1, 2, 3, 4], dtype=pl.Duration("ns"))})
136+
q = df.select(pl.col("a").quantile(0.5))
137+
assert_ir_translation_raises(q, NotImplementedError)
138+
139+
128140
@pytest.mark.parametrize(
129141
"op", [pl.Expr.min, pl.Expr.nan_min, pl.Expr.max, pl.Expr.nan_max]
130142
)

python/cudf_polars/tests/test_groupby.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,11 @@ def keys(request):
8181
[pl.col("float").round(decimals=1).sum()],
8282
[pl.col("int").first(), pl.col("float").last()],
8383
[pl.col("int").sum(), pl.col("string").str.replace("h", "foo", literal=True)],
84+
[pl.col("float").quantile(0.3, interpolation="nearest")],
85+
[pl.col("float").quantile(0.3, interpolation="higher")],
86+
[pl.col("float").quantile(0.3, interpolation="lower")],
87+
[pl.col("float").quantile(0.3, interpolation="midpoint")],
88+
[pl.col("float").quantile(0.3, interpolation="linear")],
8489
[
8590
pl.col("datetime").max(),
8691
pl.col("datetime").max().dt.is_leap_year().alias("leapyear"),

0 commit comments

Comments
 (0)