diff --git a/docs/source/user-guide/common-operations/joins.rst b/docs/source/user-guide/common-operations/joins.rst
index 40d92215..b5c111af 100644
--- a/docs/source/user-guide/common-operations/joins.rst
+++ b/docs/source/user-guide/common-operations/joins.rst
@@ -101,4 +101,48 @@ the right table.
.. ipython:: python
- left.join(right, left_on="customer_id", right_on="id", how="anti")
\ No newline at end of file
+ left.join(right, left_on="customer_id", right_on="id", how="anti")
+
+Disambiguating Columns
+----------------------
+
+When the join key exists in both DataFrames under the same name, the result contains two columns with that name. Assign a name to each DataFrame to use as a prefix and avoid ambiguity.
+
+When you create a DataFrame with a ``name`` argument, that name is used as a prefix in ``col("name.column")`` to reference specific columns.
+
+.. ipython:: python
+
+ from datafusion import col, SessionContext
+ ctx = SessionContext()
+ left = ctx.from_pydict({"id": [1, 2]}, name="l")
+ right = ctx.from_pydict({"id": [2, 3]}, name="r")
+ joined = left.join(right, on="id")
+ joined.select(col("l.id"), col("r.id"))
+
+Note that the columns in the result appear in the same order as specified in the ``select()`` call.
+
+You can remove the duplicate column after joining. Note that ``drop()`` returns a new DataFrame (DataFusion's API is immutable).
+
+.. ipython:: python
+
+ joined.drop("r.id")
+
+Automatic Deduplication
+----------------------
+
+Use the ``deduplicate`` argument of :py:meth:`DataFrame.join` to automatically
+drop the duplicate join column from the right DataFrame. Unlike PySpark which uses a ``_`` suffix by default,
+DataFusion uses the ``__right_
`` naming convention for conflicting columns when not using deduplication.
+
+.. ipython:: python
+
+ left.join(right, on="id", deduplicate=True)
+
+After deduplication, you can select the join column (which comes from the left DataFrame) and other columns as usual:
+
+.. ipython:: python
+
+ # Select the id column and other columns from both DataFrames
+ joined_dedup = left.join(right, on="id", deduplicate=True)
+ joined_dedup.select("id", "customer", "name")
+
diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py
index 61cb0943..94a68a44 100644
--- a/python/datafusion/dataframe.py
+++ b/python/datafusion/dataframe.py
@@ -21,7 +21,9 @@
from __future__ import annotations
+import uuid
import warnings
+from dataclasses import dataclass
from typing import (
TYPE_CHECKING,
Any,
@@ -44,6 +46,8 @@
from datafusion.plan import ExecutionPlan, LogicalPlan
from datafusion.record_batch import RecordBatchStream
+from .functions import coalesce, col
+
if TYPE_CHECKING:
import pathlib
from typing import Callable, Sequence
@@ -57,6 +61,49 @@
from enum import Enum
+@dataclass
+class JoinKeys:
+ """Represents the resolved join keys for a DataFrame join operation."""
+
+ on: str | Sequence[str] | None
+ left_names: list[str]
+ right_names: list[str]
+
+
+@dataclass
+class JoinPreparation:
+ """Represents the complete preparation for a DataFrame join operation."""
+
+ join_keys: JoinKeys
+ modified_right: DataFrame
+ drop_cols: list[str]
+
+
+def _deduplicate_right(
+ right: DataFrame, columns: Sequence[str]
+) -> tuple[DataFrame, list[str]]:
+ """Rename join columns on the right DataFrame for deduplication."""
+ existing_columns = set(right.schema().names)
+ modified = right
+ aliases: list[str] = []
+
+ for col_name in columns:
+ base_alias = f"__right_{col_name}"
+ alias = base_alias
+ counter = 0
+ while alias in existing_columns:
+ counter += 1
+ alias = f"{base_alias}_{counter}"
+ if alias in existing_columns:
+ alias = f"__temp_{uuid.uuid4().hex[:8]}_{col_name}"
+
+ modified = modified.with_column_renamed(col_name, alias)
+ aliases.append(alias)
+ existing_columns.add(alias)
+
+ return modified, aliases
+
+
# excerpt from deltalake
# https://github.com/apache/datafusion-python/pull/981#discussion_r1905619163
class Compression(Enum):
@@ -678,6 +725,7 @@ def join(
left_on: str | Sequence[str] | None = None,
right_on: str | Sequence[str] | None = None,
join_keys: tuple[list[str], list[str]] | None = None,
+ deduplicate: bool = False,
) -> DataFrame:
"""Join this :py:class:`DataFrame` with another :py:class:`DataFrame`.
@@ -691,13 +739,72 @@ def join(
left_on: Join column of the left dataframe.
right_on: Join column of the right dataframe.
join_keys: Tuple of two lists of column names to join on. [Deprecated]
+ deduplicate: If ``True``, drop duplicate join columns from the
+ right DataFrame similar to PySpark's ``on`` behavior.
Returns:
DataFrame after join.
"""
- # This check is to prevent breaking API changes where users prior to
- # DF 43.0.0 would pass the join_keys as a positional argument instead
- # of a keyword argument.
+ join_preparation = self._prepare_join(
+ right, on, left_on, right_on, join_keys, deduplicate
+ )
+
+ result = DataFrame(
+ self.df.join(
+ join_preparation.modified_right.df,
+ how,
+ join_preparation.join_keys.left_names,
+ join_preparation.join_keys.right_names,
+ )
+ )
+
+ if (
+ deduplicate
+ and how in ("right", "full")
+ and join_preparation.join_keys.on is not None
+ ):
+ for left_name, right_alias in zip(
+ join_preparation.join_keys.left_names,
+ join_preparation.drop_cols,
+ ):
+ result = result.with_column(
+ left_name, coalesce(col(left_name), col(right_alias))
+ )
+
+ if join_preparation.drop_cols:
+ result = result.drop(*join_preparation.drop_cols)
+
+ return result
+
+ def _prepare_join(
+ self,
+ right: DataFrame,
+ on: str | Sequence[str] | tuple[list[str], list[str]] | None,
+ left_on: str | Sequence[str] | None,
+ right_on: str | Sequence[str] | None,
+ join_keys: tuple[list[str], list[str]] | None,
+ deduplicate: bool,
+ ) -> JoinPreparation:
+ """Prepare join keys and handle deduplication if requested.
+
+ This method combines join key resolution and deduplication preparation
+ to avoid parameter handling duplication and provide a unified interface.
+
+ Args:
+ right: The right DataFrame to join with.
+ on: Column names to join on in both dataframes.
+ left_on: Join column of the left dataframe.
+ right_on: Join column of the right dataframe.
+ join_keys: Tuple of two lists of column names to join on. [Deprecated]
+ deduplicate: If True, prepare right DataFrame for column deduplication.
+
+ Returns:
+ JoinPreparation containing resolved join keys, modified right DataFrame,
+ and columns to drop after joining.
+ """
+ # Step 1: Resolve join keys
+ # Handle the special case where on is a tuple of lists (legacy format)
+ resolved_on: str | Sequence[str] | None
if (
isinstance(on, tuple)
and len(on) == 2
@@ -706,7 +813,9 @@ def join(
):
# We know this is safe because we've checked the types
join_keys = on # type: ignore[assignment]
- on = None
+ resolved_on = None
+ else:
+ resolved_on = on # type: ignore[assignment]
if join_keys is not None:
warnings.warn(
@@ -717,12 +826,12 @@ def join(
left_on = join_keys[0]
right_on = join_keys[1]
- if on is not None:
+ if resolved_on is not None:
if left_on is not None or right_on is not None:
error_msg = "`left_on` or `right_on` should not provided with `on`"
raise ValueError(error_msg)
- left_on = on
- right_on = on
+ left_on = resolved_on
+ right_on = resolved_on
elif left_on is not None or right_on is not None:
if left_on is None or right_on is None:
error_msg = "`left_on` and `right_on` should both be provided."
@@ -730,12 +839,35 @@ def join(
else:
error_msg = "either `on` or `left_on` and `right_on` should be provided."
raise ValueError(error_msg)
- if isinstance(left_on, str):
- left_on = [left_on]
- if isinstance(right_on, str):
- right_on = [right_on]
- return DataFrame(self.df.join(right.df, how, left_on, right_on))
+ # At this point, left_on and right_on are guaranteed to be non-None
+ if left_on is None or right_on is None: # pragma: no cover - sanity check
+ msg = "join keys resolved to None"
+ raise ValueError(msg)
+
+ left_names = [left_on] if isinstance(left_on, str) else list(left_on)
+ right_names = [right_on] if isinstance(right_on, str) else list(right_on)
+
+ drop_cols: list[str] = []
+ modified_right = right
+
+ if deduplicate and resolved_on is not None:
+ on_cols = (
+ [resolved_on] if isinstance(resolved_on, str) else list(resolved_on)
+ )
+ modified_right, aliases = _deduplicate_right(right, on_cols)
+ drop_cols.extend(aliases)
+ right_names = aliases.copy()
+
+ join_keys_resolved = JoinKeys(
+ on=resolved_on, left_names=left_names, right_names=right_names
+ )
+
+ return JoinPreparation(
+ join_keys=join_keys_resolved,
+ modified_right=modified_right,
+ drop_cols=drop_cols,
+ )
def join_on(
self,
diff --git a/python/datafusion/dataframe_formatter.py b/python/datafusion/dataframe_formatter.py
index 27f00f9c..2323224b 100644
--- a/python/datafusion/dataframe_formatter.py
+++ b/python/datafusion/dataframe_formatter.py
@@ -135,9 +135,6 @@ class DataFrameHtmlFormatter:
session
"""
- # Class variable to track if styles have been loaded in the notebook
- _styles_loaded = False
-
def __init__(
self,
max_cell_length: int = 25,
@@ -260,23 +257,6 @@ def set_custom_header_builder(self, builder: Callable[[Any], str]) -> None:
"""
self._custom_header_builder = builder
- @classmethod
- def is_styles_loaded(cls) -> bool:
- """Check if HTML styles have been loaded in the current session.
-
- This method is primarily intended for debugging UI rendering issues
- related to style loading.
-
- Returns:
- True if styles have been loaded, False otherwise
-
- Example:
- >>> from datafusion.dataframe_formatter import DataFrameHtmlFormatter
- >>> DataFrameHtmlFormatter.is_styles_loaded()
- False
- """
- return cls._styles_loaded
-
def format_html(
self,
batches: list,
@@ -315,18 +295,7 @@ def format_html(
# Build HTML components
html = []
- # Only include styles and scripts if:
- # 1. Not using shared styles, OR
- # 2. Using shared styles but they haven't been loaded yet
- include_styles = (
- not self.use_shared_styles or not DataFrameHtmlFormatter._styles_loaded
- )
-
- if include_styles:
- html.extend(self._build_html_header())
- # If we're using shared styles, mark them as loaded
- if self.use_shared_styles:
- DataFrameHtmlFormatter._styles_loaded = True
+ html.extend(self._build_html_header())
html.extend(self._build_table_container_start())
@@ -338,7 +307,7 @@ def format_html(
html.append("")
# Add footer (JavaScript and messages)
- if include_styles and self.enable_cell_expansion:
+ if self.enable_cell_expansion:
html.append(self._get_javascript())
# Always add truncation message if needed (independent of styles)
@@ -375,14 +344,20 @@ def format_str(
def _build_html_header(self) -> list[str]:
"""Build the HTML header with CSS styles."""
- html = []
- html.append("")
+ html.append(f"")
return html
def _build_table_container_start(self) -> list[str]:
@@ -570,28 +545,31 @@ def _get_default_css(self) -> str:
def _get_javascript(self) -> str:
"""Get JavaScript code for interactive elements."""
return """
-
- """
+
+"""
class FormatterManager:
@@ -712,24 +690,9 @@ def reset_formatter() -> None:
>>> reset_formatter() # Reset formatter to default settings
"""
formatter = DataFrameHtmlFormatter()
- # Reset the styles_loaded flag to ensure styles will be reloaded
- DataFrameHtmlFormatter._styles_loaded = False
set_formatter(formatter)
-def reset_styles_loaded_state() -> None:
- """Reset the styles loaded state to force reloading of styles.
-
- This can be useful when switching between notebook sessions or
- when styles need to be refreshed.
-
- Example:
- >>> from datafusion.html_formatter import reset_styles_loaded_state
- >>> reset_styles_loaded_state() # Force styles to reload in next render
- """
- DataFrameHtmlFormatter._styles_loaded = False
-
-
def _refresh_formatter_reference() -> None:
"""Refresh formatter reference in any modules using it.
diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py
index c9ae38d8..76dcf54c 100644
--- a/python/tests/test_dataframe.py
+++ b/python/tests/test_dataframe.py
@@ -42,7 +42,6 @@
configure_formatter,
get_formatter,
reset_formatter,
- reset_styles_loaded_state,
)
from datafusion.expr import Window
from pyarrow.csv import write_csv
@@ -520,6 +519,52 @@ def test_join_on():
assert table.to_pydict() == expected
+def test_join_deduplicate():
+ ctx = SessionContext()
+
+ batch = pa.RecordBatch.from_arrays(
+ [pa.array([1, 2]), pa.array(["l1", "l2"])],
+ names=["id", "left_val"],
+ )
+ left = ctx.create_dataframe([[batch]], "l")
+
+ batch = pa.RecordBatch.from_arrays(
+ [pa.array([1, 2]), pa.array(["r1", "r2"])],
+ names=["id", "right_val"],
+ )
+ right = ctx.create_dataframe([[batch]], "r")
+
+ joined = left.join(right, on="id", deduplicate=True)
+ joined = joined.sort(column("id"))
+ table = pa.Table.from_batches(joined.collect())
+
+ expected = {"id": [1, 2], "right_val": ["r1", "r2"], "left_val": ["l1", "l2"]}
+ assert table.to_pydict() == expected
+
+
+def test_join_deduplicate_multi():
+ ctx = SessionContext()
+
+ batch = pa.RecordBatch.from_arrays(
+ [pa.array([1, 2]), pa.array([3, 4]), pa.array(["x", "y"])],
+ names=["a", "b", "l"],
+ )
+ left = ctx.create_dataframe([[batch]], "l")
+
+ batch = pa.RecordBatch.from_arrays(
+ [pa.array([1, 2]), pa.array([3, 4]), pa.array(["u", "v"])],
+ names=["a", "b", "r"],
+ )
+ right = ctx.create_dataframe([[batch]], "r")
+
+ joined = left.join(right, on=["a", "b"], deduplicate=True)
+ joined = joined.sort(column("a"), column("b"))
+ table = pa.Table.from_batches(joined.collect())
+
+ expected = {"a": [1, 2], "b": [3, 4], "r": ["u", "v"], "l": ["x", "y"]}
+ assert table.to_pydict() == expected
+
+
def test_distinct():
ctx = SessionContext()
@@ -2177,27 +2222,15 @@ def test_html_formatter_shared_styles(df, clean_formatter_state):
# First, ensure we're using shared styles
configure_formatter(use_shared_styles=True)
- # Get HTML output for first table - should include styles
html_first = df._repr_html_()
-
- # Verify styles are included in first render
- assert "