diff --git a/python/pyspark/pandas/data_type_ops/base.py b/python/pyspark/pandas/data_type_ops/base.py index b4a6b1abbcaf9..88d17ba9ae103 100644 --- a/python/pyspark/pandas/data_type_ops/base.py +++ b/python/pyspark/pandas/data_type_ops/base.py @@ -51,7 +51,9 @@ extension_float_dtypes_available, extension_object_dtypes_available, spark_type_to_pandas_dtype, + as_spark_type, ) +from pyspark.pandas.utils import is_ansi_mode_enabled if extension_dtypes_available: from pandas import Int8Dtype, Int16Dtype, Int32Dtype, Int64Dtype @@ -108,6 +110,29 @@ def transform_boolean_operand_to_numeric( return operand +def _should_return_all_false(left: IndexOpsLike, right: Any) -> bool: + """ + Determine if binary comparison should short-circuit to all False, + based on incompatible dtypes: non-numeric vs. numeric (including bools). + """ + from pandas.core.dtypes.common import is_numeric_dtype + from pyspark.pandas.base import IndexOpsMixin + + def are_both_numeric(left_dtype: Dtype, right_dtype: Dtype) -> bool: + return is_numeric_dtype(left_dtype) and is_numeric_dtype(right_dtype) + + left_dtype = left.dtype + + if isinstance(right, (IndexOpsMixin, np.ndarray, pd.Series)): + right_dtype = right.dtype + elif isinstance(right, (list, tuple)): + right_dtype = pd.Series(right).dtype + else: + right_dtype = pd.Series([right]).dtype + + return left_dtype != right_dtype and not are_both_numeric(left_dtype, right_dtype) + + def _as_categorical_type( index_ops: IndexOpsLike, dtype: CategoricalDtype, spark_type: DataType ) -> IndexOpsLike: @@ -392,6 +417,10 @@ def ge(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: raise TypeError(">= can not be applied to %s." % self.pretty_name) def eq(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: + if is_ansi_mode_enabled(left._internal.spark_frame.sparkSession): + if _should_return_all_false(left, right): + return left._with_new_scol(F.lit(False)).rename(None) # type: ignore[attr-defined] + if isinstance(right, (list, tuple)): from pyspark.pandas.series import first_series, scol_for from pyspark.pandas.frame import DataFrame @@ -482,6 +511,12 @@ def eq(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: else: from pyspark.pandas.base import column_op + if is_ansi_mode_enabled(left._internal.spark_frame.sparkSession): + if _is_boolean_type(left): + left = transform_boolean_operand_to_numeric( + left, spark_type=as_spark_type(right.dtype) + ) + return column_op(PySparkColumn.__eq__)(left, right) def ne(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: diff --git a/python/pyspark/pandas/data_type_ops/num_ops.py b/python/pyspark/pandas/data_type_ops/num_ops.py index 508b1f4984ba6..f68515984782a 100644 --- a/python/pyspark/pandas/data_type_ops/num_ops.py +++ b/python/pyspark/pandas/data_type_ops/num_ops.py @@ -41,8 +41,9 @@ _sanitize_list_like, _is_valid_for_logical_operator, _is_boolean_type, + _should_return_all_false, ) -from pyspark.pandas.typedef.typehints import extension_dtypes, pandas_on_spark_type +from pyspark.pandas.typedef.typehints import extension_dtypes, pandas_on_spark_type, as_spark_type from pyspark.pandas.utils import is_ansi_mode_enabled from pyspark.sql import functions as F, Column as PySparkColumn from pyspark.sql.types import ( @@ -177,9 +178,18 @@ def abs(self, operand: IndexOpsLike) -> IndexOpsLike: def eq(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: # We can directly use `super().eq` when given object is list, tuple, dict or set. + if not isinstance(right, IndexOpsMixin) and is_list_like(right): return super().eq(left, right) - return pyspark_column_op("__eq__", left, right, fillna=False) + else: + if is_ansi_mode_enabled(left._internal.spark_frame.sparkSession): + if _should_return_all_false(left, right): + return left._with_new_scol(F.lit(False)).rename(None) # type: ignore[attr-defined] + if _is_boolean_type(right): + right = transform_boolean_operand_to_numeric( + right, spark_type=as_spark_type(left.dtype) + ) + return pyspark_column_op("__eq__", left, right, fillna=False) def ne(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: _sanitize_list_like(right) diff --git a/python/pyspark/pandas/tests/data_type_ops/test_num_ops.py b/python/pyspark/pandas/tests/data_type_ops/test_num_ops.py index 03a794771a910..00fc04e362312 100644 --- a/python/pyspark/pandas/tests/data_type_ops/test_num_ops.py +++ b/python/pyspark/pandas/tests/data_type_ops/test_num_ops.py @@ -23,6 +23,7 @@ from pyspark import pandas as ps from pyspark.pandas.config import option_context from pyspark.testing.pandasutils import PandasOnSparkTestCase +from pyspark.testing.utils import is_ansi_mode_test from pyspark.pandas.tests.data_type_ops.testing_utils import OpsTestBase from pyspark.pandas.typedef.typehints import ( extension_dtypes_available, @@ -128,6 +129,18 @@ def test_invert(self): else: self.assertRaises(TypeError, lambda: ~psser) + def test_comparison_dtype_compatibility(self): + pdf = pd.DataFrame( + {"int": [1, 2], "bool": [True, False], "float": [0.1, 0.2], "str": ["1", "2"]} + ) + psdf = ps.from_pandas(pdf) + self.assert_eq(pdf["int"] == pdf["bool"], psdf["int"] == psdf["bool"]) + self.assert_eq(pdf["bool"] == pdf["int"], psdf["bool"] == psdf["int"]) + self.assert_eq(pdf["int"] == pdf["float"], psdf["int"] == psdf["float"]) + if is_ansi_mode_test: # TODO: match non-ansi behavior with pandas + self.assert_eq(pdf["int"] == pdf["str"], psdf["int"] == psdf["str"]) + self.assert_eq(pdf["float"] == pdf["bool"], psdf["float"] == psdf["bool"]) + def test_eq(self): pdf, psdf = self.pdf, self.psdf for col in self.numeric_df_cols: