diff --git a/docs/source/user-guide/common-operations/joins.rst b/docs/source/user-guide/common-operations/joins.rst index 40d92215..b5c111af 100644 --- a/docs/source/user-guide/common-operations/joins.rst +++ b/docs/source/user-guide/common-operations/joins.rst @@ -101,4 +101,48 @@ the right table. .. ipython:: python - left.join(right, left_on="customer_id", right_on="id", how="anti") \ No newline at end of file + left.join(right, left_on="customer_id", right_on="id", how="anti") + +Disambiguating Columns +---------------------- + +When the join key exists in both DataFrames under the same name, the result contains two columns with that name. Assign a name to each DataFrame to use as a prefix and avoid ambiguity. + +When you create a DataFrame with a ``name`` argument, that name is used as a prefix in ``col("name.column")`` to reference specific columns. + +.. ipython:: python + + from datafusion import col, SessionContext + ctx = SessionContext() + left = ctx.from_pydict({"id": [1, 2]}, name="l") + right = ctx.from_pydict({"id": [2, 3]}, name="r") + joined = left.join(right, on="id") + joined.select(col("l.id"), col("r.id")) + +Note that the columns in the result appear in the same order as specified in the ``select()`` call. + +You can remove the duplicate column after joining. Note that ``drop()`` returns a new DataFrame (DataFusion's API is immutable). + +.. ipython:: python + + joined.drop("r.id") + +Automatic Deduplication +---------------------- + +Use the ``deduplicate`` argument of :py:meth:`DataFrame.join` to automatically +drop the duplicate join column from the right DataFrame. Unlike PySpark which uses a ``_`` suffix by default, +DataFusion uses the ``__right_`` naming convention for conflicting columns when not using deduplication. + +.. ipython:: python + + left.join(right, on="id", deduplicate=True) + +After deduplication, you can select the join column (which comes from the left DataFrame) and other columns as usual: + +.. ipython:: python + + # Select the id column and other columns from both DataFrames + joined_dedup = left.join(right, on="id", deduplicate=True) + joined_dedup.select("id", "customer", "name") + diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index 61cb0943..94a68a44 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -21,7 +21,9 @@ from __future__ import annotations +import uuid import warnings +from dataclasses import dataclass from typing import ( TYPE_CHECKING, Any, @@ -44,6 +46,8 @@ from datafusion.plan import ExecutionPlan, LogicalPlan from datafusion.record_batch import RecordBatchStream +from .functions import coalesce, col + if TYPE_CHECKING: import pathlib from typing import Callable, Sequence @@ -57,6 +61,49 @@ from enum import Enum +@dataclass +class JoinKeys: + """Represents the resolved join keys for a DataFrame join operation.""" + + on: str | Sequence[str] | None + left_names: list[str] + right_names: list[str] + + +@dataclass +class JoinPreparation: + """Represents the complete preparation for a DataFrame join operation.""" + + join_keys: JoinKeys + modified_right: DataFrame + drop_cols: list[str] + + +def _deduplicate_right( + right: DataFrame, columns: Sequence[str] +) -> tuple[DataFrame, list[str]]: + """Rename join columns on the right DataFrame for deduplication.""" + existing_columns = set(right.schema().names) + modified = right + aliases: list[str] = [] + + for col_name in columns: + base_alias = f"__right_{col_name}" + alias = base_alias + counter = 0 + while alias in existing_columns: + counter += 1 + alias = f"{base_alias}_{counter}" + if alias in existing_columns: + alias = f"__temp_{uuid.uuid4().hex[:8]}_{col_name}" + + modified = modified.with_column_renamed(col_name, alias) + aliases.append(alias) + existing_columns.add(alias) + + return modified, aliases + + # excerpt from deltalake # https://github.com/apache/datafusion-python/pull/981#discussion_r1905619163 class Compression(Enum): @@ -678,6 +725,7 @@ def join( left_on: str | Sequence[str] | None = None, right_on: str | Sequence[str] | None = None, join_keys: tuple[list[str], list[str]] | None = None, + deduplicate: bool = False, ) -> DataFrame: """Join this :py:class:`DataFrame` with another :py:class:`DataFrame`. @@ -691,13 +739,72 @@ def join( left_on: Join column of the left dataframe. right_on: Join column of the right dataframe. join_keys: Tuple of two lists of column names to join on. [Deprecated] + deduplicate: If ``True``, drop duplicate join columns from the + right DataFrame similar to PySpark's ``on`` behavior. Returns: DataFrame after join. """ - # This check is to prevent breaking API changes where users prior to - # DF 43.0.0 would pass the join_keys as a positional argument instead - # of a keyword argument. + join_preparation = self._prepare_join( + right, on, left_on, right_on, join_keys, deduplicate + ) + + result = DataFrame( + self.df.join( + join_preparation.modified_right.df, + how, + join_preparation.join_keys.left_names, + join_preparation.join_keys.right_names, + ) + ) + + if ( + deduplicate + and how in ("right", "full") + and join_preparation.join_keys.on is not None + ): + for left_name, right_alias in zip( + join_preparation.join_keys.left_names, + join_preparation.drop_cols, + ): + result = result.with_column( + left_name, coalesce(col(left_name), col(right_alias)) + ) + + if join_preparation.drop_cols: + result = result.drop(*join_preparation.drop_cols) + + return result + + def _prepare_join( + self, + right: DataFrame, + on: str | Sequence[str] | tuple[list[str], list[str]] | None, + left_on: str | Sequence[str] | None, + right_on: str | Sequence[str] | None, + join_keys: tuple[list[str], list[str]] | None, + deduplicate: bool, + ) -> JoinPreparation: + """Prepare join keys and handle deduplication if requested. + + This method combines join key resolution and deduplication preparation + to avoid parameter handling duplication and provide a unified interface. + + Args: + right: The right DataFrame to join with. + on: Column names to join on in both dataframes. + left_on: Join column of the left dataframe. + right_on: Join column of the right dataframe. + join_keys: Tuple of two lists of column names to join on. [Deprecated] + deduplicate: If True, prepare right DataFrame for column deduplication. + + Returns: + JoinPreparation containing resolved join keys, modified right DataFrame, + and columns to drop after joining. + """ + # Step 1: Resolve join keys + # Handle the special case where on is a tuple of lists (legacy format) + resolved_on: str | Sequence[str] | None if ( isinstance(on, tuple) and len(on) == 2 @@ -706,7 +813,9 @@ def join( ): # We know this is safe because we've checked the types join_keys = on # type: ignore[assignment] - on = None + resolved_on = None + else: + resolved_on = on # type: ignore[assignment] if join_keys is not None: warnings.warn( @@ -717,12 +826,12 @@ def join( left_on = join_keys[0] right_on = join_keys[1] - if on is not None: + if resolved_on is not None: if left_on is not None or right_on is not None: error_msg = "`left_on` or `right_on` should not provided with `on`" raise ValueError(error_msg) - left_on = on - right_on = on + left_on = resolved_on + right_on = resolved_on elif left_on is not None or right_on is not None: if left_on is None or right_on is None: error_msg = "`left_on` and `right_on` should both be provided." @@ -730,12 +839,35 @@ def join( else: error_msg = "either `on` or `left_on` and `right_on` should be provided." raise ValueError(error_msg) - if isinstance(left_on, str): - left_on = [left_on] - if isinstance(right_on, str): - right_on = [right_on] - return DataFrame(self.df.join(right.df, how, left_on, right_on)) + # At this point, left_on and right_on are guaranteed to be non-None + if left_on is None or right_on is None: # pragma: no cover - sanity check + msg = "join keys resolved to None" + raise ValueError(msg) + + left_names = [left_on] if isinstance(left_on, str) else list(left_on) + right_names = [right_on] if isinstance(right_on, str) else list(right_on) + + drop_cols: list[str] = [] + modified_right = right + + if deduplicate and resolved_on is not None: + on_cols = ( + [resolved_on] if isinstance(resolved_on, str) else list(resolved_on) + ) + modified_right, aliases = _deduplicate_right(right, on_cols) + drop_cols.extend(aliases) + right_names = aliases.copy() + + join_keys_resolved = JoinKeys( + on=resolved_on, left_names=left_names, right_names=right_names + ) + + return JoinPreparation( + join_keys=join_keys_resolved, + modified_right=modified_right, + drop_cols=drop_cols, + ) def join_on( self, diff --git a/python/datafusion/dataframe_formatter.py b/python/datafusion/dataframe_formatter.py index 27f00f9c..2323224b 100644 --- a/python/datafusion/dataframe_formatter.py +++ b/python/datafusion/dataframe_formatter.py @@ -135,9 +135,6 @@ class DataFrameHtmlFormatter: session """ - # Class variable to track if styles have been loaded in the notebook - _styles_loaded = False - def __init__( self, max_cell_length: int = 25, @@ -260,23 +257,6 @@ def set_custom_header_builder(self, builder: Callable[[Any], str]) -> None: """ self._custom_header_builder = builder - @classmethod - def is_styles_loaded(cls) -> bool: - """Check if HTML styles have been loaded in the current session. - - This method is primarily intended for debugging UI rendering issues - related to style loading. - - Returns: - True if styles have been loaded, False otherwise - - Example: - >>> from datafusion.dataframe_formatter import DataFrameHtmlFormatter - >>> DataFrameHtmlFormatter.is_styles_loaded() - False - """ - return cls._styles_loaded - def format_html( self, batches: list, @@ -315,18 +295,7 @@ def format_html( # Build HTML components html = [] - # Only include styles and scripts if: - # 1. Not using shared styles, OR - # 2. Using shared styles but they haven't been loaded yet - include_styles = ( - not self.use_shared_styles or not DataFrameHtmlFormatter._styles_loaded - ) - - if include_styles: - html.extend(self._build_html_header()) - # If we're using shared styles, mark them as loaded - if self.use_shared_styles: - DataFrameHtmlFormatter._styles_loaded = True + html.extend(self._build_html_header()) html.extend(self._build_table_container_start()) @@ -338,7 +307,7 @@ def format_html( html.append("") # Add footer (JavaScript and messages) - if include_styles and self.enable_cell_expansion: + if self.enable_cell_expansion: html.append(self._get_javascript()) # Always add truncation message if needed (independent of styles) @@ -375,14 +344,20 @@ def format_str( def _build_html_header(self) -> list[str]: """Build the HTML header with CSS styles.""" - html = [] - html.append("") + html.append(f"") return html def _build_table_container_start(self) -> list[str]: @@ -570,28 +545,31 @@ def _get_default_css(self) -> str: def _get_javascript(self) -> str: """Get JavaScript code for interactive elements.""" return """ - - """ + +""" class FormatterManager: @@ -712,24 +690,9 @@ def reset_formatter() -> None: >>> reset_formatter() # Reset formatter to default settings """ formatter = DataFrameHtmlFormatter() - # Reset the styles_loaded flag to ensure styles will be reloaded - DataFrameHtmlFormatter._styles_loaded = False set_formatter(formatter) -def reset_styles_loaded_state() -> None: - """Reset the styles loaded state to force reloading of styles. - - This can be useful when switching between notebook sessions or - when styles need to be refreshed. - - Example: - >>> from datafusion.html_formatter import reset_styles_loaded_state - >>> reset_styles_loaded_state() # Force styles to reload in next render - """ - DataFrameHtmlFormatter._styles_loaded = False - - def _refresh_formatter_reference() -> None: """Refresh formatter reference in any modules using it. diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py index c9ae38d8..76dcf54c 100644 --- a/python/tests/test_dataframe.py +++ b/python/tests/test_dataframe.py @@ -42,7 +42,6 @@ configure_formatter, get_formatter, reset_formatter, - reset_styles_loaded_state, ) from datafusion.expr import Window from pyarrow.csv import write_csv @@ -520,6 +519,52 @@ def test_join_on(): assert table.to_pydict() == expected +def test_join_deduplicate(): + ctx = SessionContext() + + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2]), pa.array(["l1", "l2"])], + names=["id", "left_val"], + ) + left = ctx.create_dataframe([[batch]], "l") + + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2]), pa.array(["r1", "r2"])], + names=["id", "right_val"], + ) + right = ctx.create_dataframe([[batch]], "r") + + joined = left.join(right, on="id", deduplicate=True) + joined = joined.sort(column("id")) + table = pa.Table.from_batches(joined.collect()) + + expected = {"id": [1, 2], "right_val": ["r1", "r2"], "left_val": ["l1", "l2"]} + assert table.to_pydict() == expected + + +def test_join_deduplicate_multi(): + ctx = SessionContext() + + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2]), pa.array([3, 4]), pa.array(["x", "y"])], + names=["a", "b", "l"], + ) + left = ctx.create_dataframe([[batch]], "l") + + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2]), pa.array([3, 4]), pa.array(["u", "v"])], + names=["a", "b", "r"], + ) + right = ctx.create_dataframe([[batch]], "r") + + joined = left.join(right, on=["a", "b"], deduplicate=True) + joined = joined.sort(column("a"), column("b")) + table = pa.Table.from_batches(joined.collect()) + + expected = {"a": [1, 2], "b": [3, 4], "r": ["u", "v"], "l": ["x", "y"]} + assert table.to_pydict() == expected + + def test_distinct(): ctx = SessionContext() @@ -2177,27 +2222,15 @@ def test_html_formatter_shared_styles(df, clean_formatter_state): # First, ensure we're using shared styles configure_formatter(use_shared_styles=True) - # Get HTML output for first table - should include styles html_first = df._repr_html_() - - # Verify styles are included in first render - assert "