From 243d535cc5bc8a804710959dbe24e8a6e7e1559f Mon Sep 17 00:00:00 2001 From: Tai Le Manh Date: Wed, 21 May 2025 08:46:41 +0700 Subject: [PATCH 1/3] Implement spark-compatible factorical function Signed-off-by: Tai Le Manh --- README.md | 7 +- .../spark/src/function/math/factorial.rs | 199 ++++++++++++++++++ datafusion/spark/src/function/math/mod.rs | 9 +- .../test_files/spark/math/factorial.slt | 29 +++ 4 files changed, 239 insertions(+), 5 deletions(-) create mode 100644 datafusion/spark/src/function/math/factorial.rs create mode 100644 datafusion/sqllogictest/test_files/spark/math/factorial.slt diff --git a/README.md b/README.md index c142d8f366b2..3ad11176cc28 100644 --- a/README.md +++ b/README.md @@ -65,8 +65,7 @@ See [use cases] for examples. The following related subprojects target end users - [DataFusion Comet](https://github.com/apache/datafusion-comet/) is an accelerator for Apache Spark based on DataFusion. -"Out of the box," -DataFusion offers [SQL] and [`Dataframe`] APIs, excellent [performance], +"Out of the box," DataFusion offers [`SQL`] and [`Dataframe`] APIs, excellent [performance], built-in support for CSV, Parquet, JSON, and Avro, extensive customization, and a great community. @@ -96,7 +95,7 @@ Here are links to some important information ## What can you do with this crate? -DataFusion is great for building projects such as domain specific query engines, new database platforms and data pipelines, query languages and more. +DataFusion is great for building projects such as domain-specific query engines, new database platforms and data pipelines, query languages and more. It lets you start quickly from a fully working engine, and then customize those features specific to your use. [Click Here](https://datafusion.apache.org/user-guide/introduction.html#known-users) to see a list known users. ## Contributing to DataFusion @@ -123,7 +122,7 @@ Default features: - `regex_expressions`: regular expression functions, such as `regexp_match` - `unicode_expressions`: Include unicode aware functions such as `character_length` - `unparser`: enables support to reverse LogicalPlans back into SQL -- `recursive_protection`: uses [recursive](https://docs.rs/recursive/latest/recursive/) for stack overflow protection. +- `recursive_protection`: uses [recursive](https://docs.rs/recursive/latest/recursive/) for stack overflow protection Optional features: diff --git a/datafusion/spark/src/function/math/factorial.rs b/datafusion/spark/src/function/math/factorial.rs new file mode 100644 index 000000000000..80648387e893 --- /dev/null +++ b/datafusion/spark/src/function/math/factorial.rs @@ -0,0 +1,199 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use crate::function::error_utils::{ + invalid_arg_count_exec_err, unsupported_data_type_exec_err, +}; +use arrow::array::{Array, Int64Array}; +use arrow::datatypes::DataType; +use arrow::datatypes::DataType::Int64; +use datafusion_common::{ + cast::as_int64_array, exec_err, DataFusionError, Result, ScalarValue, +}; +use datafusion_expr::Signature; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Volatility}; + +/// +#[derive(Debug)] +pub struct SparkFactorial { + signature: Signature, + aliases: Vec, +} + +impl Default for SparkFactorial { + fn default() -> Self { + Self::new() + } +} + +impl SparkFactorial { + pub fn new() -> Self { + Self { + signature: Signature::uniform(1, vec![Int64], Volatility::Immutable), + aliases: vec![], + } + } +} + +impl ScalarUDFImpl for SparkFactorial { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "factorial" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(Int64) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + spark_factorial(&args.args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + if arg_types.len() != 1 { + return Err(invalid_arg_count_exec_err( + "factorial", + (1, 1), + arg_types.len(), + )); + } + match &arg_types[0] { + Int64 => Ok(vec![arg_types[0].clone()]), + _ => Err(unsupported_data_type_exec_err( + "factorial", + "Integer", + &arg_types[0], + )), + } + } +} + +const FACTORIALS: [i64; 21] = [ + 1, + 1, + 2, + 6, + 24, + 120, + 720, + 5040, + 40320, + 362880, + 3628800, + 39916800, + 479001600, + 6227020800, + 87178291200, + 1307674368000, + 20922789888000, + 355687428096000, + 6402373705728000, + 121645100408832000, + 2432902008176640000, +]; + +pub fn spark_factorial(args: &[ColumnarValue]) -> Result { + if args.len() != 1 { + return Err(DataFusionError::Internal( + "`factorial` expects exactly one argument".to_string(), + )); + } + + match &args[0] { + ColumnarValue::Scalar(ScalarValue::Int64(value)) => { + let result = compute_factorial(*value); + Ok(ColumnarValue::Scalar(ScalarValue::Int64(result))) + } + ColumnarValue::Scalar(other) => { + exec_err!("`factorial` got an unexpected scalar type: {:?}", other) + } + ColumnarValue::Array(array) => match array.data_type() { + Int64 => { + let array = as_int64_array(array)?; + + let result: Int64Array = array.iter().map(compute_factorial).collect(); + + Ok(ColumnarValue::Array(Arc::new(result))) + } + other => { + exec_err!("`factorial` got an unexpected argument type: {:?}", other) + } + }, + } +} + +#[inline] +fn compute_factorial(num: Option) -> Option { + num.filter(|&v| v >= 0 && v <= 20) + .map(|v| FACTORIALS[v as usize]) +} + +#[cfg(test)] +mod test { + use crate::function::math::factorial::spark_factorial; + use arrow::array::Int64Array; + use datafusion_common::cast::as_int64_array; + use datafusion_expr::ColumnarValue; + use std::sync::Arc; + + #[test] + fn test_spark_factorial() { + let input = Int64Array::from(vec![ + Some(0), + Some(1), + Some(2), + Some(4), + Some(20), + Some(21), + None, + ]); + + let args = ColumnarValue::Array(Arc::new(input)); + let result = spark_factorial(&[args]).unwrap(); + let result = match result { + ColumnarValue::Array(array) => array, + _ => panic!("Expected array"), + }; + + let actual = as_int64_array(&result).unwrap(); + let expected = Int64Array::from(vec![ + Some(1), + Some(1), + Some(2), + Some(24), + Some(2432902008176640000), + None, + None, + ]); + + assert_eq!(actual, &expected); + } +} diff --git a/datafusion/spark/src/function/math/mod.rs b/datafusion/spark/src/function/math/mod.rs index 80bcdc39a41d..1f2a3b1d67f6 100644 --- a/datafusion/spark/src/function/math/mod.rs +++ b/datafusion/spark/src/function/math/mod.rs @@ -16,6 +16,7 @@ // under the License. pub mod expm1; +pub mod factorial; pub mod hex; use datafusion_expr::ScalarUDF; @@ -23,15 +24,21 @@ use datafusion_functions::make_udf_function; use std::sync::Arc; make_udf_function!(expm1::SparkExpm1, expm1); +make_udf_function!(factorial::SparkFactorial, factorial); make_udf_function!(hex::SparkHex, hex); pub mod expr_fn { use datafusion_functions::export_functions; export_functions!((expm1, "Returns exp(expr) - 1 as a Float64.", arg1)); + export_functions!(( + factorial, + "Returns the factorial of expr. expr is [0..20]. Otherwise, null.", + arg1 + )); export_functions!((hex, "Computes hex value of the given column.", arg1)); } pub fn functions() -> Vec> { - vec![expm1(), hex()] + vec![expm1(), factorial(), hex()] } diff --git a/datafusion/sqllogictest/test_files/spark/math/factorial.slt b/datafusion/sqllogictest/test_files/spark/math/factorial.slt new file mode 100644 index 000000000000..5619d0a11096 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/factorial.slt @@ -0,0 +1,29 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +query I +SELECT factorial(5::INT); +---- +120 + +query I +SELECT factorial(a) from VALUES (0::INT), (20::INT), (21::INT), (NULL) AS t(a); +---- +1 +2432902008176640000 +NULL +NULL From 61cbbd9c84b44116bfdf4ecf35540f8474d3c539 Mon Sep 17 00:00:00 2001 From: Tai Le Manh Date: Wed, 21 May 2025 09:38:18 +0700 Subject: [PATCH 2/3] Remove clippy warning --- README.md | 7 ++++--- datafusion/spark/src/function/math/factorial.rs | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 3ad11176cc28..c142d8f366b2 100644 --- a/README.md +++ b/README.md @@ -65,7 +65,8 @@ See [use cases] for examples. The following related subprojects target end users - [DataFusion Comet](https://github.com/apache/datafusion-comet/) is an accelerator for Apache Spark based on DataFusion. -"Out of the box," DataFusion offers [`SQL`] and [`Dataframe`] APIs, excellent [performance], +"Out of the box," +DataFusion offers [SQL] and [`Dataframe`] APIs, excellent [performance], built-in support for CSV, Parquet, JSON, and Avro, extensive customization, and a great community. @@ -95,7 +96,7 @@ Here are links to some important information ## What can you do with this crate? -DataFusion is great for building projects such as domain-specific query engines, new database platforms and data pipelines, query languages and more. +DataFusion is great for building projects such as domain specific query engines, new database platforms and data pipelines, query languages and more. It lets you start quickly from a fully working engine, and then customize those features specific to your use. [Click Here](https://datafusion.apache.org/user-guide/introduction.html#known-users) to see a list known users. ## Contributing to DataFusion @@ -122,7 +123,7 @@ Default features: - `regex_expressions`: regular expression functions, such as `regexp_match` - `unicode_expressions`: Include unicode aware functions such as `character_length` - `unparser`: enables support to reverse LogicalPlans back into SQL -- `recursive_protection`: uses [recursive](https://docs.rs/recursive/latest/recursive/) for stack overflow protection +- `recursive_protection`: uses [recursive](https://docs.rs/recursive/latest/recursive/) for stack overflow protection. Optional features: diff --git a/datafusion/spark/src/function/math/factorial.rs b/datafusion/spark/src/function/math/factorial.rs index 80648387e893..623edbe997b1 100644 --- a/datafusion/spark/src/function/math/factorial.rs +++ b/datafusion/spark/src/function/math/factorial.rs @@ -152,7 +152,7 @@ pub fn spark_factorial(args: &[ColumnarValue]) -> Result) -> Option { - num.filter(|&v| v >= 0 && v <= 20) + num.filter(|&v| (0..=20).contains(&v)) .map(|v| FACTORIALS[v as usize]) } From 574084dea41eaa9c5adfc5f3d6c095b9f5db0a8f Mon Sep 17 00:00:00 2001 From: Tai Le Manh Date: Sat, 21 Jun 2025 10:33:27 +0700 Subject: [PATCH 3/3] Add unit tests --- .../spark/src/function/math/factorial.rs | 63 +++++++++---------- .../test_files/spark/math/factorial.slt | 42 ++++++++++++- 2 files changed, 70 insertions(+), 35 deletions(-) diff --git a/datafusion/spark/src/function/math/factorial.rs b/datafusion/spark/src/function/math/factorial.rs index 623edbe997b1..10f7f0696469 100644 --- a/datafusion/spark/src/function/math/factorial.rs +++ b/datafusion/spark/src/function/math/factorial.rs @@ -18,15 +18,11 @@ use std::any::Any; use std::sync::Arc; -use crate::function::error_utils::{ - invalid_arg_count_exec_err, unsupported_data_type_exec_err, -}; use arrow::array::{Array, Int64Array}; use arrow::datatypes::DataType; -use arrow::datatypes::DataType::Int64; -use datafusion_common::{ - cast::as_int64_array, exec_err, DataFusionError, Result, ScalarValue, -}; +use arrow::datatypes::DataType::{Int32, Int64}; +use datafusion_common::cast::as_int32_array; +use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::Signature; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Volatility}; @@ -46,7 +42,7 @@ impl Default for SparkFactorial { impl SparkFactorial { pub fn new() -> Self { Self { - signature: Signature::uniform(1, vec![Int64], Volatility::Immutable), + signature: Signature::exact(vec![Int32], Volatility::Immutable), aliases: vec![], } } @@ -76,24 +72,6 @@ impl ScalarUDFImpl for SparkFactorial { fn aliases(&self) -> &[String] { &self.aliases } - - fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - if arg_types.len() != 1 { - return Err(invalid_arg_count_exec_err( - "factorial", - (1, 1), - arg_types.len(), - )); - } - match &arg_types[0] { - Int64 => Ok(vec![arg_types[0].clone()]), - _ => Err(unsupported_data_type_exec_err( - "factorial", - "Integer", - &arg_types[0], - )), - } - } } const FACTORIALS: [i64; 21] = [ @@ -128,7 +106,7 @@ pub fn spark_factorial(args: &[ColumnarValue]) -> Result { + ColumnarValue::Scalar(ScalarValue::Int32(value)) => { let result = compute_factorial(*value); Ok(ColumnarValue::Scalar(ScalarValue::Int64(result))) } @@ -136,8 +114,8 @@ pub fn spark_factorial(args: &[ColumnarValue]) -> Result match array.data_type() { - Int64 => { - let array = as_int64_array(array)?; + Int32 => { + let array = as_int32_array(array)?; let result: Int64Array = array.iter().map(compute_factorial).collect(); @@ -151,7 +129,7 @@ pub fn spark_factorial(args: &[ColumnarValue]) -> Result) -> Option { +fn compute_factorial(num: Option) -> Option { num.filter(|&v| (0..=20).contains(&v)) .map(|v| FACTORIALS[v as usize]) } @@ -159,14 +137,16 @@ fn compute_factorial(num: Option) -> Option { #[cfg(test)] mod test { use crate::function::math::factorial::spark_factorial; - use arrow::array::Int64Array; + use arrow::array::{Int32Array, Int64Array}; use datafusion_common::cast::as_int64_array; + use datafusion_common::ScalarValue; use datafusion_expr::ColumnarValue; use std::sync::Arc; #[test] - fn test_spark_factorial() { - let input = Int64Array::from(vec![ + fn test_spark_factorial_array() { + let input = Int32Array::from(vec![ + Some(-1), Some(0), Some(1), Some(2), @@ -185,6 +165,7 @@ mod test { let actual = as_int64_array(&result).unwrap(); let expected = Int64Array::from(vec![ + None, Some(1), Some(1), Some(2), @@ -196,4 +177,20 @@ mod test { assert_eq!(actual, &expected); } + + #[test] + fn test_spark_factorial_scalar() { + let input = ScalarValue::Int32(Some(5)); + + let args = ColumnarValue::Scalar(input); + let result = spark_factorial(&[args]).unwrap(); + let result = match result { + ColumnarValue::Scalar(ScalarValue::Int64(val)) => val, + _ => panic!("Expected scalar"), + }; + let actual = result.unwrap(); + let expected = 120_i64; + + assert_eq!(actual, expected); + } } diff --git a/datafusion/sqllogictest/test_files/spark/math/factorial.slt b/datafusion/sqllogictest/test_files/spark/math/factorial.slt index eecdfcffc8e4..f8eae5d95ab8 100644 --- a/datafusion/sqllogictest/test_files/spark/math/factorial.slt +++ b/datafusion/sqllogictest/test_files/spark/math/factorial.slt @@ -23,6 +23,44 @@ ## Original Query: SELECT factorial(5); ## PySpark 3.5.5 Result: {'factorial(5)': 120, 'typeof(factorial(5))': 'bigint', 'typeof(5)': 'int'} -#query -#SELECT factorial(5::int); +query I +SELECT factorial(5::INT); +---- +120 +query I +SELECT factorial(a) +FROM VALUES + (-1::INT), + (0::INT), (1::INT), (2::INT), (3::INT), (4::INT), (5::INT), (6::INT), (7::INT), (8::INT), (9::INT), (10::INT), + (11::INT), (12::INT), (13::INT), (14::INT), (15::INT), (16::INT), (17::INT), (18::INT), (19::INT), (20::INT), + (21::INT), + (NULL) AS t(a); +---- +NULL +1 +1 +2 +6 +24 +120 +720 +5040 +40320 +362880 +3628800 +39916800 +479001600 +6227020800 +87178291200 +1307674368000 +20922789888000 +355687428096000 +6402373705728000 +121645100408832000 +2432902008176640000 +NULL +NULL + +query error Error during planning: Failed to coerce arguments to satisfy a call to 'factorial' function +SELECT factorial(5::BIGINT);