Skip to content

[WIP][SPARK-52622][PS] Avoid CAST_INVALID_INPUT of DataFrame.melt in ANSI mode #51326

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion python/pyspark/pandas/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
TimestampType,
TimestampNTZType,
NullType,
LongType,
)
from pyspark.sql.window import Window
from pyspark import pandas as ps # For running doctests and reference resolution in PyCharm.
Expand Down Expand Up @@ -126,6 +127,7 @@
validate_mode,
verify_temp_column_name,
log_advice,
infer_common_type,
)
from pyspark.pandas.generic import Frame
from pyspark.pandas.internal import (
Expand Down Expand Up @@ -10617,12 +10619,27 @@ def melt(
else:
var_name = [var_name] # type: ignore[list-item]

value_col_names = [
name_like_string(label) for label in column_labels if label in value_vars
]
value_col_types = [
self._internal.spark_frame.schema[col].dataType for col in value_col_names
]
common_type = infer_common_type(value_col_types)
use_cast = is_ansi_mode_enabled(self._internal.spark_frame.sql_ctx.sparkSession)

pairs = F.explode(
F.array(
*[
F.struct(
*[F.lit(c).alias(name) for c, name in zip(label, var_name)],
*[self._internal.spark_column_for(label).alias(value_name)],
*[
(
self._internal.spark_column_for(label).cast(common_type)
if use_cast
else self._internal.spark_column_for(label)
).alias(value_name)
],
)
for label in column_labels
if label in value_vars
Expand Down
26 changes: 25 additions & 1 deletion python/pyspark/pandas/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from pandas.api.types import is_list_like # type: ignore[attr-defined]

from pyspark.sql import functions as F, Column, DataFrame as PySparkDataFrame, SparkSession
from pyspark.sql.types import DoubleType
from pyspark.sql.types import *
from pyspark.sql.utils import is_remote
from pyspark.errors import PySparkTypeError, UnsupportedOperationException
from pyspark import pandas as ps # noqa: F401
Expand Down Expand Up @@ -1096,6 +1096,30 @@ def is_ansi_mode_enabled(spark: SparkSession) -> bool:
)


def infer_common_type(types: list[DataType]) -> DataType:
# Define promotion order
type_priority = [
BooleanType,
ByteType,
ShortType,
IntegerType,
LongType,
FloatType,
DoubleType,
StringType,
]

for t in type_priority:
if all(isinstance(typ, t) or isinstance(typ, NullType) for typ in types):
return t()

# Promote mixed numeric types
if all(isinstance(t, (IntegerType, LongType, FloatType, DoubleType, NullType)) for t in types):
return DoubleType()

return StringType() # fallback


def _test() -> None:
import os
import doctest
Expand Down