@@ -24,39 +24,56 @@ use datafusion::arrow::datatypes::DataType;
24
24
use datafusion:: arrow:: pyarrow:: FromPyArrow ;
25
25
use datafusion:: arrow:: pyarrow:: { PyArrowType , ToPyArrow } ;
26
26
use datafusion:: error:: DataFusionError ;
27
- use datafusion:: logical_expr:: create_udf;
28
27
use datafusion:: logical_expr:: function:: ScalarFunctionImplementation ;
29
28
use datafusion:: logical_expr:: ScalarUDF ;
29
+ use datafusion:: logical_expr:: { create_udf, ColumnarValue } ;
30
30
31
31
use crate :: expr:: PyExpr ;
32
32
use crate :: utils:: parse_volatility;
33
33
34
+ /// Create a Rust callable function fr a python function that expects pyarrow arrays
35
+ fn pyarrow_function_to_rust (
36
+ func : PyObject ,
37
+ ) -> impl Fn ( & [ ArrayRef ] ) -> Result < ArrayRef , DataFusionError > {
38
+ move |args : & [ ArrayRef ] | -> Result < ArrayRef , DataFusionError > {
39
+ Python :: with_gil ( |py| {
40
+ // 1. cast args to Pyarrow arrays
41
+ let py_args = args
42
+ . iter ( )
43
+ . map ( |arg| {
44
+ arg. into_data ( )
45
+ . to_pyarrow ( py)
46
+ . map_err ( |e| DataFusionError :: Execution ( format ! ( "{e:?}" ) ) )
47
+ } )
48
+ . collect :: < Result < Vec < _ > , _ > > ( ) ?;
49
+ let py_args = PyTuple :: new_bound ( py, py_args) ;
50
+
51
+ // 2. call function
52
+ let value = func
53
+ . call_bound ( py, py_args, None )
54
+ . map_err ( |e| DataFusionError :: Execution ( format ! ( "{e:?}" ) ) ) ?;
55
+
56
+ // 3. cast to arrow::array::Array
57
+ let array_data = ArrayData :: from_pyarrow_bound ( value. bind ( py) )
58
+ . map_err ( |e| DataFusionError :: Execution ( format ! ( "{e:?}" ) ) ) ?;
59
+ Ok ( make_array ( array_data) )
60
+ } )
61
+ }
62
+ }
63
+
34
64
/// Create a DataFusion's UDF implementation from a python function
35
65
/// that expects pyarrow arrays. This is more efficient as it performs
36
66
/// a zero-copy of the contents.
37
- fn to_rust_function ( func : PyObject ) -> ScalarFunctionImplementation {
38
- #[ allow( deprecated) ]
39
- datafusion:: physical_plan:: functions:: make_scalar_function (
40
- move |args : & [ ArrayRef ] | -> Result < ArrayRef , DataFusionError > {
41
- Python :: with_gil ( |py| {
42
- // 1. cast args to Pyarrow arrays
43
- let py_args = args
44
- . iter ( )
45
- . map ( |arg| arg. into_data ( ) . to_pyarrow ( py) . unwrap ( ) )
46
- . collect :: < Vec < _ > > ( ) ;
47
- let py_args = PyTuple :: new_bound ( py, py_args) ;
48
-
49
- // 2. call function
50
- let value = func
51
- . call_bound ( py, py_args, None )
52
- . map_err ( |e| DataFusionError :: Execution ( format ! ( "{e:?}" ) ) ) ?;
67
+ fn to_scalar_function_impl ( func : PyObject ) -> ScalarFunctionImplementation {
68
+ // Make the python function callable from rust
69
+ let pyarrow_func = pyarrow_function_to_rust ( func) ;
53
70
54
- // 3. cast to arrow::array::Array
55
- let array_data = ArrayData :: from_pyarrow_bound ( value . bind ( py ) ) . unwrap ( ) ;
56
- Ok ( make_array ( array_data ) )
57
- } )
58
- } ,
59
- )
71
+ // Convert input/output from datafusion ColumnarValue to arrow arrays
72
+ Arc :: new ( move | args : & [ ColumnarValue ] | {
73
+ let array_refs = ColumnarValue :: values_to_arrays ( args ) ? ;
74
+ let array_result = pyarrow_func ( & array_refs ) ? ;
75
+ Ok ( array_result . into ( ) )
76
+ } )
60
77
}
61
78
62
79
/// Represents a PyScalarUDF
@@ -82,7 +99,7 @@ impl PyScalarUDF {
82
99
input_types. 0 ,
83
100
Arc :: new ( return_type. 0 ) ,
84
101
parse_volatility ( volatility) ?,
85
- to_rust_function ( func) ,
102
+ to_scalar_function_impl ( func) ,
86
103
) ;
87
104
Ok ( Self { function } )
88
105
}
0 commit comments