@@ -24,39 +24,51 @@ use datafusion::arrow::datatypes::DataType;
24
24
use datafusion:: arrow:: pyarrow:: FromPyArrow ;
25
25
use datafusion:: arrow:: pyarrow:: { PyArrowType , ToPyArrow } ;
26
26
use datafusion:: error:: DataFusionError ;
27
- use datafusion:: logical_expr:: create_udf;
28
27
use datafusion:: logical_expr:: function:: ScalarFunctionImplementation ;
29
28
use datafusion:: logical_expr:: ScalarUDF ;
29
+ use datafusion:: logical_expr:: { create_udf, ColumnarValue } ;
30
30
31
31
use crate :: expr:: PyExpr ;
32
32
use crate :: utils:: parse_volatility;
33
33
34
+ /// Create a Rust callable function fr a python function that expects pyarrow arrays
35
+ fn pyarrow_function_to_rust (
36
+ func : PyObject ,
37
+ ) -> impl Fn ( & [ ArrayRef ] ) -> Result < ArrayRef , DataFusionError > {
38
+ move |args : & [ ArrayRef ] | -> Result < ArrayRef , DataFusionError > {
39
+ Python :: with_gil ( |py| {
40
+ // 1. cast args to Pyarrow arrays
41
+ let py_args = args
42
+ . iter ( )
43
+ . map ( |arg| arg. into_data ( ) . to_pyarrow ( py) . unwrap ( ) )
44
+ . collect :: < Vec < _ > > ( ) ;
45
+ let py_args = PyTuple :: new_bound ( py, py_args) ;
46
+
47
+ // 2. call function
48
+ let value = func
49
+ . call_bound ( py, py_args, None )
50
+ . map_err ( |e| DataFusionError :: Execution ( format ! ( "{e:?}" ) ) ) ?;
51
+
52
+ // 3. cast to arrow::array::Array
53
+ let array_data = ArrayData :: from_pyarrow_bound ( value. bind ( py) ) . unwrap ( ) ;
54
+ Ok ( make_array ( array_data) )
55
+ } )
56
+ }
57
+ }
58
+
34
59
/// Create a DataFusion's UDF implementation from a python function
35
60
/// that expects pyarrow arrays. This is more efficient as it performs
36
61
/// a zero-copy of the contents.
37
- fn to_rust_function ( func : PyObject ) -> ScalarFunctionImplementation {
38
- #[ allow( deprecated) ]
39
- datafusion:: physical_plan:: functions:: make_scalar_function (
40
- move |args : & [ ArrayRef ] | -> Result < ArrayRef , DataFusionError > {
41
- Python :: with_gil ( |py| {
42
- // 1. cast args to Pyarrow arrays
43
- let py_args = args
44
- . iter ( )
45
- . map ( |arg| arg. into_data ( ) . to_pyarrow ( py) . unwrap ( ) )
46
- . collect :: < Vec < _ > > ( ) ;
47
- let py_args = PyTuple :: new_bound ( py, py_args) ;
48
-
49
- // 2. call function
50
- let value = func
51
- . call_bound ( py, py_args, None )
52
- . map_err ( |e| DataFusionError :: Execution ( format ! ( "{e:?}" ) ) ) ?;
62
+ fn to_scalar_function_impl ( func : PyObject ) -> ScalarFunctionImplementation {
63
+ // Make the python function callable from rust
64
+ let pyarrow_func = pyarrow_function_to_rust ( func) ;
53
65
54
- // 3. cast to arrow::array::Array
55
- let array_data = ArrayData :: from_pyarrow_bound ( value . bind ( py ) ) . unwrap ( ) ;
56
- Ok ( make_array ( array_data ) )
57
- } )
58
- } ,
59
- )
66
+ // Convert input/output from datafusion ColumnarValue to arrow arrays
67
+ Arc :: new ( move | args : & [ ColumnarValue ] | {
68
+ let array_refs = ColumnarValue :: values_to_arrays ( args ) ? ;
69
+ let array_result = pyarrow_func ( & array_refs ) ? ;
70
+ Ok ( array_result . into ( ) )
71
+ } )
60
72
}
61
73
62
74
/// Represents a PyScalarUDF
@@ -82,7 +94,7 @@ impl PyScalarUDF {
82
94
input_types. 0 ,
83
95
Arc :: new ( return_type. 0 ) ,
84
96
parse_volatility ( volatility) ?,
85
- to_rust_function ( func) ,
97
+ to_scalar_function_impl ( func) ,
86
98
) ;
87
99
Ok ( Self { function } )
88
100
}
0 commit comments