Skip to content

Commit 234c978

Browse files
committed
Derive UDF equality from PartialEq, Hash
Reduce boilerplate in cases where implementation of `{ScalarUDFImpl,AggregateUDFImpl,WindowUDFImpl}::{equals,hash_code}` can be derived using standard `PartialEq` and `Hash` traits. This is code complexity reduction. While valuable on its own, this also prepares for more automatic derivation of UDF equals/hash and/or removal of default implementations (which currently are error-prone).
1 parent 350c61b commit 234c978

File tree

2 files changed

+23
-31
lines changed

2 files changed

+23
-31
lines changed

datafusion/core/tests/user_defined/user_defined_scalar_functions.rs

Lines changed: 5 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ use datafusion_common::{
4343
use datafusion_expr::expr::FieldMetadata;
4444
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
4545
use datafusion_expr::{
46-
lit_with_metadata, Accumulator, ColumnarValue, CreateFunction, CreateFunctionBody,
47-
LogicalPlanBuilder, OperateFunctionArg, ReturnFieldArgs, ScalarFunctionArgs,
48-
ScalarUDF, ScalarUDFImpl, Signature, Volatility,
46+
lit_with_metadata, udf_equals_hash, Accumulator, ColumnarValue, CreateFunction,
47+
CreateFunctionBody, LogicalPlanBuilder, OperateFunctionArg, ReturnFieldArgs,
48+
ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility,
4949
};
5050
use datafusion_functions_nested::range::range_udf;
5151
use parking_lot::Mutex;
@@ -517,7 +517,7 @@ async fn test_user_defined_functions_with_alias() -> Result<()> {
517517
}
518518

519519
/// Volatile UDF that should append a different value to each row
520-
#[derive(Debug)]
520+
#[derive(Debug, PartialEq, Hash)]
521521
struct AddIndexToStringVolatileScalarUDF {
522522
name: String,
523523
signature: Signature,
@@ -586,33 +586,7 @@ impl ScalarUDFImpl for AddIndexToStringVolatileScalarUDF {
586586
Ok(ColumnarValue::Array(Arc::new(StringArray::from(answer))))
587587
}
588588

589-
fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
590-
let Some(other) = other.as_any().downcast_ref::<Self>() else {
591-
return false;
592-
};
593-
let Self {
594-
name,
595-
signature,
596-
return_type,
597-
} = self;
598-
name == &other.name
599-
&& signature == &other.signature
600-
&& return_type == &other.return_type
601-
}
602-
603-
fn hash_value(&self) -> u64 {
604-
let Self {
605-
name,
606-
signature,
607-
return_type,
608-
} = self;
609-
let mut hasher = DefaultHasher::new();
610-
std::any::type_name::<Self>().hash(&mut hasher);
611-
name.hash(&mut hasher);
612-
signature.hash(&mut hasher);
613-
return_type.hash(&mut hasher);
614-
hasher.finish()
615-
}
589+
udf_equals_hash!(ScalarUDFImpl);
616590
}
617591

618592
#[tokio::test]

datafusion/expr/src/utils.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1260,6 +1260,24 @@ pub fn collect_subquery_cols(
12601260
})
12611261
}
12621262

1263+
#[macro_export]
1264+
macro_rules! udf_equals_hash {
1265+
($udf_type:tt) => {
1266+
fn equals(&self, other: &dyn $udf_type) -> bool {
1267+
let Some(other) = other.as_any().downcast_ref::<Self>() else {
1268+
return false;
1269+
};
1270+
self == other
1271+
}
1272+
1273+
fn hash_value(&self) -> u64 {
1274+
let hasher = &mut DefaultHasher::new();
1275+
self.hash(hasher);
1276+
hasher.finish()
1277+
}
1278+
};
1279+
}
1280+
12631281
#[cfg(test)]
12641282
mod tests {
12651283
use super::*;

0 commit comments

Comments
 (0)