Skip to content

[SPARK-52646][PS] Avoid CAST_INVALID_INPUT of __eq__ in ANSI mode #51370

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
35 changes: 35 additions & 0 deletions python/pyspark/pandas/data_type_ops/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@
extension_float_dtypes_available,
extension_object_dtypes_available,
spark_type_to_pandas_dtype,
as_spark_type,
)
from pyspark.pandas.utils import is_ansi_mode_enabled

if extension_dtypes_available:
from pandas import Int8Dtype, Int16Dtype, Int32Dtype, Int64Dtype
Expand Down Expand Up @@ -108,6 +110,29 @@ def transform_boolean_operand_to_numeric(
return operand


def _should_return_all_false(left: IndexOpsLike, right: Any) -> bool:
"""
Determine if binary comparison should short-circuit to all False,
based on incompatible dtypes: non-numeric vs. numeric (including bools).
"""
from pandas.api.types import is_numeric_dtype
from pyspark.pandas.base import IndexOpsMixin

def are_both_numeric(left_dtype, right_dtype) -> bool:
return is_numeric_dtype(left_dtype) and is_numeric_dtype(right_dtype)

left_dtype = left.dtype

if isinstance(right, (IndexOpsMixin, np.ndarray, pd.Series)):
right_dtype = right.dtype
elif isinstance(right, (list, tuple)):
right_dtype = pd.Series(right).dtype
else:
right_dtype = pd.Series([right]).dtype

return left_dtype != right_dtype and not are_both_numeric(left_dtype, right_dtype)


def _as_categorical_type(
index_ops: IndexOpsLike, dtype: CategoricalDtype, spark_type: DataType
) -> IndexOpsLike:
Expand Down Expand Up @@ -392,6 +417,10 @@ def ge(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
raise TypeError(">= can not be applied to %s." % self.pretty_name)

def eq(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
if is_ansi_mode_enabled(left._internal.spark_frame.sparkSession):
if _should_return_all_false(left, right):
return left._with_new_scol(F.lit(False)).rename(None)

if isinstance(right, (list, tuple)):
from pyspark.pandas.series import first_series, scol_for
from pyspark.pandas.frame import DataFrame
Expand Down Expand Up @@ -482,6 +511,12 @@ def eq(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
else:
from pyspark.pandas.base import column_op

if is_ansi_mode_enabled(left._internal.spark_frame.sparkSession):
if _is_boolean_type(left):
left = transform_boolean_operand_to_numeric(
left, spark_type=as_spark_type(right.dtype)
)

return column_op(PySparkColumn.__eq__)(left, right)

def ne(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
Expand Down
14 changes: 12 additions & 2 deletions python/pyspark/pandas/data_type_ops/num_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@
_sanitize_list_like,
_is_valid_for_logical_operator,
_is_boolean_type,
_should_return_all_false,
)
from pyspark.pandas.typedef.typehints import extension_dtypes, pandas_on_spark_type
from pyspark.pandas.typedef.typehints import extension_dtypes, pandas_on_spark_type, as_spark_type
from pyspark.pandas.utils import is_ansi_mode_enabled
from pyspark.sql import functions as F, Column as PySparkColumn
from pyspark.sql.types import (
Expand Down Expand Up @@ -177,9 +178,18 @@ def abs(self, operand: IndexOpsLike) -> IndexOpsLike:

def eq(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
# We can directly use `super().eq` when given object is list, tuple, dict or set.

if not isinstance(right, IndexOpsMixin) and is_list_like(right):
return super().eq(left, right)
return pyspark_column_op("__eq__", left, right, fillna=False)
else:
if is_ansi_mode_enabled(left._internal.spark_frame.sparkSession):
if _should_return_all_false(left, right):
return left._with_new_scol(F.lit(False)).rename(None)
if _is_boolean_type(right):
right = transform_boolean_operand_to_numeric(
right, spark_type=as_spark_type(left.dtype)
)
return pyspark_column_op("__eq__", left, right, fillna=False)

def ne(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
_sanitize_list_like(right)
Expand Down
13 changes: 13 additions & 0 deletions python/pyspark/pandas/tests/data_type_ops/test_num_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from pyspark import pandas as ps
from pyspark.pandas.config import option_context
from pyspark.testing.pandasutils import PandasOnSparkTestCase
from pyspark.testing.utils import is_ansi_mode_test
from pyspark.pandas.tests.data_type_ops.testing_utils import OpsTestBase
from pyspark.pandas.typedef.typehints import (
extension_dtypes_available,
Expand Down Expand Up @@ -128,6 +129,18 @@ def test_invert(self):
else:
self.assertRaises(TypeError, lambda: ~psser)

def test_comparison_dtype_compatibility(self):
pdf = pd.DataFrame(
{"int": [1, 2], "bool": [True, False], "float": [0.1, 0.2], "str": ["1", "2"]}
)
psdf = ps.from_pandas(pdf)
self.assert_eq(pdf["int"] == pdf["bool"], psdf["int"] == psdf["bool"])
self.assert_eq(pdf["bool"] == pdf["int"], psdf["bool"] == psdf["int"])
self.assert_eq(pdf["int"] == pdf["float"], psdf["int"] == psdf["float"])
if is_ansi_mode_test: # TODO: match non-ansi behavior with pandas
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.assert_eq(pdf["int"] == pdf["str"], psdf["int"] == psdf["str"])
self.assert_eq(pdf["float"] == pdf["bool"], psdf["float"] == psdf["bool"])

def test_eq(self):
pdf, psdf = self.pdf, self.psdf
for col in self.numeric_df_cols:
Expand Down