diff --git a/docs/spark_expressions_support.md b/docs/spark_expressions_support.md index be7c911814..6c4364f26c 100644 --- a/docs/spark_expressions_support.md +++ b/docs/spark_expressions_support.md @@ -351,6 +351,8 @@ - [ ] input_file_name - [ ] monotonically_increasing_id - [ ] raise_error + - [x] rand + - [x] randn - [ ] spark_partition_id - [ ] typeof - [x] user diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 327e57bcfd..f76e10199f 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -103,7 +103,7 @@ use datafusion_comet_proto::{ use datafusion_comet_spark_expr::{ ArrayInsert, Avg, AvgDecimal, Cast, CheckOverflow, Correlation, Covariance, CreateNamedStruct, GetArrayStructFields, GetStructField, IfExpr, ListExtract, NormalizeNaNAndZero, RLike, - RandExpr, SparkCastOptions, Stddev, StringSpaceExpr, SubstringExpr, SumDecimal, + RandExpr, RandnExpr, SparkCastOptions, Stddev, StringSpaceExpr, SubstringExpr, SumDecimal, TimestampTruncExpr, ToJson, UnboundColumn, Variance, }; use itertools::Itertools; @@ -791,8 +791,12 @@ impl PhysicalPlanner { ))) } ExprStruct::Rand(expr) => { - let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?; - Ok(Arc::new(RandExpr::new(child, self.partition))) + let seed = expr.seed.wrapping_add(self.partition.into()); + Ok(Arc::new(RandExpr::new(seed))) + } + ExprStruct::Randn(expr) => { + let seed = expr.seed.wrapping_add(self.partition.into()); + Ok(Arc::new(RandnExpr::new(seed))) } expr => Err(GeneralError(format!("Not implemented: {expr:?}"))), } diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index 8f4c875eec..9f31beffdd 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -80,7 +80,8 @@ message Expr { ArrayInsert array_insert = 58; MathExpr integral_divide = 59; ToPrettyString to_pretty_string = 60; - UnaryExpr rand = 61; + Rand rand = 61; + Rand randn = 62; } } @@ -415,6 +416,10 @@ message ArrayJoin { Expr null_replacement_expr = 3; } +message Rand { + int64 seed = 1; +} + message DataType { enum DataTypeId { BOOL = 0; diff --git a/native/spark-expr/src/nondetermenistic_funcs/internal/mod.rs b/native/spark-expr/src/nondetermenistic_funcs/internal/mod.rs new file mode 100644 index 0000000000..c7437f0667 --- /dev/null +++ b/native/spark-expr/src/nondetermenistic_funcs/internal/mod.rs @@ -0,0 +1,21 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod rand_utils; + +pub use rand_utils::evaluate_batch_for_rand; +pub use rand_utils::StatefulSeedValueGenerator; diff --git a/native/spark-expr/src/nondetermenistic_funcs/internal/rand_utils.rs b/native/spark-expr/src/nondetermenistic_funcs/internal/rand_utils.rs new file mode 100644 index 0000000000..9abaaa9396 --- /dev/null +++ b/native/spark-expr/src/nondetermenistic_funcs/internal/rand_utils.rs @@ -0,0 +1,60 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Float64Array, Float64Builder}; +use datafusion::logical_expr::ColumnarValue; +use std::ops::Deref; +use std::sync::{Arc, Mutex}; + +pub fn evaluate_batch_for_rand( + state_holder: &Arc>>, + seed: i64, + num_rows: usize, +) -> datafusion::common::Result +where + R: StatefulSeedValueGenerator, + S: Copy, +{ + let seed_state = state_holder.lock().unwrap(); + let mut rnd = R::from_state_ref(seed_state, seed); + let mut arr_builder = Float64Builder::with_capacity(num_rows); + std::iter::repeat_with(|| rnd.next_value()) + .take(num_rows) + .for_each(|v| arr_builder.append_value(v)); + let array_ref = Arc::new(Float64Array::from(arr_builder.finish())); + let mut seed_state = state_holder.lock().unwrap(); + seed_state.replace(rnd.get_current_state()); + Ok(ColumnarValue::Array(array_ref)) +} + +pub trait StatefulSeedValueGenerator: Sized { + fn from_init_seed(init_seed: i64) -> Self; + + fn from_stored_state(stored_state: State) -> Self; + + fn next_value(&mut self) -> Value; + + fn get_current_state(&self) -> State; + + fn from_state_ref(state: impl Deref>, init_value: i64) -> Self { + if state.is_none() { + Self::from_init_seed(init_value) + } else { + Self::from_stored_state(state.unwrap()) + } + } +} diff --git a/native/spark-expr/src/nondetermenistic_funcs/mod.rs b/native/spark-expr/src/nondetermenistic_funcs/mod.rs index c5ff894e8e..94774acd51 100644 --- a/native/spark-expr/src/nondetermenistic_funcs/mod.rs +++ b/native/spark-expr/src/nondetermenistic_funcs/mod.rs @@ -15,6 +15,9 @@ // specific language governing permissions and limitations // under the License. +pub mod internal; pub mod rand; +pub mod randn; pub use rand::RandExpr; +pub use randn::RandnExpr; diff --git a/native/spark-expr/src/nondetermenistic_funcs/rand.rs b/native/spark-expr/src/nondetermenistic_funcs/rand.rs index d82c2cd92e..e548f78909 100644 --- a/native/spark-expr/src/nondetermenistic_funcs/rand.rs +++ b/native/spark-expr/src/nondetermenistic_funcs/rand.rs @@ -16,11 +16,11 @@ // under the License. use crate::hash_funcs::murmur3::spark_compatible_murmur3_hash; -use arrow::array::{Float64Array, Float64Builder, RecordBatch}; + +use crate::internal::{evaluate_batch_for_rand, StatefulSeedValueGenerator}; +use arrow::array::RecordBatch; use arrow::datatypes::{DataType, Schema}; use datafusion::common::Result; -use datafusion::common::ScalarValue; -use datafusion::error::DataFusionError; use datafusion::logical_expr::ColumnarValue; use datafusion::physical_expr::PhysicalExpr; use std::any::Any; @@ -42,21 +42,11 @@ const DOUBLE_UNIT: f64 = 1.1102230246251565e-16; const SPARK_MURMUR_ARRAY_SEED: u32 = 0x3c074a61; #[derive(Debug, Clone)] -struct XorShiftRandom { - seed: i64, +pub(crate) struct XorShiftRandom { + pub(crate) seed: i64, } impl XorShiftRandom { - fn from_init_seed(init_seed: i64) -> Self { - XorShiftRandom { - seed: Self::init_seed(init_seed), - } - } - - fn from_stored_seed(stored_seed: i64) -> Self { - XorShiftRandom { seed: stored_seed } - } - fn next(&mut self, bits: u8) -> i32 { let mut next_seed = self.seed ^ (self.seed << 21); next_seed ^= ((next_seed as u64) >> 35) as i64; @@ -70,60 +60,43 @@ impl XorShiftRandom { let b = self.next(27) as i64; ((a << 27) + b) as f64 * DOUBLE_UNIT } +} - fn init_seed(init: i64) -> i64 { - let bytes_repr = init.to_be_bytes(); +impl StatefulSeedValueGenerator for XorShiftRandom { + fn from_init_seed(init_seed: i64) -> Self { + let bytes_repr = init_seed.to_be_bytes(); let low_bits = spark_compatible_murmur3_hash(bytes_repr, SPARK_MURMUR_ARRAY_SEED); let high_bits = spark_compatible_murmur3_hash(bytes_repr, low_bits); - ((high_bits as i64) << 32) | (low_bits as i64 & 0xFFFFFFFFi64) + let init_seed = ((high_bits as i64) << 32) | (low_bits as i64 & 0xFFFFFFFFi64); + XorShiftRandom { seed: init_seed } + } + + fn from_stored_state(stored_state: i64) -> Self { + XorShiftRandom { seed: stored_state } + } + + fn next_value(&mut self) -> f64 { + self.next_f64() + } + + fn get_current_state(&self) -> i64 { + self.seed } } #[derive(Debug)] pub struct RandExpr { - seed: Arc, - init_seed_shift: i32, + seed: i64, state_holder: Arc>>, } impl RandExpr { - pub fn new(seed: Arc, init_seed_shift: i32) -> Self { + pub fn new(seed: i64) -> Self { Self { seed, - init_seed_shift, state_holder: Arc::new(Mutex::new(None::)), } } - - fn extract_init_state(seed: ScalarValue) -> Result { - if let ScalarValue::Int64(seed_opt) = seed.cast_to(&DataType::Int64)? { - Ok(seed_opt.unwrap_or(0)) - } else { - Err(DataFusionError::Internal( - "unexpected execution branch".to_string(), - )) - } - } - fn evaluate_batch(&self, seed: ScalarValue, num_rows: usize) -> Result { - let mut seed_state = self.state_holder.lock().unwrap(); - let mut rnd = if seed_state.is_none() { - let init_seed = RandExpr::extract_init_state(seed)?; - let init_seed = init_seed.wrapping_add(self.init_seed_shift as i64); - *seed_state = Some(init_seed); - XorShiftRandom::from_init_seed(init_seed) - } else { - let stored_seed = seed_state.unwrap(); - XorShiftRandom::from_stored_seed(stored_seed) - }; - - let mut arr_builder = Float64Builder::with_capacity(num_rows); - std::iter::repeat_with(|| rnd.next_f64()) - .take(num_rows) - .for_each(|v| arr_builder.append_value(v)); - let array_ref = Arc::new(Float64Array::from(arr_builder.finish())); - *seed_state = Some(rnd.seed); - Ok(ColumnarValue::Array(array_ref)) - } } impl Display for RandExpr { @@ -134,7 +107,7 @@ impl Display for RandExpr { impl PartialEq for RandExpr { fn eq(&self, other: &Self) -> bool { - self.seed.eq(&other.seed) && self.init_seed_shift == other.init_seed_shift + self.seed.eq(&other.seed) } } @@ -160,16 +133,15 @@ impl PhysicalExpr for RandExpr { } fn evaluate(&self, batch: &RecordBatch) -> Result { - match self.seed.evaluate(batch)? { - ColumnarValue::Scalar(seed) => self.evaluate_batch(seed, batch.num_rows()), - ColumnarValue::Array(_arr) => Err(DataFusionError::NotImplemented(format!( - "Only literal seeds are supported for {self}" - ))), - } + evaluate_batch_for_rand::( + &self.state_holder, + self.seed, + batch.num_rows(), + ) } fn children(&self) -> Vec<&Arc> { - vec![&self.seed] + vec![] } fn fmt_sql(&self, _: &mut Formatter<'_>) -> std::fmt::Result { @@ -178,26 +150,22 @@ impl PhysicalExpr for RandExpr { fn with_new_children( self: Arc, - children: Vec>, + _children: Vec>, ) -> Result> { - Ok(Arc::new(RandExpr::new( - Arc::clone(&children[0]), - self.init_seed_shift, - ))) + Ok(Arc::new(RandExpr::new(self.seed))) } } -pub fn rand(seed: Arc, init_seed_shift: i32) -> Result> { - Ok(Arc::new(RandExpr::new(seed, init_seed_shift))) +pub fn rand(seed: i64) -> Arc { + Arc::new(RandExpr::new(seed)) } #[cfg(test)] mod tests { use super::*; - use arrow::array::{Array, BooleanArray, Int64Array}; + use arrow::array::{Array, Float64Array, Int64Array}; use arrow::{array::StringArray, compute::concat, datatypes::*}; use datafusion::common::cast::as_float64_array; - use datafusion::physical_expr::expressions::lit; const SPARK_SEED_42_FIRST_5: [f64; 5] = [ 0.619189370225301, @@ -212,7 +180,7 @@ mod tests { let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]); let data = StringArray::from(vec![Some("foo"), None, None, Some("bar"), None]); let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(data)])?; - let rand_expr = rand(lit(42), 0)?; + let rand_expr = rand(42); let result = rand_expr.evaluate(&batch)?.into_array(batch.num_rows())?; let result = as_float64_array(&result)?; let expected = &Float64Array::from(Vec::from(SPARK_SEED_42_FIRST_5)); @@ -226,7 +194,7 @@ mod tests { let first_batch_data = Int64Array::from(vec![Some(42), None]); let second_batch_schema = first_batch_schema.clone(); let second_batch_data = Int64Array::from(vec![None, Some(-42), None]); - let rand_expr = rand(lit(42), 0)?; + let rand_expr = rand(42); let first_batch = RecordBatch::try_new( Arc::new(first_batch_schema), vec![Arc::new(first_batch_data)], @@ -251,23 +219,4 @@ mod tests { assert_eq!(final_result, expected); Ok(()) } - - #[test] - fn test_overflow_shift_seed() -> Result<()> { - let schema = Schema::new(vec![Field::new("a", DataType::Boolean, false)]); - let data = BooleanArray::from(vec![Some(true), Some(false)]); - let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(data)])?; - let max_seed_and_shift_expr = rand(lit(i64::MAX), 1)?; - let min_seed_no_shift_expr = rand(lit(i64::MIN), 0)?; - let first_expr_result = max_seed_and_shift_expr - .evaluate(&batch)? - .into_array(batch.num_rows())?; - let first_expr_result = as_float64_array(&first_expr_result)?; - let second_expr_result = min_seed_no_shift_expr - .evaluate(&batch)? - .into_array(batch.num_rows())?; - let second_expr_result = as_float64_array(&second_expr_result)?; - assert_eq!(first_expr_result, second_expr_result); - Ok(()) - } } diff --git a/native/spark-expr/src/nondetermenistic_funcs/randn.rs b/native/spark-expr/src/nondetermenistic_funcs/randn.rs new file mode 100644 index 0000000000..e1455b68e8 --- /dev/null +++ b/native/spark-expr/src/nondetermenistic_funcs/randn.rs @@ -0,0 +1,247 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::nondetermenistic_funcs::rand::XorShiftRandom; + +use crate::internal::{evaluate_batch_for_rand, StatefulSeedValueGenerator}; +use arrow::array::RecordBatch; +use arrow::datatypes::{DataType, Schema}; +use datafusion::logical_expr::ColumnarValue; +use datafusion::physical_expr::PhysicalExpr; +use std::any::Any; +use std::fmt::{Display, Formatter}; +use std::hash::{Hash, Hasher}; +use std::sync::{Arc, Mutex}; + +/// Stateful extension of the Marsaglia polar method (https://en.wikipedia.org/wiki/Marsaglia_polar_method) +/// to convert uniform distribution to the standard normal one used by Apache Spark. +/// For correct processing of batches having odd number of elements, we need to keep not used yet generated value as a part of the state. +/// Note about Comet <-> Spark equivalence: +/// Under the hood, the spark algorithm refers to java.util.Random relying on a module StrictMath. The latter uses +/// native implementations of floating-point operations (ln, exp, sin, cos) and ensures +/// they are stable across different platforms. +/// See: https://github.com/openjdk/jdk/blob/07c9f7138affdf0d42ecdc30adcb854515569985/src/java.base/share/classes/java/util/Random.java#L745 +/// Yet, for the Rust standard library this stability is not guaranteed (https://doc.rust-lang.org/std/primitive.f64.html#method.ln) +/// Moreover, potential usage of external library like rug (https://docs.rs/rug/latest/rug/) doesn't help because still there is no +/// guarantee it matches the StrictMath jvm implementation. +/// So, we can ensure only equivalence with some error tolerance between rust and spark(jvm). + +#[derive(Debug, Clone)] +struct XorShiftRandomForGaussian { + base_generator: XorShiftRandom, + next_gaussian: Option, +} + +impl XorShiftRandomForGaussian { + pub fn next_gaussian(&mut self) -> f64 { + if let Some(stored_value) = self.next_gaussian { + self.next_gaussian = None; + return stored_value; + } + let mut v1: f64; + let mut v2: f64; + let mut s: f64; + loop { + v1 = 2f64 * self.base_generator.next_f64() - 1f64; + v2 = 2f64 * self.base_generator.next_f64() - 1f64; + s = v1 * v1 + v2 * v2; + if s < 1f64 && s != 0f64 { + break; + } + } + let multiplier = (-2f64 * s.ln() / s).sqrt(); + self.next_gaussian = Some(v2 * multiplier); + v1 * multiplier + } +} + +type RandomGaussianState = (i64, Option); + +impl StatefulSeedValueGenerator for XorShiftRandomForGaussian { + fn from_init_seed(init_value: i64) -> Self { + XorShiftRandomForGaussian { + base_generator: XorShiftRandom::from_init_seed(init_value), + next_gaussian: None, + } + } + + fn from_stored_state(stored_state: RandomGaussianState) -> Self { + XorShiftRandomForGaussian { + base_generator: XorShiftRandom::from_stored_state(stored_state.0), + next_gaussian: stored_state.1, + } + } + + fn next_value(&mut self) -> f64 { + self.next_gaussian() + } + + fn get_current_state(&self) -> RandomGaussianState { + (self.base_generator.seed, self.next_gaussian) + } +} + +#[derive(Debug, Clone)] +pub struct RandnExpr { + seed: i64, + state_holder: Arc>>, +} + +impl RandnExpr { + pub fn new(seed: i64) -> Self { + Self { + seed, + state_holder: Arc::new(Mutex::new(None)), + } + } +} + +impl Display for RandnExpr { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "RANDN({})", self.seed) + } +} + +impl PartialEq for RandnExpr { + fn eq(&self, other: &Self) -> bool { + self.seed.eq(&other.seed) + } +} + +impl Eq for RandnExpr {} + +impl Hash for RandnExpr { + fn hash(&self, state: &mut H) { + self.children().hash(state); + } +} + +impl PhysicalExpr for RandnExpr { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, _input_schema: &Schema) -> datafusion::common::Result { + Ok(DataType::Float64) + } + + fn nullable(&self, _input_schema: &Schema) -> datafusion::common::Result { + Ok(false) + } + + fn evaluate(&self, batch: &RecordBatch) -> datafusion::common::Result { + evaluate_batch_for_rand::( + &self.state_holder, + self.seed, + batch.num_rows(), + ) + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn fmt_sql(&self, _: &mut Formatter<'_>) -> std::fmt::Result { + unimplemented!() + } + + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> datafusion::common::Result> { + Ok(Arc::new(RandnExpr::new(self.seed))) + } +} + +pub fn randn(seed: i64) -> Arc { + Arc::new(RandnExpr::new(seed)) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{Array, Float64Array, Int64Array}; + use arrow::{array::StringArray, compute::concat, datatypes::*}; + use datafusion::common::cast::as_float64_array; + + const PRECISION_TOLERANCE: f64 = 1e-6; + + const SPARK_SEED_42_FIRST_5_GAUSSIAN: [f64; 5] = [ + 2.384479054241165, + 0.1920934041293524, + 0.7337336533286575, + -0.5224480195716871, + 2.060084179317831, + ]; + + #[test] + fn test_rand_single_batch() -> datafusion::common::Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]); + let data = StringArray::from(vec![Some("foo"), None, None, Some("bar"), None]); + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(data)])?; + let randn_expr = randn(42); + let result = randn_expr.evaluate(&batch)?.into_array(batch.num_rows())?; + let result = as_float64_array(&result)?; + let expected = &Float64Array::from(Vec::from(SPARK_SEED_42_FIRST_5_GAUSSIAN)); + assert_eq_with_tolerance(result, expected); + Ok(()) + } + + #[test] + fn test_rand_multi_batch() -> datafusion::common::Result<()> { + let first_batch_schema = Schema::new(vec![Field::new("a", DataType::Int64, true)]); + let first_batch_data = Int64Array::from(vec![Some(24), None, None]); + let second_batch_schema = first_batch_schema.clone(); + let second_batch_data = Int64Array::from(vec![None, Some(22)]); + let randn_expr = randn(42); + let first_batch = RecordBatch::try_new( + Arc::new(first_batch_schema), + vec![Arc::new(first_batch_data)], + )?; + let first_batch_result = randn_expr + .evaluate(&first_batch)? + .into_array(first_batch.num_rows())?; + let second_batch = RecordBatch::try_new( + Arc::new(second_batch_schema), + vec![Arc::new(second_batch_data)], + )?; + let second_batch_result = randn_expr + .evaluate(&second_batch)? + .into_array(second_batch.num_rows())?; + let result_arrays: Vec<&dyn Array> = vec![ + as_float64_array(&first_batch_result)?, + as_float64_array(&second_batch_result)?, + ]; + let result_arrays = &concat(&result_arrays)?; + let final_result = as_float64_array(result_arrays)?; + let expected = &Float64Array::from(Vec::from(SPARK_SEED_42_FIRST_5_GAUSSIAN)); + assert_eq_with_tolerance(final_result, expected); + Ok(()) + } + + fn assert_eq_with_tolerance(left: &Float64Array, right: &Float64Array) { + assert_eq!(left.len(), right.len()); + left.iter().zip(right.iter()).for_each(|(l, r)| { + assert!( + (l.unwrap() - r.unwrap()).abs() < PRECISION_TOLERANCE, + "difference between {:?} and {:?} is larger than acceptable precision", + l.unwrap(), + r.unwrap() + ) + }) + } +} diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 970329b28f..4e5631ed2c 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -1860,12 +1860,27 @@ object QueryPlanSerde extends Logging with CometExprShim { case _: ArrayExcept => convert(CometArrayExcept) case Rand(child, _) => - createUnaryExpr( - expr, - child, - inputs, - binding, - (builder, unaryExpr) => builder.setRand(unaryExpr)) + val seed = child match { + case Literal(seed: Long, _) => Some(seed) + case Literal(null, _) => Some(0L) + case _ => None + } + seed.map(seed => + ExprOuterClass.Expr + .newBuilder() + .setRand(ExprOuterClass.Rand.newBuilder().setSeed(seed)) + .build()) + case Randn(child, _) => + val seed = child match { + case Literal(seed: Long, _) => Some(seed) + case Literal(null, _) => Some(0L) + case _ => None + } + seed.map(seed => + ExprOuterClass.Expr + .newBuilder() + .setRandn(ExprOuterClass.Rand.newBuilder().setSeed(seed)) + .build()) case expr => QueryPlanSerde.exprSerdeMap.get(expr.getClass) match { case Some(handler) => convert(handler) diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 6f29d48878..a09f337f8b 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -2745,26 +2745,57 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } - test("rand expression with random parameters") { + private def testOnShuffledRangeWithRandomParameters(testLogic: DataFrame => Unit): Unit = { val partitionsNumber = Random.nextInt(10) + 1 val rowsNumber = Random.nextInt(500) - val seed = Random.nextLong() // use this value to have both single-batch and multi-batch partitions - val cometBatchSize = math.max(1, math.ceil(rowsNumber.toDouble / partitionsNumber).toInt) + val cometBatchSize = math.max(1, math.floor(rowsNumber.toDouble / partitionsNumber).toInt) withSQLConf("spark.comet.batchSize" -> cometBatchSize.toString) { withParquetDataFrame((0 until rowsNumber).map(Tuple1.apply)) { df => - val dfWithRandParameters = df.repartition(partitionsNumber).withColumn("rnd", rand(seed)) - checkSparkAnswerAndOperator(dfWithRandParameters) - val dfWithOverflowSeed = - df.repartition(partitionsNumber).withColumn("rnd", rand(Long.MaxValue)) - checkSparkAnswerAndOperator(dfWithOverflowSeed) - val dfWithNullSeed = - df.repartition(partitionsNumber).selectExpr("_1", "rand(null) as rnd") - checkSparkAnswerAndOperator(dfWithNullSeed) + testLogic(df.repartition(partitionsNumber)) } } } + test("rand expression with random parameters") { + testOnShuffledRangeWithRandomParameters { df => + val seed = Random.nextLong() + val dfWithRandParameters = df.withColumn("rnd", rand(seed)) + checkSparkAnswerAndOperator(dfWithRandParameters) + val dfWithOverflowSeed = df.withColumn("rnd", rand(Long.MaxValue)) + checkSparkAnswerAndOperator(dfWithOverflowSeed) + val dfWithNullSeed = df.selectExpr("_1", "rand(null) as rnd") + checkSparkAnswerAndOperator(dfWithNullSeed) + } + } + + test("randn expression with random parameters") { + testOnShuffledRangeWithRandomParameters { df => + val seed = Random.nextLong() + val dfWithRandParameters = df.withColumn("randn", randn(seed)) + checkSparkAnswerAndOperatorWithTol(dfWithRandParameters) + val dfWithOverflowSeed = df.withColumn("randn", randn(Long.MaxValue)) + checkSparkAnswerAndOperatorWithTol(dfWithOverflowSeed) + val dfWithNullSeed = df.selectExpr("_1", "randn(null) as randn") + checkSparkAnswerAndOperatorWithTol(dfWithNullSeed) + } + } + + test("multiple nondetermenistic expressions with shuffle") { + testOnShuffledRangeWithRandomParameters { df => + val seed1 = Random.nextLong() + val seed2 = Random.nextLong() + val complexRandDf = df + .withColumn("rand1", rand(seed1)) + .withColumn("randn1", randn(seed1)) + .repartition(2, col("_1")) + .sortWithinPartitions("_1") + .withColumn("rand2", rand(seed2)) + .withColumn("randn2", randn(seed2)) + checkSparkAnswerAndOperatorWithTol(complexRandDf) + } + } + test("window query with rangeBetween") { // values are int diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala index 9d51c69196..365003aa8c 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala @@ -179,6 +179,20 @@ abstract class CometTestBase checkSparkAnswer(df) } + protected def checkSparkAnswerAndOperatorWithTol(df: => DataFrame, tol: Double = 1e-6): Unit = { + checkSparkAnswerAndOperatorWithTol(df, tol, Seq.empty) + } + + protected def checkSparkAnswerAndOperatorWithTol( + df: => DataFrame, + tol: Double, + includeClasses: Seq[Class[_]], + excludedClasses: Class[_]*): Unit = { + checkCometOperators(stripAQEPlan(df.queryExecution.executedPlan), excludedClasses: _*) + checkPlanContains(stripAQEPlan(df.queryExecution.executedPlan), includeClasses: _*) + checkSparkAnswerWithTol(df, tol) + } + protected def checkCometOperators(plan: SparkPlan, excludedClasses: Class[_]*): Unit = { val wrapped = wrapCometSparkToColumnar(plan) wrapped.foreach {