Skip to content

Commit 7bab0a3

Browse files
fix: remove use of deprecated make_scalar_function
`make_scalar_function` has been deprecated since v36 [0]. It is being removed from the public api in v43 [1]. [0]: apache/datafusion#8878 [1]: apache/datafusion#12505
1 parent cdec202 commit 7bab0a3

File tree

1 file changed

+36
-24
lines changed

1 file changed

+36
-24
lines changed

src/udf.rs

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -24,39 +24,51 @@ use datafusion::arrow::datatypes::DataType;
2424
use datafusion::arrow::pyarrow::FromPyArrow;
2525
use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow};
2626
use datafusion::error::DataFusionError;
27-
use datafusion::logical_expr::create_udf;
2827
use datafusion::logical_expr::function::ScalarFunctionImplementation;
2928
use datafusion::logical_expr::ScalarUDF;
29+
use datafusion::logical_expr::{create_udf, ColumnarValue};
3030

3131
use crate::expr::PyExpr;
3232
use crate::utils::parse_volatility;
3333

34+
/// Create a Rust callable function fr a python function that expects pyarrow arrays
35+
fn pyarrow_function_to_rust(
36+
func: PyObject,
37+
) -> impl Fn(&[ArrayRef]) -> Result<ArrayRef, DataFusionError> {
38+
move |args: &[ArrayRef]| -> Result<ArrayRef, DataFusionError> {
39+
Python::with_gil(|py| {
40+
// 1. cast args to Pyarrow arrays
41+
let py_args = args
42+
.iter()
43+
.map(|arg| arg.into_data().to_pyarrow(py).unwrap())
44+
.collect::<Vec<_>>();
45+
let py_args = PyTuple::new_bound(py, py_args);
46+
47+
// 2. call function
48+
let value = func
49+
.call_bound(py, py_args, None)
50+
.map_err(|e| DataFusionError::Execution(format!("{e:?}")))?;
51+
52+
// 3. cast to arrow::array::Array
53+
let array_data = ArrayData::from_pyarrow_bound(value.bind(py)).unwrap();
54+
Ok(make_array(array_data))
55+
})
56+
}
57+
}
58+
3459
/// Create a DataFusion's UDF implementation from a python function
3560
/// that expects pyarrow arrays. This is more efficient as it performs
3661
/// a zero-copy of the contents.
37-
fn to_rust_function(func: PyObject) -> ScalarFunctionImplementation {
38-
#[allow(deprecated)]
39-
datafusion::physical_plan::functions::make_scalar_function(
40-
move |args: &[ArrayRef]| -> Result<ArrayRef, DataFusionError> {
41-
Python::with_gil(|py| {
42-
// 1. cast args to Pyarrow arrays
43-
let py_args = args
44-
.iter()
45-
.map(|arg| arg.into_data().to_pyarrow(py).unwrap())
46-
.collect::<Vec<_>>();
47-
let py_args = PyTuple::new_bound(py, py_args);
48-
49-
// 2. call function
50-
let value = func
51-
.call_bound(py, py_args, None)
52-
.map_err(|e| DataFusionError::Execution(format!("{e:?}")))?;
62+
fn to_scalar_function_impl(func: PyObject) -> ScalarFunctionImplementation {
63+
// Make the python function callable from rust
64+
let pyarrow_func = pyarrow_function_to_rust(func);
5365

54-
// 3. cast to arrow::array::Array
55-
let array_data = ArrayData::from_pyarrow_bound(value.bind(py)).unwrap();
56-
Ok(make_array(array_data))
57-
})
58-
},
59-
)
66+
// Convert input/output from datafusion ColumnarValue to arrow arrays
67+
Arc::new(move |args: &[ColumnarValue]| {
68+
let array_refs = ColumnarValue::values_to_arrays(args)?;
69+
let array_result = pyarrow_func(&array_refs)?;
70+
Ok(array_result.into())
71+
})
6072
}
6173

6274
/// Represents a PyScalarUDF
@@ -82,7 +94,7 @@ impl PyScalarUDF {
8294
input_types.0,
8395
Arc::new(return_type.0),
8496
parse_volatility(volatility)?,
85-
to_rust_function(func),
97+
to_scalar_function_impl(func),
8698
);
8799
Ok(Self { function })
88100
}

0 commit comments

Comments
 (0)