diff --git a/datafusion/spark/src/function/math/factorial.rs b/datafusion/spark/src/function/math/factorial.rs new file mode 100644 index 000000000000..10f7f0696469 --- /dev/null +++ b/datafusion/spark/src/function/math/factorial.rs @@ -0,0 +1,196 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{Array, Int64Array}; +use arrow::datatypes::DataType; +use arrow::datatypes::DataType::{Int32, Int64}; +use datafusion_common::cast::as_int32_array; +use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue}; +use datafusion_expr::Signature; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Volatility}; + +/// +#[derive(Debug)] +pub struct SparkFactorial { + signature: Signature, + aliases: Vec, +} + +impl Default for SparkFactorial { + fn default() -> Self { + Self::new() + } +} + +impl SparkFactorial { + pub fn new() -> Self { + Self { + signature: Signature::exact(vec![Int32], Volatility::Immutable), + aliases: vec![], + } + } +} + +impl ScalarUDFImpl for SparkFactorial { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "factorial" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(Int64) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + spark_factorial(&args.args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +const FACTORIALS: [i64; 21] = [ + 1, + 1, + 2, + 6, + 24, + 120, + 720, + 5040, + 40320, + 362880, + 3628800, + 39916800, + 479001600, + 6227020800, + 87178291200, + 1307674368000, + 20922789888000, + 355687428096000, + 6402373705728000, + 121645100408832000, + 2432902008176640000, +]; + +pub fn spark_factorial(args: &[ColumnarValue]) -> Result { + if args.len() != 1 { + return Err(DataFusionError::Internal( + "`factorial` expects exactly one argument".to_string(), + )); + } + + match &args[0] { + ColumnarValue::Scalar(ScalarValue::Int32(value)) => { + let result = compute_factorial(*value); + Ok(ColumnarValue::Scalar(ScalarValue::Int64(result))) + } + ColumnarValue::Scalar(other) => { + exec_err!("`factorial` got an unexpected scalar type: {:?}", other) + } + ColumnarValue::Array(array) => match array.data_type() { + Int32 => { + let array = as_int32_array(array)?; + + let result: Int64Array = array.iter().map(compute_factorial).collect(); + + Ok(ColumnarValue::Array(Arc::new(result))) + } + other => { + exec_err!("`factorial` got an unexpected argument type: {:?}", other) + } + }, + } +} + +#[inline] +fn compute_factorial(num: Option) -> Option { + num.filter(|&v| (0..=20).contains(&v)) + .map(|v| FACTORIALS[v as usize]) +} + +#[cfg(test)] +mod test { + use crate::function::math::factorial::spark_factorial; + use arrow::array::{Int32Array, Int64Array}; + use datafusion_common::cast::as_int64_array; + use datafusion_common::ScalarValue; + use datafusion_expr::ColumnarValue; + use std::sync::Arc; + + #[test] + fn test_spark_factorial_array() { + let input = Int32Array::from(vec![ + Some(-1), + Some(0), + Some(1), + Some(2), + Some(4), + Some(20), + Some(21), + None, + ]); + + let args = ColumnarValue::Array(Arc::new(input)); + let result = spark_factorial(&[args]).unwrap(); + let result = match result { + ColumnarValue::Array(array) => array, + _ => panic!("Expected array"), + }; + + let actual = as_int64_array(&result).unwrap(); + let expected = Int64Array::from(vec![ + None, + Some(1), + Some(1), + Some(2), + Some(24), + Some(2432902008176640000), + None, + None, + ]); + + assert_eq!(actual, &expected); + } + + #[test] + fn test_spark_factorial_scalar() { + let input = ScalarValue::Int32(Some(5)); + + let args = ColumnarValue::Scalar(input); + let result = spark_factorial(&[args]).unwrap(); + let result = match result { + ColumnarValue::Scalar(ScalarValue::Int64(val)) => val, + _ => panic!("Expected scalar"), + }; + let actual = result.unwrap(); + let expected = 120_i64; + + assert_eq!(actual, expected); + } +} diff --git a/datafusion/spark/src/function/math/mod.rs b/datafusion/spark/src/function/math/mod.rs index 80bcdc39a41d..1f2a3b1d67f6 100644 --- a/datafusion/spark/src/function/math/mod.rs +++ b/datafusion/spark/src/function/math/mod.rs @@ -16,6 +16,7 @@ // under the License. pub mod expm1; +pub mod factorial; pub mod hex; use datafusion_expr::ScalarUDF; @@ -23,15 +24,21 @@ use datafusion_functions::make_udf_function; use std::sync::Arc; make_udf_function!(expm1::SparkExpm1, expm1); +make_udf_function!(factorial::SparkFactorial, factorial); make_udf_function!(hex::SparkHex, hex); pub mod expr_fn { use datafusion_functions::export_functions; export_functions!((expm1, "Returns exp(expr) - 1 as a Float64.", arg1)); + export_functions!(( + factorial, + "Returns the factorial of expr. expr is [0..20]. Otherwise, null.", + arg1 + )); export_functions!((hex, "Computes hex value of the given column.", arg1)); } pub fn functions() -> Vec> { - vec![expm1(), hex()] + vec![expm1(), factorial(), hex()] } diff --git a/datafusion/sqllogictest/test_files/spark/math/factorial.slt b/datafusion/sqllogictest/test_files/spark/math/factorial.slt index eecdfcffc8e4..f8eae5d95ab8 100644 --- a/datafusion/sqllogictest/test_files/spark/math/factorial.slt +++ b/datafusion/sqllogictest/test_files/spark/math/factorial.slt @@ -23,6 +23,44 @@ ## Original Query: SELECT factorial(5); ## PySpark 3.5.5 Result: {'factorial(5)': 120, 'typeof(factorial(5))': 'bigint', 'typeof(5)': 'int'} -#query -#SELECT factorial(5::int); +query I +SELECT factorial(5::INT); +---- +120 +query I +SELECT factorial(a) +FROM VALUES + (-1::INT), + (0::INT), (1::INT), (2::INT), (3::INT), (4::INT), (5::INT), (6::INT), (7::INT), (8::INT), (9::INT), (10::INT), + (11::INT), (12::INT), (13::INT), (14::INT), (15::INT), (16::INT), (17::INT), (18::INT), (19::INT), (20::INT), + (21::INT), + (NULL) AS t(a); +---- +NULL +1 +1 +2 +6 +24 +120 +720 +5040 +40320 +362880 +3628800 +39916800 +479001600 +6227020800 +87178291200 +1307674368000 +20922789888000 +355687428096000 +6402373705728000 +121645100408832000 +2432902008176640000 +NULL +NULL + +query error Error during planning: Failed to coerce arguments to satisfy a call to 'factorial' function +SELECT factorial(5::BIGINT);