From f5c7ed08cea3fe79abf10adf8e5ba182e437a72b Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 8 Jul 2025 07:35:04 -0400 Subject: [PATCH 1/3] Add field to dataframe join to indicate if we should keep duplicate keys --- python/datafusion/dataframe.py | 32 +++++++++++++++++++----------- python/tests/test_dataframe.py | 1 - src/dataframe.rs | 36 +++++++++++++++++++++++++++++++++- 3 files changed, 56 insertions(+), 13 deletions(-) diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index 61cb0943..64774c97 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -643,6 +643,7 @@ def join( left_on: None = None, right_on: None = None, join_keys: None = None, + keep_duplicate_keys: bool = False, ) -> DataFrame: ... @overload @@ -655,6 +656,7 @@ def join( left_on: str | Sequence[str], right_on: str | Sequence[str], join_keys: tuple[list[str], list[str]] | None = None, + keep_duplicate_keys: bool = False, ) -> DataFrame: ... @overload @@ -667,6 +669,7 @@ def join( join_keys: tuple[list[str], list[str]], left_on: None = None, right_on: None = None, + keep_duplicate_keys: bool = False, ) -> DataFrame: ... def join( @@ -678,6 +681,7 @@ def join( left_on: str | Sequence[str] | None = None, right_on: str | Sequence[str] | None = None, join_keys: tuple[list[str], list[str]] | None = None, + keep_duplicate_keys: bool = False, ) -> DataFrame: """Join this :py:class:`DataFrame` with another :py:class:`DataFrame`. @@ -690,11 +694,23 @@ def join( "right", "full", "semi", "anti". left_on: Join column of the left dataframe. right_on: Join column of the right dataframe. + keep_duplicate_keys: When False, the columns from the right DataFrame + that have identical names in the ``on`` fields to the left DataFrame + will be dropped. join_keys: Tuple of two lists of column names to join on. [Deprecated] Returns: DataFrame after join. """ + if join_keys is not None: + warnings.warn( + "`join_keys` is deprecated, use `on` or `left_on` with `right_on`", + category=DeprecationWarning, + stacklevel=2, + ) + left_on = join_keys[0] + right_on = join_keys[1] + # This check is to prevent breaking API changes where users prior to # DF 43.0.0 would pass the join_keys as a positional argument instead # of a keyword argument. @@ -705,18 +721,10 @@ def join( and isinstance(on[1], list) ): # We know this is safe because we've checked the types - join_keys = on # type: ignore[assignment] + left_on = on[0] + right_on = on[1] on = None - if join_keys is not None: - warnings.warn( - "`join_keys` is deprecated, use `on` or `left_on` with `right_on`", - category=DeprecationWarning, - stacklevel=2, - ) - left_on = join_keys[0] - right_on = join_keys[1] - if on is not None: if left_on is not None or right_on is not None: error_msg = "`left_on` or `right_on` should not provided with `on`" @@ -735,7 +743,9 @@ def join( if isinstance(right_on, str): right_on = [right_on] - return DataFrame(self.df.join(right.df, how, left_on, right_on)) + return DataFrame( + self.df.join(right.df, how, left_on, right_on, keep_duplicate_keys) + ) def join_on( self, diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py index a3870ead..46f794e7 100644 --- a/python/tests/test_dataframe.py +++ b/python/tests/test_dataframe.py @@ -400,7 +400,6 @@ def test_unnest_without_nulls(nested_df): assert result.column(1) == pa.array([7, 8, 8, 9, 9, 9]) -@pytest.mark.filterwarnings("ignore:`join_keys`:DeprecationWarning") def test_join(): ctx = SessionContext() diff --git a/src/dataframe.rs b/src/dataframe.rs index ab4749e3..f3d6e253 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -566,6 +566,7 @@ impl PyDataFrame { how: &str, left_on: Vec, right_on: Vec, + keep_duplicate_keys: bool, ) -> PyDataFusionResult { let join_type = match how { "inner" => JoinType::Inner, @@ -584,13 +585,46 @@ impl PyDataFrame { let left_keys = left_on.iter().map(|s| s.as_ref()).collect::>(); let right_keys = right_on.iter().map(|s| s.as_ref()).collect::>(); - let df = self.df.as_ref().clone().join( + let mut df = self.df.as_ref().clone().join( right.df.as_ref().clone(), join_type, &left_keys, &right_keys, None, )?; + + if !keep_duplicate_keys { + let mutual_keys = left_keys + .iter() + .zip(right_keys.iter()) + .filter(|(l, r)| l == r) + .map(|(key, _)| *key) + .collect::>(); + + let fields_to_drop = mutual_keys + .iter() + .map(|name| { + df.logical_plan() + .schema() + .qualified_fields_with_unqualified_name(name) + }) + .filter(|r| r.len() == 2) + .map(|r| r[1]) + .collect::>(); + + let expr: Vec = df + .logical_plan() + .schema() + .fields() + .into_iter() + .enumerate() + .map(|(idx, _)| df.logical_plan().schema().qualified_field(idx)) + .filter(|(qualifier, f)| !fields_to_drop.contains(&(*qualifier, f))) + .map(|(qualifier, field)| Expr::Column(Column::from((qualifier, field)))) + .collect(); + df = df.select(expr)?; + } + Ok(Self::new(df)) } From fb8096b6d43407e0bb938158610f105d0d91e70f Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 8 Jul 2025 07:37:16 -0400 Subject: [PATCH 2/3] Suppress expected warning --- python/tests/test_sql.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/tests/test_sql.py b/python/tests/test_sql.py index c383edc6..72907373 100644 --- a/python/tests/test_sql.py +++ b/python/tests/test_sql.py @@ -157,6 +157,9 @@ def test_register_parquet(ctx, tmp_path): assert result.to_pydict() == {"cnt": [100]} +@pytest.mark.filterwarnings( + "ignore:using literals for table_partition_cols data types:DeprecationWarning" +) @pytest.mark.parametrize( ("path_to_str", "legacy_data_type"), [(True, False), (False, False), (False, True)] ) From 8e1ed6753876c99d92490acb9e4357fa5ce2af5f Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 8 Jul 2025 07:37:43 -0400 Subject: [PATCH 3/3] Minor: small tables rendered way too large --- python/datafusion/dataframe_formatter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/datafusion/dataframe_formatter.py b/python/datafusion/dataframe_formatter.py index 2323224b..ffafde53 100644 --- a/python/datafusion/dataframe_formatter.py +++ b/python/datafusion/dataframe_formatter.py @@ -368,7 +368,7 @@ def _build_table_container_start(self) -> list[str]: f"max-height: {self.max_height}px; overflow: auto; border: " '1px solid #ccc;">' ) - html.append('') + html.append('
') return html def _build_table_header(self, schema: Any) -> list[str]: