Skip to content

Commit fdee7cb

Browse files
xinrong-mengasl3
authored andcommitted
[SPARK-52701][PS] Fix float32 type widening in mod with bool under ANSI
### What changes were proposed in this pull request? Fix float32 type widening in `mod` with bool under ANSI. ### Why are the changes needed? Ensure pandas on Spark works well with ANSI mode on. Part of https://issues.apache.org/jira/browse/SPARK-52700. ### Does this PR introduce _any_ user-facing change? Yes. `mod` under ANSI works as pandas. ```py (dev3.11) spark (mod_dtype) % SPARK_ANSI_SQL_MODE=False ./python/run-tests --python-executables=python3.11 --testnames "pyspark.pandas.tests.data_type_ops.test_num_mod NumModTests.test_mod" ... Tests passed in 8 seconds (dev3.11) spark (mod_dtype) % SPARK_ANSI_SQL_MODE=True ./python/run-tests --python-executables=python3.11 --testnames "pyspark.pandas.tests.data_type_ops.test_num_mod NumModTests.test_mod" ... Tests passed in 7 seconds ``` ### How was this patch tested? Unit tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes apache#51394 from xinrong-meng/mod_dtype. Authored-by: Xinrong Meng <xinrong@apache.org> Signed-off-by: Xinrong Meng <xinrong@apache.org>
1 parent a731384 commit fdee7cb

File tree

2 files changed

+31
-8
lines changed

2 files changed

+31
-8
lines changed

python/pyspark/pandas/data_type_ops/num_ops.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from pandas.api.types import ( # type: ignore[attr-defined]
2424
is_bool_dtype,
2525
is_integer_dtype,
26+
is_float_dtype,
2627
CategoricalDtype,
2728
is_list_like,
2829
)
@@ -42,7 +43,7 @@
4243
_is_valid_for_logical_operator,
4344
_is_boolean_type,
4445
)
45-
from pyspark.pandas.typedef.typehints import extension_dtypes, pandas_on_spark_type
46+
from pyspark.pandas.typedef.typehints import extension_dtypes, pandas_on_spark_type, as_spark_type
4647
from pyspark.pandas.utils import is_ansi_mode_enabled
4748
from pyspark.sql import functions as F, Column as PySparkColumn
4849
from pyspark.sql.types import (
@@ -69,6 +70,26 @@ def _non_fractional_astype(
6970
return _as_other_type(index_ops, dtype, spark_type)
7071

7172

73+
def _cast_back_float(
74+
expr: PySparkColumn, left_dtype: Union[str, type, Dtype], right: Any
75+
) -> PySparkColumn:
76+
"""
77+
Cast the result expression back to the original float dtype if needed.
78+
79+
This function ensures pandas on Spark matches pandas behavior when performing
80+
arithmetic operations involving float and boolean values. In such cases, under ANSI mode,
81+
Spark implicitly widen float32 to float64, which deviates from pandas behavior where the
82+
result retains float32.
83+
"""
84+
is_left_float = is_float_dtype(left_dtype)
85+
is_right_bool = isinstance(right, bool) or (
86+
hasattr(right, "dtype") and is_bool_dtype(right.dtype)
87+
)
88+
if is_left_float and is_right_bool:
89+
return expr.cast(as_spark_type(left_dtype))
90+
return expr
91+
92+
7293
class NumericOps(DataTypeOps):
7394
"""The class for binary operations of numeric pandas-on-Spark objects."""
7495

@@ -98,16 +119,19 @@ def mod(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
98119
raise TypeError("Modulo can not be applied to given types.")
99120
spark_session = left._internal.spark_frame.sparkSession
100121

101-
def mod(left: PySparkColumn, right: Any) -> PySparkColumn:
122+
def mod(left_op: PySparkColumn, right_op: Any) -> PySparkColumn:
102123
if is_ansi_mode_enabled(spark_session):
103-
return F.when(F.lit(right == 0), F.lit(None)).otherwise(
104-
((left % right) + right) % right
124+
expr = F.when(F.lit(right_op == 0), F.lit(None)).otherwise(
125+
((left_op % right_op) + right_op) % right_op
105126
)
127+
expr = _cast_back_float(expr, left.dtype, right)
106128
else:
107-
return ((left % right) + right) % right
129+
expr = ((left_op % right_op) + right_op) % right_op
130+
return expr
108131

109-
right = transform_boolean_operand_to_numeric(right, spark_type=left.spark.data_type)
110-
return column_op(mod)(left, right)
132+
new_right = transform_boolean_operand_to_numeric(right, spark_type=left.spark.data_type)
133+
134+
return column_op(mod)(left, new_right)
111135

112136
def pow(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
113137
_sanitize_list_like(right)

python/pyspark/pandas/tests/data_type_ops/test_num_mod.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ def float_pser(self):
3535
def float_psser(self):
3636
return ps.from_pandas(self.float_pser)
3737

38-
@unittest.skipIf(is_ansi_mode_test, ansi_mode_not_supported_message)
3938
def test_mod(self):
4039
pdf, psdf = self.pdf, self.psdf
4140
for col in self.numeric_df_cols:

0 commit comments

Comments
 (0)