Skip to content

Commit 929027c

Browse files
xinrong-mengyhuang-db
authored andcommitted
[SPARK-52288][PS] Avoid INVALID_ARRAY_INDEX in split/rsplit when ANSI mode is on
### What changes were proposed in this pull request? Avoid INVALID_ARRAY_INDEX in `split`/`rsplit` when ANSI mode is on ### Why are the changes needed? Ensure pandas on Spark works well with ANSI mode on. Part of https://issues.apache.org/jira/browse/SPARK-52169. ### Does this PR introduce _any_ user-facing change? Yes. INVALID_ARRAY_INDEX no longer fails `split`/`rsplit` when ANSI mode is on ```py >>> spark.conf.get("spark.sql.ansi.enabled") 'true' >>> import pandas as pd >>> pser = pd.Series(["hello-world", "short"]) >>> psser = ps.from_pandas(pser) ``` FROM ```py >>> psser.str.split("-", n=1, expand=True) 25/05/28 14:52:10 ERROR Executor: Exception in task 10.0 in stage 2.0 (TID 15) org.apache.spark.SparkArrayIndexOutOfBoundsException: [INVALID_ARRAY_INDEX] The index 1 is out of bounds. The array has 1 elements. Use the SQL function `get()` to tolerate accessing element at invalid index and return NULL instead. SQLSTATE: 22003 == DataFrame == "__getitem__" was called from <stdin>:1 ... ``` TO ```py >>> psser.str.split("-", n=1, expand=True) 0 1 0 hello world 1 short None ``` ### How was this patch tested? Unit tests ### Was this patch authored or co-authored using generative AI tooling? No Closes apache#51006 from xinrong-meng/arr_idx_enable. Authored-by: Xinrong Meng <xinrong@apache.org> Signed-off-by: Takuya Ueshin <ueshin@databricks.com>
1 parent 58a57f0 commit 929027c

File tree

2 files changed

+19
-5
lines changed

2 files changed

+19
-5
lines changed

python/pyspark/pandas/strings.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import numpy as np
3333
import pandas as pd
3434

35+
from pyspark.pandas.utils import is_ansi_mode_enabled
3536
from pyspark.sql.types import StringType, BinaryType, ArrayType, LongType, MapType
3637
from pyspark.sql import functions as F
3738
from pyspark.sql.functions import pandas_udf
@@ -2031,7 +2032,13 @@ def pudf(s: pd.Series) -> pd.Series:
20312032
if expand:
20322033
psdf = psser.to_frame()
20332034
scol = psdf._internal.data_spark_columns[0]
2034-
spark_columns = [scol[i].alias(str(i)) for i in range(n + 1)]
2035+
spark_session = self._data._internal.spark_frame.sparkSession
2036+
if is_ansi_mode_enabled(spark_session):
2037+
spark_columns = [
2038+
F.try_element_at(scol, F.lit(i + 1)).alias(str(i)) for i in range(n + 1)
2039+
]
2040+
else:
2041+
spark_columns = [scol[i].alias(str(i)) for i in range(n + 1)]
20352042
column_labels = [(i,) for i in range(n + 1)]
20362043
internal = psdf._internal.with_new_columns(
20372044
spark_columns,
@@ -2178,7 +2185,13 @@ def pudf(s: pd.Series) -> pd.Series:
21782185
if expand:
21792186
psdf = psser.to_frame()
21802187
scol = psdf._internal.data_spark_columns[0]
2181-
spark_columns = [scol[i].alias(str(i)) for i in range(n + 1)]
2188+
spark_session = self._data._internal.spark_frame.sparkSession
2189+
if is_ansi_mode_enabled(spark_session):
2190+
spark_columns = [
2191+
F.try_element_at(scol, F.lit(i + 1)).alias(str(i)) for i in range(n + 1)
2192+
]
2193+
else:
2194+
spark_columns = [scol[i].alias(str(i)) for i in range(n + 1)]
21822195
column_labels = [(i,) for i in range(n + 1)]
21832196
internal = psdf._internal.with_new_columns(
21842197
spark_columns,

python/pyspark/pandas/tests/series/test_string_ops_adv.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from pyspark import pandas as ps
2323
from pyspark.testing.pandasutils import PandasOnSparkTestCase
2424
from pyspark.testing.sqlutils import SQLTestUtils
25-
from pyspark.testing.utils import is_ansi_mode_test, ansi_mode_not_supported_message
2625

2726

2827
class SeriesStringOpsAdvMixin:
@@ -174,7 +173,6 @@ def test_string_slice_replace(self):
174173
self.check_func(lambda x: x.str.slice_replace(stop=2, repl="X"))
175174
self.check_func(lambda x: x.str.slice_replace(start=1, stop=3, repl="X"))
176175

177-
@unittest.skipIf(is_ansi_mode_test, ansi_mode_not_supported_message)
178176
def test_string_split(self):
179177
self.check_func_on_series(lambda x: repr(x.str.split()), self.pser[:-1])
180178
self.check_func_on_series(lambda x: repr(x.str.split(r"p*")), self.pser[:-1])
@@ -185,7 +183,8 @@ def test_string_split(self):
185183
with self.assertRaises(NotImplementedError):
186184
self.check_func(lambda x: x.str.split(expand=True))
187185

188-
@unittest.skipIf(is_ansi_mode_test, ansi_mode_not_supported_message)
186+
self.check_func_on_series(lambda x: repr(x.str.split("-", n=1, expand=True)), pser)
187+
189188
def test_string_rsplit(self):
190189
self.check_func_on_series(lambda x: repr(x.str.rsplit()), self.pser[:-1])
191190
self.check_func_on_series(lambda x: repr(x.str.rsplit(r"p*")), self.pser[:-1])
@@ -196,6 +195,8 @@ def test_string_rsplit(self):
196195
with self.assertRaises(NotImplementedError):
197196
self.check_func(lambda x: x.str.rsplit(expand=True))
198197

198+
self.check_func_on_series(lambda x: repr(x.str.rsplit("-", n=1, expand=True)), pser)
199+
199200
def test_string_translate(self):
200201
m = str.maketrans({"a": "X", "e": "Y", "i": None})
201202
self.check_func(lambda x: x.str.translate(m))

0 commit comments

Comments
 (0)