From 4d789c9d3d74cc752fde034fa6d2d5402c5d5330 Mon Sep 17 00:00:00 2001 From: Xinrong Meng Date: Thu, 3 Jul 2025 17:28:48 -0700 Subject: [PATCH 01/10] eq --- python/pyspark/pandas/data_type_ops/base.py | 25 ++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/python/pyspark/pandas/data_type_ops/base.py b/python/pyspark/pandas/data_type_ops/base.py index b4a6b1abbcaf9..5cac08886b032 100644 --- a/python/pyspark/pandas/data_type_ops/base.py +++ b/python/pyspark/pandas/data_type_ops/base.py @@ -35,6 +35,7 @@ DecimalType, FractionalType, IntegralType, + LongType, MapType, NullType, NumericType, @@ -52,6 +53,7 @@ extension_object_dtypes_available, spark_type_to_pandas_dtype, ) +from pyspark.pandas.utils import is_ansi_mode_enabled if extension_dtypes_available: from pandas import Int8Dtype, Int16Dtype, Int32Dtype, Int64Dtype @@ -392,10 +394,12 @@ def ge(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: raise TypeError(">= can not be applied to %s." % self.pretty_name) def eq(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: + from pyspark.pandas.internal import InternalField + if isinstance(right, (list, tuple)): from pyspark.pandas.series import first_series, scol_for from pyspark.pandas.frame import DataFrame - from pyspark.pandas.internal import NATURAL_ORDER_COLUMN_NAME, InternalField + from pyspark.pandas.internal import NATURAL_ORDER_COLUMN_NAME if len(left) != len(right): raise ValueError("Lengths must be equal") @@ -482,6 +486,25 @@ def eq(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: else: from pyspark.pandas.base import column_op + if is_ansi_mode_enabled(left._internal.spark_frame.sparkSession): + from pyspark.pandas.base import IndexOpsMixin + + def are_both_numeric(left_dtype, right_dtype): + return pd.api.types.is_numeric_dtype( + left_dtype + ) and pd.api.types.is_numeric_dtype(right_dtype) + + left = transform_boolean_operand_to_numeric(left, spark_type=LongType()) + right = transform_boolean_operand_to_numeric(right, spark_type=LongType()) + left_dtype = left.dtype + if isinstance(right, (IndexOpsMixin, np.ndarray)): + right_dtype = right.dtype + else: + right_dtype = pd.Series([right]).dtype + + if left_dtype != right_dtype and not are_both_numeric(left_dtype, right_dtype): + return left._with_new_scol(F.lit(False)) + return column_op(PySparkColumn.__eq__)(left, right) def ne(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: From 716017f87e08716607f0bcb69bbb151b8981ec6f Mon Sep 17 00:00:00 2001 From: Xinrong Meng Date: Mon, 21 Jul 2025 17:07:20 -0700 Subject: [PATCH 02/10] num <-> str --- python/pyspark/pandas/data_type_ops/base.py | 54 ++++++++++++------- .../pyspark/pandas/data_type_ops/num_ops.py | 8 ++- 2 files changed, 42 insertions(+), 20 deletions(-) diff --git a/python/pyspark/pandas/data_type_ops/base.py b/python/pyspark/pandas/data_type_ops/base.py index 5cac08886b032..146a0d26c6daf 100644 --- a/python/pyspark/pandas/data_type_ops/base.py +++ b/python/pyspark/pandas/data_type_ops/base.py @@ -110,6 +110,37 @@ def transform_boolean_operand_to_numeric( return operand +def _should_return_all_false(left: IndexOpsLike, right: Any) -> bool: + """ + Determine if binary comparison should short-circuit to all False, + based on incompatible dtypes. + + This function is used to mimic pandas behavior when comparing operands + with non-matching dtypes that cannot be reasonably coerced, such as + comparing floats with strings. + + It internally transforms boolean operands to numeric (long) and checks + whether both operands are numeric or not. If they are not, and their + dtypes differ, the comparison result is considered to be all False. + """ + from pandas.api.types import is_numeric_dtype + from pyspark.pandas.base import IndexOpsMixin + + def are_both_numeric(left_dtype, right_dtype) -> bool: + return is_numeric_dtype(left_dtype) and is_numeric_dtype(right_dtype) + + left_dtype = left.dtype + + if isinstance(right, (IndexOpsMixin, np.ndarray, pd.Series)): + right_dtype = right.dtype + elif isinstance(right, (list, tuple)): + right_dtype = pd.Series(right).dtype + else: + right_dtype = pd.Series([right]).dtype + + return left_dtype != right_dtype and not are_both_numeric(left_dtype, right_dtype) + + def _as_categorical_type( index_ops: IndexOpsLike, dtype: CategoricalDtype, spark_type: DataType ) -> IndexOpsLike: @@ -396,6 +427,10 @@ def ge(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: def eq(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: from pyspark.pandas.internal import InternalField + if is_ansi_mode_enabled(left._internal.spark_frame.sparkSession): + if _should_return_all_false(left, right): + return left._with_new_scol(F.lit(False)) + if isinstance(right, (list, tuple)): from pyspark.pandas.series import first_series, scol_for from pyspark.pandas.frame import DataFrame @@ -486,25 +521,6 @@ def eq(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: else: from pyspark.pandas.base import column_op - if is_ansi_mode_enabled(left._internal.spark_frame.sparkSession): - from pyspark.pandas.base import IndexOpsMixin - - def are_both_numeric(left_dtype, right_dtype): - return pd.api.types.is_numeric_dtype( - left_dtype - ) and pd.api.types.is_numeric_dtype(right_dtype) - - left = transform_boolean_operand_to_numeric(left, spark_type=LongType()) - right = transform_boolean_operand_to_numeric(right, spark_type=LongType()) - left_dtype = left.dtype - if isinstance(right, (IndexOpsMixin, np.ndarray)): - right_dtype = right.dtype - else: - right_dtype = pd.Series([right]).dtype - - if left_dtype != right_dtype and not are_both_numeric(left_dtype, right_dtype): - return left._with_new_scol(F.lit(False)) - return column_op(PySparkColumn.__eq__)(left, right) def ne(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: diff --git a/python/pyspark/pandas/data_type_ops/num_ops.py b/python/pyspark/pandas/data_type_ops/num_ops.py index 508b1f4984ba6..b99900cd61518 100644 --- a/python/pyspark/pandas/data_type_ops/num_ops.py +++ b/python/pyspark/pandas/data_type_ops/num_ops.py @@ -41,6 +41,7 @@ _sanitize_list_like, _is_valid_for_logical_operator, _is_boolean_type, + _should_return_all_false, ) from pyspark.pandas.typedef.typehints import extension_dtypes, pandas_on_spark_type from pyspark.pandas.utils import is_ansi_mode_enabled @@ -177,9 +178,14 @@ def abs(self, operand: IndexOpsLike) -> IndexOpsLike: def eq(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: # We can directly use `super().eq` when given object is list, tuple, dict or set. + if not isinstance(right, IndexOpsMixin) and is_list_like(right): return super().eq(left, right) - return pyspark_column_op("__eq__", left, right, fillna=False) + else: + if is_ansi_mode_enabled(left._internal.spark_frame.sparkSession): + if _should_return_all_false(left, right): + return left._with_new_scol(F.lit(False)) + return pyspark_column_op("__eq__", left, right, fillna=False) def ne(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: _sanitize_list_like(right) From 9eca812cd27635fb13c73d79dd048e9b6b099977 Mon Sep 17 00:00:00 2001 From: Xinrong Meng Date: Mon, 21 Jul 2025 17:16:21 -0700 Subject: [PATCH 03/10] num <-> bool --- python/pyspark/pandas/data_type_ops/base.py | 7 +++++++ python/pyspark/pandas/data_type_ops/num_ops.py | 6 +++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/python/pyspark/pandas/data_type_ops/base.py b/python/pyspark/pandas/data_type_ops/base.py index 146a0d26c6daf..8247d936beb06 100644 --- a/python/pyspark/pandas/data_type_ops/base.py +++ b/python/pyspark/pandas/data_type_ops/base.py @@ -52,6 +52,7 @@ extension_float_dtypes_available, extension_object_dtypes_available, spark_type_to_pandas_dtype, + as_spark_type, ) from pyspark.pandas.utils import is_ansi_mode_enabled @@ -521,6 +522,12 @@ def eq(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: else: from pyspark.pandas.base import column_op + if is_ansi_mode_enabled(left._internal.spark_frame.sparkSession): + if _is_boolean_type(left): + left = transform_boolean_operand_to_numeric( + left, spark_type=as_spark_type(right.dtype) + ) + return column_op(PySparkColumn.__eq__)(left, right) def ne(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: diff --git a/python/pyspark/pandas/data_type_ops/num_ops.py b/python/pyspark/pandas/data_type_ops/num_ops.py index b99900cd61518..dca880c7115c3 100644 --- a/python/pyspark/pandas/data_type_ops/num_ops.py +++ b/python/pyspark/pandas/data_type_ops/num_ops.py @@ -43,7 +43,7 @@ _is_boolean_type, _should_return_all_false, ) -from pyspark.pandas.typedef.typehints import extension_dtypes, pandas_on_spark_type +from pyspark.pandas.typedef.typehints import extension_dtypes, pandas_on_spark_type, as_spark_type from pyspark.pandas.utils import is_ansi_mode_enabled from pyspark.sql import functions as F, Column as PySparkColumn from pyspark.sql.types import ( @@ -185,6 +185,10 @@ def eq(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: if is_ansi_mode_enabled(left._internal.spark_frame.sparkSession): if _should_return_all_false(left, right): return left._with_new_scol(F.lit(False)) + if _is_boolean_type(right): + right = transform_boolean_operand_to_numeric( + right, spark_type=as_spark_type(left.dtype) + ) return pyspark_column_op("__eq__", left, right, fillna=False) def ne(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: From 1e1b190e6211e200aa685be1e715155d09f4ef82 Mon Sep 17 00:00:00 2001 From: Xinrong Meng Date: Tue, 22 Jul 2025 11:16:57 -0700 Subject: [PATCH 04/10] fix lint --- python/pyspark/pandas/data_type_ops/base.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/pyspark/pandas/data_type_ops/base.py b/python/pyspark/pandas/data_type_ops/base.py index 8247d936beb06..237f35c474a52 100644 --- a/python/pyspark/pandas/data_type_ops/base.py +++ b/python/pyspark/pandas/data_type_ops/base.py @@ -35,7 +35,6 @@ DecimalType, FractionalType, IntegralType, - LongType, MapType, NullType, NumericType, From 226312f478862587adb75ab7b06f879320a5f797 Mon Sep 17 00:00:00 2001 From: Xinrong Meng Date: Tue, 22 Jul 2025 11:21:10 -0700 Subject: [PATCH 05/10] docstr --- python/pyspark/pandas/data_type_ops/base.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/python/pyspark/pandas/data_type_ops/base.py b/python/pyspark/pandas/data_type_ops/base.py index 237f35c474a52..b040bd7885085 100644 --- a/python/pyspark/pandas/data_type_ops/base.py +++ b/python/pyspark/pandas/data_type_ops/base.py @@ -113,15 +113,7 @@ def transform_boolean_operand_to_numeric( def _should_return_all_false(left: IndexOpsLike, right: Any) -> bool: """ Determine if binary comparison should short-circuit to all False, - based on incompatible dtypes. - - This function is used to mimic pandas behavior when comparing operands - with non-matching dtypes that cannot be reasonably coerced, such as - comparing floats with strings. - - It internally transforms boolean operands to numeric (long) and checks - whether both operands are numeric or not. If they are not, and their - dtypes differ, the comparison result is considered to be all False. + based on incompatible dtypes: non-numeric vs. numeric (including bools). """ from pandas.api.types import is_numeric_dtype from pyspark.pandas.base import IndexOpsMixin From eac2321e11cee4119c8b0f00a4db112fbcf05a29 Mon Sep 17 00:00:00 2001 From: Xinrong Meng Date: Tue, 22 Jul 2025 11:34:42 -0700 Subject: [PATCH 06/10] fix + test --- python/pyspark/pandas/data_type_ops/base.py | 6 ++---- python/pyspark/pandas/data_type_ops/num_ops.py | 2 +- .../pandas/tests/data_type_ops/test_num_ops.py | 13 +++++++++++++ 3 files changed, 16 insertions(+), 5 deletions(-) diff --git a/python/pyspark/pandas/data_type_ops/base.py b/python/pyspark/pandas/data_type_ops/base.py index b040bd7885085..ab3638cd24a68 100644 --- a/python/pyspark/pandas/data_type_ops/base.py +++ b/python/pyspark/pandas/data_type_ops/base.py @@ -417,16 +417,14 @@ def ge(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: raise TypeError(">= can not be applied to %s." % self.pretty_name) def eq(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: - from pyspark.pandas.internal import InternalField - if is_ansi_mode_enabled(left._internal.spark_frame.sparkSession): if _should_return_all_false(left, right): - return left._with_new_scol(F.lit(False)) + return left._with_new_scol(F.lit(False)).rename(None) if isinstance(right, (list, tuple)): from pyspark.pandas.series import first_series, scol_for from pyspark.pandas.frame import DataFrame - from pyspark.pandas.internal import NATURAL_ORDER_COLUMN_NAME + from pyspark.pandas.internal import NATURAL_ORDER_COLUMN_NAME, InternalField if len(left) != len(right): raise ValueError("Lengths must be equal") diff --git a/python/pyspark/pandas/data_type_ops/num_ops.py b/python/pyspark/pandas/data_type_ops/num_ops.py index dca880c7115c3..94ce01b361674 100644 --- a/python/pyspark/pandas/data_type_ops/num_ops.py +++ b/python/pyspark/pandas/data_type_ops/num_ops.py @@ -184,7 +184,7 @@ def eq(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: else: if is_ansi_mode_enabled(left._internal.spark_frame.sparkSession): if _should_return_all_false(left, right): - return left._with_new_scol(F.lit(False)) + return left._with_new_scol(F.lit(False)).rename(None) if _is_boolean_type(right): right = transform_boolean_operand_to_numeric( right, spark_type=as_spark_type(left.dtype) diff --git a/python/pyspark/pandas/tests/data_type_ops/test_num_ops.py b/python/pyspark/pandas/tests/data_type_ops/test_num_ops.py index 03a794771a910..00fc04e362312 100644 --- a/python/pyspark/pandas/tests/data_type_ops/test_num_ops.py +++ b/python/pyspark/pandas/tests/data_type_ops/test_num_ops.py @@ -23,6 +23,7 @@ from pyspark import pandas as ps from pyspark.pandas.config import option_context from pyspark.testing.pandasutils import PandasOnSparkTestCase +from pyspark.testing.utils import is_ansi_mode_test from pyspark.pandas.tests.data_type_ops.testing_utils import OpsTestBase from pyspark.pandas.typedef.typehints import ( extension_dtypes_available, @@ -128,6 +129,18 @@ def test_invert(self): else: self.assertRaises(TypeError, lambda: ~psser) + def test_comparison_dtype_compatibility(self): + pdf = pd.DataFrame( + {"int": [1, 2], "bool": [True, False], "float": [0.1, 0.2], "str": ["1", "2"]} + ) + psdf = ps.from_pandas(pdf) + self.assert_eq(pdf["int"] == pdf["bool"], psdf["int"] == psdf["bool"]) + self.assert_eq(pdf["bool"] == pdf["int"], psdf["bool"] == psdf["int"]) + self.assert_eq(pdf["int"] == pdf["float"], psdf["int"] == psdf["float"]) + if is_ansi_mode_test: # TODO: match non-ansi behavior with pandas + self.assert_eq(pdf["int"] == pdf["str"], psdf["int"] == psdf["str"]) + self.assert_eq(pdf["float"] == pdf["bool"], psdf["float"] == psdf["bool"]) + def test_eq(self): pdf, psdf = self.pdf, self.psdf for col in self.numeric_df_cols: From 69d9532c44e3ee30fafaaf17995471ac84f050dc Mon Sep 17 00:00:00 2001 From: Xinrong Meng Date: Tue, 22 Jul 2025 16:59:22 -0700 Subject: [PATCH 07/10] lint --- python/pyspark/pandas/data_type_ops/base.py | 6 +++--- python/pyspark/pandas/data_type_ops/num_ops.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/pyspark/pandas/data_type_ops/base.py b/python/pyspark/pandas/data_type_ops/base.py index ab3638cd24a68..88d17ba9ae103 100644 --- a/python/pyspark/pandas/data_type_ops/base.py +++ b/python/pyspark/pandas/data_type_ops/base.py @@ -115,10 +115,10 @@ def _should_return_all_false(left: IndexOpsLike, right: Any) -> bool: Determine if binary comparison should short-circuit to all False, based on incompatible dtypes: non-numeric vs. numeric (including bools). """ - from pandas.api.types import is_numeric_dtype + from pandas.core.dtypes.common import is_numeric_dtype from pyspark.pandas.base import IndexOpsMixin - def are_both_numeric(left_dtype, right_dtype) -> bool: + def are_both_numeric(left_dtype: Dtype, right_dtype: Dtype) -> bool: return is_numeric_dtype(left_dtype) and is_numeric_dtype(right_dtype) left_dtype = left.dtype @@ -419,7 +419,7 @@ def ge(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: def eq(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: if is_ansi_mode_enabled(left._internal.spark_frame.sparkSession): if _should_return_all_false(left, right): - return left._with_new_scol(F.lit(False)).rename(None) + return left._with_new_scol(F.lit(False)).rename(None) # type: ignore[attr-defined] if isinstance(right, (list, tuple)): from pyspark.pandas.series import first_series, scol_for diff --git a/python/pyspark/pandas/data_type_ops/num_ops.py b/python/pyspark/pandas/data_type_ops/num_ops.py index 94ce01b361674..f68515984782a 100644 --- a/python/pyspark/pandas/data_type_ops/num_ops.py +++ b/python/pyspark/pandas/data_type_ops/num_ops.py @@ -184,7 +184,7 @@ def eq(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: else: if is_ansi_mode_enabled(left._internal.spark_frame.sparkSession): if _should_return_all_false(left, right): - return left._with_new_scol(F.lit(False)).rename(None) + return left._with_new_scol(F.lit(False)).rename(None) # type: ignore[attr-defined] if _is_boolean_type(right): right = transform_boolean_operand_to_numeric( right, spark_type=as_spark_type(left.dtype) From fd32e05bf5bc092d8e2562c87b3f1d0fea466e14 Mon Sep 17 00:00:00 2001 From: Xinrong Meng Date: Wed, 23 Jul 2025 10:56:21 -0700 Subject: [PATCH 08/10] lint --- python/pyspark/pandas/data_type_ops/num_ops.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/pyspark/pandas/data_type_ops/num_ops.py b/python/pyspark/pandas/data_type_ops/num_ops.py index f68515984782a..5dbde99611f98 100644 --- a/python/pyspark/pandas/data_type_ops/num_ops.py +++ b/python/pyspark/pandas/data_type_ops/num_ops.py @@ -184,7 +184,8 @@ def eq(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: else: if is_ansi_mode_enabled(left._internal.spark_frame.sparkSession): if _should_return_all_false(left, right): - return left._with_new_scol(F.lit(False)).rename(None) # type: ignore[attr-defined] + left_scol = left._with_new_scol(F.lit(False)) + return left_scol.rename(None) # type: ignore[attr-defined] if _is_boolean_type(right): right = transform_boolean_operand_to_numeric( right, spark_type=as_spark_type(left.dtype) From b1fef4636706100ed644e76293d17d56b75549fd Mon Sep 17 00:00:00 2001 From: Xinrong Meng Date: Wed, 23 Jul 2025 11:33:20 -0700 Subject: [PATCH 09/10] ansi transform_boolean_operand_to_numeric for bool == non-bool numeric --- python/pyspark/pandas/data_type_ops/base.py | 19 ++++++++++++++++--- .../pyspark/pandas/data_type_ops/num_ops.py | 2 +- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/python/pyspark/pandas/data_type_ops/base.py b/python/pyspark/pandas/data_type_ops/base.py index 88d17ba9ae103..5d003e706ddfb 100644 --- a/python/pyspark/pandas/data_type_ops/base.py +++ b/python/pyspark/pandas/data_type_ops/base.py @@ -23,6 +23,7 @@ import numpy as np import pandas as pd from pandas.api.types import CategoricalDtype +from pandas.core.dtypes.common import is_numeric_dtype from pyspark.sql import functions as F, Column as PySparkColumn from pyspark.sql.types import ( @@ -100,8 +101,18 @@ def transform_boolean_operand_to_numeric( dtype = spark_type_to_pandas_dtype( spark_type, use_extension_dtypes=operand._internal.data_fields[0].is_extension_dtype ) + + if is_ansi_mode_enabled(operand._internal.spark_frame.sparkSession): + casted = ( + F.when(operand.spark.column.isNull(), None) + .otherwise(F.when(operand.spark.column, F.lit(1)).otherwise(F.lit(0))) + .cast(spark_type) + ) + else: + casted = operand.spark.column.cast(spark_type) + return operand._with_new_scol( - operand.spark.column.cast(spark_type), + casted, field=operand._internal.data_fields[0].copy(dtype=dtype, spark_type=spark_type), ) elif isinstance(operand, bool): @@ -115,7 +126,6 @@ def _should_return_all_false(left: IndexOpsLike, right: Any) -> bool: Determine if binary comparison should short-circuit to all False, based on incompatible dtypes: non-numeric vs. numeric (including bools). """ - from pandas.core.dtypes.common import is_numeric_dtype from pyspark.pandas.base import IndexOpsMixin def are_both_numeric(left_dtype: Dtype, right_dtype: Dtype) -> bool: @@ -512,7 +522,10 @@ def eq(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: from pyspark.pandas.base import column_op if is_ansi_mode_enabled(left._internal.spark_frame.sparkSession): - if _is_boolean_type(left): + # Handle bool vs. non-bool numeric comparisons + left_is_bool = _is_boolean_type(left) + right_is_non_bool_numeric = is_numeric_dtype(right) and not _is_boolean_type(right) + if left_is_bool and right_is_non_bool_numeric: left = transform_boolean_operand_to_numeric( left, spark_type=as_spark_type(right.dtype) ) diff --git a/python/pyspark/pandas/data_type_ops/num_ops.py b/python/pyspark/pandas/data_type_ops/num_ops.py index 5dbde99611f98..41a4c62d8241f 100644 --- a/python/pyspark/pandas/data_type_ops/num_ops.py +++ b/python/pyspark/pandas/data_type_ops/num_ops.py @@ -186,7 +186,7 @@ def eq(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: if _should_return_all_false(left, right): left_scol = left._with_new_scol(F.lit(False)) return left_scol.rename(None) # type: ignore[attr-defined] - if _is_boolean_type(right): + if _is_boolean_type(right): # numeric vs. bool right = transform_boolean_operand_to_numeric( right, spark_type=as_spark_type(left.dtype) ) From 693e32abe4b2caac4f95aa83ea3b22c273bdbbff Mon Sep 17 00:00:00 2001 From: Xinrong Meng Date: Wed, 23 Jul 2025 14:42:23 -0700 Subject: [PATCH 10/10] fix --- python/pyspark/pandas/data_type_ops/base.py | 23 ++++++++----------- .../pyspark/pandas/data_type_ops/num_ops.py | 2 +- 2 files changed, 10 insertions(+), 15 deletions(-) diff --git a/python/pyspark/pandas/data_type_ops/base.py b/python/pyspark/pandas/data_type_ops/base.py index 5d003e706ddfb..b133f84f66fb1 100644 --- a/python/pyspark/pandas/data_type_ops/base.py +++ b/python/pyspark/pandas/data_type_ops/base.py @@ -101,18 +101,8 @@ def transform_boolean_operand_to_numeric( dtype = spark_type_to_pandas_dtype( spark_type, use_extension_dtypes=operand._internal.data_fields[0].is_extension_dtype ) - - if is_ansi_mode_enabled(operand._internal.spark_frame.sparkSession): - casted = ( - F.when(operand.spark.column.isNull(), None) - .otherwise(F.when(operand.spark.column, F.lit(1)).otherwise(F.lit(0))) - .cast(spark_type) - ) - else: - casted = operand.spark.column.cast(spark_type) - return operand._with_new_scol( - casted, + operand.spark.column.cast(spark_type), field=operand._internal.data_fields[0].copy(dtype=dtype, spark_type=spark_type), ) elif isinstance(operand, bool): @@ -526,9 +516,14 @@ def eq(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: left_is_bool = _is_boolean_type(left) right_is_non_bool_numeric = is_numeric_dtype(right) and not _is_boolean_type(right) if left_is_bool and right_is_non_bool_numeric: - left = transform_boolean_operand_to_numeric( - left, spark_type=as_spark_type(right.dtype) - ) + if isinstance(right, numbers.Number): + left = transform_boolean_operand_to_numeric( + left, spark_type=as_spark_type(type(right)) + ) + else: + left = transform_boolean_operand_to_numeric( + left, spark_type=right.spark.data_type + ) return column_op(PySparkColumn.__eq__)(left, right) diff --git a/python/pyspark/pandas/data_type_ops/num_ops.py b/python/pyspark/pandas/data_type_ops/num_ops.py index 41a4c62d8241f..4739c1407d9ec 100644 --- a/python/pyspark/pandas/data_type_ops/num_ops.py +++ b/python/pyspark/pandas/data_type_ops/num_ops.py @@ -188,7 +188,7 @@ def eq(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: return left_scol.rename(None) # type: ignore[attr-defined] if _is_boolean_type(right): # numeric vs. bool right = transform_boolean_operand_to_numeric( - right, spark_type=as_spark_type(left.dtype) + right, spark_type=left.spark.data_type ) return pyspark_column_op("__eq__", left, right, fillna=False)