diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 43ac3992be78..398f59e35d10 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -522,6 +522,38 @@ impl DataFrame { }) } + /// Return a new `DataFrame` with duplicated rows removed as per the specified expression list + /// according to the provided sorting expressions grouped by the `DISTINCT ON` clause + /// expressions. + /// + /// # Example + /// ``` + /// # use datafusion::prelude::*; + /// # use datafusion::error::Result; + /// # #[tokio::main] + /// # async fn main() -> Result<()> { + /// let ctx = SessionContext::new(); + /// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await? + /// // Return a single row (a, b) for each distinct value of a + /// .distinct_on(vec![col("a")], vec![col("a"), col("b")], None)?; + /// # Ok(()) + /// # } + /// ``` + pub fn distinct_on( + self, + on_expr: Vec, + select_expr: Vec, + sort_expr: Option>, + ) -> Result { + let plan = LogicalPlanBuilder::from(self.plan) + .distinct_on(on_expr, select_expr, sort_expr)? + .build()?; + Ok(DataFrame { + session_state: self.session_state, + plan, + }) + } + /// Return a new `DataFrame` that has statistics for a DataFrame. /// /// Only summarizes numeric datatypes at the moment and returns nulls for @@ -2359,6 +2391,91 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_distinct_on() -> Result<()> { + let t = test_table().await?; + let plan = t + .distinct_on(vec![col("c1")], vec![col("aggregate_test_100.c1")], None) + .unwrap(); + + let sql_plan = + create_plan("select distinct on (c1) c1 from aggregate_test_100").await?; + + assert_same_plan(&plan.plan.clone(), &sql_plan); + + let df_results = plan.clone().collect().await?; + + #[rustfmt::skip] + assert_batches_sorted_eq!( + ["+----+", + "| c1 |", + "+----+", + "| a |", + "| b |", + "| c |", + "| d |", + "| e |", + "+----+"], + &df_results + ); + + Ok(()) + } + + #[tokio::test] + async fn test_distinct_on_sort_by() -> Result<()> { + let t = test_table().await?; + let plan = t + .select(vec![col("c1")]) + .unwrap() + .distinct_on( + vec![col("c1")], + vec![col("c1")], + Some(vec![col("c1").sort(true, true)]), + ) + .unwrap() + .sort(vec![col("c1").sort(true, true)]) + .unwrap(); + + let df_results = plan.clone().collect().await?; + + #[rustfmt::skip] + assert_batches_sorted_eq!( + ["+----+", + "| c1 |", + "+----+", + "| a |", + "| b |", + "| c |", + "| d |", + "| e |", + "+----+"], + &df_results + ); + + Ok(()) + } + + #[tokio::test] + async fn test_distinct_on_sort_by_unprojected() -> Result<()> { + let t = test_table().await?; + let err = t + .select(vec![col("c1")]) + .unwrap() + .distinct_on( + vec![col("c1")], + vec![col("c1")], + Some(vec![col("c1").sort(true, true)]), + ) + .unwrap() + // try to sort on some value not present in input to distinct + .sort(vec![col("c2").sort(true, true)]) + .unwrap_err(); + assert_eq!(err.strip_backtrace(), "Error during planning: For SELECT DISTINCT, ORDER BY expressions c2 must appear in select list"); + + Ok(()) + } + #[tokio::test] async fn join() -> Result<()> { let left = test_table().await?.select_columns(&["c1", "c2"])?; diff --git a/docs/source/user-guide/dataframe.md b/docs/source/user-guide/dataframe.md index 744a719e77be..f011e68fadb2 100644 --- a/docs/source/user-guide/dataframe.md +++ b/docs/source/user-guide/dataframe.md @@ -64,6 +64,7 @@ execution. The plan is evaluated (executed) when an action method is invoked, su | ------------------- | ------------------------------------------------------------------------------------------------------------------------------------------ | | aggregate | Perform an aggregate query with optional grouping expressions. | | distinct | Filter out duplicate rows. | +| distinct_on | Filter out duplicate rows based on provided expressions. | | drop_columns | Create a projection with all but the provided column names. | | except | Calculate the exception of two DataFrames. The two DataFrames must have exactly the same schema | | filter | Filter a DataFrame to only include rows that match the specified filter expression. |