diff --git a/python/docs/source/reference/pyspark.sql/column.rst b/python/docs/source/reference/pyspark.sql/column.rst index 208e7de1e26b6..7a3b42140d5df 100644 --- a/python/docs/source/reference/pyspark.sql/column.rst +++ b/python/docs/source/reference/pyspark.sql/column.rst @@ -58,6 +58,7 @@ Column Column.rlike Column.startswith Column.substr + Column.transform Column.try_cast Column.when Column.withField diff --git a/python/pyspark/sql/classic/column.py b/python/pyspark/sql/classic/column.py index fef65bcb5d54e..a0e1326f45e99 100644 --- a/python/pyspark/sql/classic/column.py +++ b/python/pyspark/sql/classic/column.py @@ -612,6 +612,9 @@ def over(self, window: "WindowSpec") -> ParentColumn: jc = self._jc.over(window._jspec) return Column(jc) + def transform(self, f: Callable[[ParentColumn], ParentColumn]) -> ParentColumn: + return f(self) + def outer(self) -> ParentColumn: jc = self._jc.outer() return Column(jc) diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 0aa8d359308dd..628fd9923e7f0 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -21,6 +21,7 @@ from typing import ( overload, Any, + Callable, TYPE_CHECKING, Union, ) @@ -1538,6 +1539,58 @@ def over(self, window: "WindowSpec") -> "Column": """ ... + @dispatch_col_method + def transform(self, f: Callable[["Column"], "Column"]) -> "Column": + """ + Applies a transformation function to this column. + + This method allows you to apply a function that takes a Column and returns a Column, + enabling method chaining and functional transformations. + + .. versionadded:: 4.1.0 + + Parameters + ---------- + f : callable + A function that takes a :class:`Column` and returns a :class:`Column`. + + Returns + ------- + :class:`Column` + The result of applying the function to this column. + + Examples + -------- + Example 1: Chain built-in functions + + >>> from pyspark.sql.functions import trim, upper + >>> df = spark.createDataFrame([(" hello ",), (" world ",)], ["text"]) + >>> df.select(df.text.transform(trim).transform(upper).alias("result")).show() + +------+ + |result| + +------+ + | HELLO| + | WORLD| + +------+ + + Example 2: Use lambda functions + + >>> df = spark.createDataFrame([(10,), (20,), (30,)], ["value"]) + >>> df.select( + ... df.value.transform(lambda c: c + 5) + ... .transform(lambda c: c * 2) + ... .transform(lambda c: c - 10).alias("result") + ... ).show() + +------+ + |result| + +------+ + | 20| + | 40| + | 60| + +------+ + """ + ... + @dispatch_col_method def outer(self) -> "Column": """ diff --git a/python/pyspark/sql/connect/column.py b/python/pyspark/sql/connect/column.py index 78960d9795220..93c85e1b095d1 100644 --- a/python/pyspark/sql/connect/column.py +++ b/python/pyspark/sql/connect/column.py @@ -25,6 +25,7 @@ from typing import ( TYPE_CHECKING, Any, + Callable, Union, Optional, Tuple, @@ -469,6 +470,9 @@ def over(self, window: "WindowSpec") -> ParentColumn: # type: ignore[override] return Column(WindowExpression(windowFunction=self._expr, windowSpec=window)) + def transform(self, f: Callable[[ParentColumn], ParentColumn]) -> ParentColumn: + return f(self) + def outer(self) -> ParentColumn: return Column(self._expr) diff --git a/python/pyspark/sql/tests/connect/test_connect_column.py b/python/pyspark/sql/tests/connect/test_connect_column.py index 8983d45d42d14..763d42b104829 100644 --- a/python/pyspark/sql/tests/connect/test_connect_column.py +++ b/python/pyspark/sql/tests/connect/test_connect_column.py @@ -1044,6 +1044,33 @@ def test_cast_default_column_name(self): ) self.assertEqual(cdf.columns, sdf.columns) + def test_transform(self): + # Test with built-in functions + cdf = self.connect.createDataFrame([(" hello ",), (" world ",)], ["text"]) + sdf = self.spark.createDataFrame([(" hello ",), (" world ",)], ["text"]) + + self.assert_eq( + cdf.select(cdf.text.transform(CF.trim).transform(CF.upper)).toPandas(), + sdf.select(sdf.text.transform(SF.trim).transform(SF.upper)).toPandas(), + ) + + # Test with lambda functions + cdf = self.connect.createDataFrame([(10,), (20,), (30,)], ["value"]) + sdf = self.spark.createDataFrame([(10,), (20,), (30,)], ["value"]) + + self.assert_eq( + cdf.select( + cdf.value.transform(lambda c: c + 5) + .transform(lambda c: c * 2) + .transform(lambda c: c - 10) + ).toPandas(), + sdf.select( + sdf.value.transform(lambda c: c + 5) + .transform(lambda c: c * 2) + .transform(lambda c: c - 10) + ).toPandas(), + ) + if __name__ == "__main__": import unittest diff --git a/python/pyspark/sql/tests/test_column.py b/python/pyspark/sql/tests/test_column.py index ae9010fbc6d4f..ad1102374c923 100644 --- a/python/pyspark/sql/tests/test_column.py +++ b/python/pyspark/sql/tests/test_column.py @@ -468,6 +468,24 @@ class StrEnum(Enum): for r, c, e in zip(result, cols, expected): self.assertEqual(r, e, str(c)) + def test_transform(self): + # Test with built-in functions + df = self.spark.createDataFrame([(" hello ",), (" world ",)], ["text"]) + result = df.select(df.text.transform(sf.trim).transform(sf.upper)).collect() + self.assertEqual(result[0][0], "HELLO") + self.assertEqual(result[1][0], "WORLD") + + # Test with lambda functions + df = self.spark.createDataFrame([(10,), (20,), (30,)], ["value"]) + result = df.select( + df.value.transform(lambda c: c + 5) + .transform(lambda c: c * 2) + .transform(lambda c: c - 10) + ).collect() + self.assertEqual(result[0][0], 20) + self.assertEqual(result[1][0], 40) + self.assertEqual(result[2][0], 60) + class ColumnTests(ColumnTestsMixin, ReusedSQLTestCase): pass