Skip to content

Commit eac2321

Browse files
committed
fix + test
1 parent 226312f commit eac2321

File tree

3 files changed

+16
-5
lines changed

3 files changed

+16
-5
lines changed

python/pyspark/pandas/data_type_ops/base.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -417,16 +417,14 @@ def ge(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
417417
raise TypeError(">= can not be applied to %s." % self.pretty_name)
418418

419419
def eq(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
420-
from pyspark.pandas.internal import InternalField
421-
422420
if is_ansi_mode_enabled(left._internal.spark_frame.sparkSession):
423421
if _should_return_all_false(left, right):
424-
return left._with_new_scol(F.lit(False))
422+
return left._with_new_scol(F.lit(False)).rename(None)
425423

426424
if isinstance(right, (list, tuple)):
427425
from pyspark.pandas.series import first_series, scol_for
428426
from pyspark.pandas.frame import DataFrame
429-
from pyspark.pandas.internal import NATURAL_ORDER_COLUMN_NAME
427+
from pyspark.pandas.internal import NATURAL_ORDER_COLUMN_NAME, InternalField
430428

431429
if len(left) != len(right):
432430
raise ValueError("Lengths must be equal")

python/pyspark/pandas/data_type_ops/num_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def eq(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
184184
else:
185185
if is_ansi_mode_enabled(left._internal.spark_frame.sparkSession):
186186
if _should_return_all_false(left, right):
187-
return left._with_new_scol(F.lit(False))
187+
return left._with_new_scol(F.lit(False)).rename(None)
188188
if _is_boolean_type(right):
189189
right = transform_boolean_operand_to_numeric(
190190
right, spark_type=as_spark_type(left.dtype)

python/pyspark/pandas/tests/data_type_ops/test_num_ops.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from pyspark import pandas as ps
2424
from pyspark.pandas.config import option_context
2525
from pyspark.testing.pandasutils import PandasOnSparkTestCase
26+
from pyspark.testing.utils import is_ansi_mode_test
2627
from pyspark.pandas.tests.data_type_ops.testing_utils import OpsTestBase
2728
from pyspark.pandas.typedef.typehints import (
2829
extension_dtypes_available,
@@ -128,6 +129,18 @@ def test_invert(self):
128129
else:
129130
self.assertRaises(TypeError, lambda: ~psser)
130131

132+
def test_comparison_dtype_compatibility(self):
133+
pdf = pd.DataFrame(
134+
{"int": [1, 2], "bool": [True, False], "float": [0.1, 0.2], "str": ["1", "2"]}
135+
)
136+
psdf = ps.from_pandas(pdf)
137+
self.assert_eq(pdf["int"] == pdf["bool"], psdf["int"] == psdf["bool"])
138+
self.assert_eq(pdf["bool"] == pdf["int"], psdf["bool"] == psdf["int"])
139+
self.assert_eq(pdf["int"] == pdf["float"], psdf["int"] == psdf["float"])
140+
if is_ansi_mode_test: # TODO: match non-ansi behavior with pandas
141+
self.assert_eq(pdf["int"] == pdf["str"], psdf["int"] == psdf["str"])
142+
self.assert_eq(pdf["float"] == pdf["bool"], psdf["float"] == psdf["bool"])
143+
131144
def test_eq(self):
132145
pdf, psdf = self.pdf, self.psdf
133146
for col in self.numeric_df_cols:

0 commit comments

Comments
 (0)