diff --git a/python/docs/source/reference/pyspark.sql/functions.rst b/python/docs/source/reference/pyspark.sql/functions.rst index 003fdc0a00b54..e4175707aecd7 100644 --- a/python/docs/source/reference/pyspark.sql/functions.rst +++ b/python/docs/source/reference/pyspark.sql/functions.rst @@ -299,6 +299,7 @@ Date and Timestamp Functions timestamp_micros timestamp_millis timestamp_seconds + time_diff time_trunc to_date to_time diff --git a/python/pyspark/sql/connect/functions/builtin.py b/python/pyspark/sql/connect/functions/builtin.py index aee4a7572a35b..2668b7a526fdd 100644 --- a/python/pyspark/sql/connect/functions/builtin.py +++ b/python/pyspark/sql/connect/functions/builtin.py @@ -3650,6 +3650,13 @@ def timestamp_seconds(col: "ColumnOrName") -> Column: timestamp_seconds.__doc__ = pysparkfuncs.timestamp_seconds.__doc__ +def time_diff(unit: "ColumnOrName", start: "ColumnOrName", end: "ColumnOrName") -> Column: + return _invoke_function_over_columns("time_diff", unit, start, end) + + +time_diff.__doc__ = pysparkfuncs.time_diff.__doc__ + + def time_trunc(unit: "ColumnOrName", time: "ColumnOrName") -> Column: return _invoke_function_over_columns("time_trunc", unit, time) diff --git a/python/pyspark/sql/functions/__init__.py b/python/pyspark/sql/functions/__init__.py index 7c3f4cbc1a4ff..e1b320c98f7fe 100644 --- a/python/pyspark/sql/functions/__init__.py +++ b/python/pyspark/sql/functions/__init__.py @@ -248,6 +248,7 @@ "timestamp_micros", "timestamp_millis", "timestamp_seconds", + "time_diff", "time_trunc", "to_date", "to_time", diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index 0dd0aea7bced7..24baace54621b 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -12710,6 +12710,49 @@ def timestamp_seconds(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("timestamp_seconds", col) +@_try_remote_functions +def time_diff(unit: "ColumnOrName", start: "ColumnOrName", end: "ColumnOrName") -> Column: + """ + Returns the difference between two times, measured in specified units. + + .. versionadded:: 4.1.0 + + Parameters + ---------- + unit : :class:`~pyspark.sql.Column` or column name + The unit to truncate the time to. Supported units are: "HOUR", "MINUTE", "SECOND", + "MILLISECOND", and "MICROSECOND". The unit is case-insensitive. + start : :class:`~pyspark.sql.Column` or column name + A starting time. + end : :class:`~pyspark.sql.Column` or column name + An ending time. + + Returns + ------- + :class:`~pyspark.sql.Column` + The difference between two times, in the specified units. + + See Also + -------- + :meth:`pyspark.sql.functions.date_diff` + :meth:`pyspark.sql.functions.timestamp_diff` + + Examples + -------- + >>> from pyspark.sql import functions as sf + >>> df = spark.createDataFrame( + ... [("HOUR", "13:08:15", "21:30:28")], ['unit', 'start', 'end']).withColumn("start", + ... sf.col("start").cast("time")).withColumn("end", sf.col("end").cast("time")) + >>> df.select('*', sf.time_diff('unit', 'start', 'end')).show() + +----+--------+--------+---------------------------+ + |unit| start| end|time_diff(unit, start, end)| + +----+--------+--------+---------------------------+ + |HOUR|13:08:15|21:30:28| 8| + +----+--------+--------+---------------------------+ + """ + return _invoke_function_over_columns("time_diff", unit, start, end) + + @_try_remote_functions def time_trunc(unit: "ColumnOrName", time: "ColumnOrName") -> Column: """ diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py index 91e519c6f8c77..41c07a61eb1e3 100644 --- a/python/pyspark/sql/tests/test_functions.py +++ b/python/pyspark/sql/tests/test_functions.py @@ -81,10 +81,7 @@ def test_function_parity(self): missing_in_py = jvm_fn_set.difference(py_fn_set) # Functions that we expect to be missing in python until they are added to pyspark - expected_missing_in_py = set( - # TODO(SPARK-53108): Implement the time_diff function in Python - ["time_diff"] - ) + expected_missing_in_py = set() self.assertEqual( expected_missing_in_py, missing_in_py, "Missing functions in pyspark not as expected" @@ -403,6 +400,19 @@ def test_rand_functions(self): rndn2 = df.select("key", F.randn(0)).collect() self.assertEqual(sorted(rndn1), sorted(rndn2)) + def test_time_diff(self): + # SPARK-53111: test the time_diff function. + df = self.spark.range(1).select( + F.lit("hour").alias("unit"), + F.lit(datetime.time(20, 30, 29)).alias("start"), + F.lit(datetime.time(21, 30, 29)).alias("end"), + ) + result = 1 + row_from_col = df.select(F.time_diff(df.unit, df.start, df.end)).first() + self.assertEqual(row_from_col[0], result) + row_from_name = df.select(F.time_diff("unit", "start", "end")).first() + self.assertEqual(row_from_name[0], result) + def test_time_trunc(self): # SPARK-53110: test the time_trunc function. df = self.spark.range(1).select(