8000 Feature/instance udfs (#890) · kylebarron/datafusion-python@1fd3762 · GitHub
[go: up one dir, main page]

Skip to content

Commit 1fd3762

Browse files
authored
Feature/instance udfs (apache#890)
* Add option for passing in constructor arguments to the udaf * Fix small warnings in pylance * Improve type hinting for udaf and fix one pylance warning * Set up UDWF to take arguments as constructor just like UDAF to ensure we get a clean state when functions are reused * Improve handling of udf when user provides a class instead of bare function * Add unit tests for UDF showing callable class * Add license text * Switching to use factory methods for udaf and udwf * Move new tests to the new testing directory
1 parent d181a30 commit 1fd3762

File tree

7 files changed

+272
-97
lines changed

7 files changed

+272
-97
lines changed

python/datafusion/udf.py

Lines changed: 101 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from datafusion.expr import Expr
2424
from typing import Callable, TYPE_CHECKING, TypeVar
2525
from abc import ABCMeta, abstractmethod
26-
from typing import List
26+
from typing import List, Optional
2727
from enum import Enum
2828
import pyarrow
2929

@@ -84,16 +84,18 @@ class ScalarUDF:
8484

8585
def __init__(
8686
self,
87-
name: str | None,
87+
name: Optional[str],
8888
func: Callable[..., _R],
89-
input_types: list[pyarrow.DataType],
89+
input_types: pyarrow.DataType | list[pyarrow.DataType],
9090
return_type: _R,
9191
volatility: Volatility | str,
9292
) -> None:
9393
"""Instantiate a scalar user-defined function (UDF).
9494
9595
See helper method :py:func:`udf` for argument details.
9696
"""
97+
if isinstance(input_types, pyarrow.DataType):
98+
input_types = [input_types]
9799
self._udf = df_internal.ScalarUDF(
98100
name, func, input_types, return_type, str(volatility)
99101
)
@@ -104,16 +106,16 @@ def __call__(self, *args: Expr) -> Expr:
104106
This function is not typically called by an end user. These calls will
105107
occur during the evaluation of the dataframe.
106108
"""
107-
args = [arg.expr for arg in args]
108-
return Expr(self._udf.__call__(*args))
109+
args_raw = [arg.expr for arg in args]
110+
return Expr(self._udf.__call__(*args_raw))
109111

110112
@staticmethod
111113
def udf(
112114
func: Callable[..., _R],
113115
input_types: list[pyarrow.DataType],
114116
return_type: _R,
115117
volatility: Volatility | str,
116-
name: str | None = None,
118+
name: Optional[str] = None,
117119
) -> ScalarUDF:
118120
"""Create a new User-Defined Function.
119121
@@ -133,7 +135,10 @@ def udf(
133135
if not callable(func):
134136
raise TypeError("`func` argument must be callable")
135137
if name is None:
136-
name = func.__qualname__.lower()
138+
if hasattr(func, "__qualname__"):
139+
name = func.__qualname__.lower()
140+
else:
141+
name = func.__class__.__name__.lower()
137142
return ScalarUDF(
138143
name=name,
139144
func=func,
@@ -167,10 +172,6 @@ def evaluate(self) -> pyarrow.Scalar:
167172
pass
168173

169174

170-
if TYPE_CHECKING:
171-
_A = TypeVar("_A", bound=(Callable[..., _R], Accumulator))
172-
173-
174175
class AggregateUDF:
175176
"""Class for performing scalar user-defined functions (UDF).
176177
@@ -180,10 +181,10 @@ class AggregateUDF:
180181

181182
def __init__(
182183
self,
183-
name: str | None,
184-
accumulator: _A,
184+
name: Optional[str],
185+
accumulator: Callable[[], Accumulator],
185186
input_types: list[pyarrow.DataType],
186-
return_type: _R,
187+
return_type: pyarrow.DataType,
187188
state_type: list[pyarrow.DataType],
188189
volatility: Volatility | str,
189190
) -> None:
@@ -193,7 +194,12 @@ def __init__(
193194
descriptions.
194195
"""
195196
self._udaf = df_internal.AggregateUDF(
196-
name, accumulator, input_types, return_type, state_type, str(volatility)
197+
name,
198+
accumulator,
199+
input_types,
200+
return_type,
201+
state_type,
202+
str(volatility),
197203
)
198204

199205
def __call__(self, *args: Expr) -> Expr:
@@ -202,21 +208,52 @@ def __call__(self, *args: Expr) -> Expr:
202208
This function is not typically called by an end user. These calls will
203209
occur during the evaluation of the dataframe.
204210
"""
205-
args = [arg.expr for arg in args]
206-
return Expr(self._udaf.__call__(*args))
211+
args_raw = [arg.expr for arg in args]
212+
return Expr(self._udaf.__call__(*args_raw))
207213

208214
@staticmethod
209215
def udaf(
210-
accum: _A,
211-
input_types: list[pyarrow.DataType],
212-
return_type: _R,
216+
accum: Callable[[], Accumulator],
217+
input_types: pyarrow.DataType | list[pyarrow.DataType],
218+
return_type: pyarrow.DataType,
213219
state_type: list[pyarrow.DataType],
214220
volatility: Volatility | str,
215-
name: str | None = None,
221+
name: Optional[str] = None,
216222
) -> AggregateUDF:
217223
"""Create a new User-Defined Aggregate Function.
218224
219-
The accumulator function must be callable and implement :py:class:`Accumulator`.
225+
If your :py:class:`Accumulator` can be instantiated with no arguments, you
226+
can simply pass it's type as ``accum``. If you need to pass additional arguments
227+
to it's constructor, you can define a lambda or a factory method. During runtime
228+
the :py:class:`Accumulator` will be constructed for every instance in
229+
which this UDAF is used. The following examples are all valid.
230+
231+
.. code-block:: python
232+
import pyarrow as pa
233+
import pyarrow.compute as pc
234+
235+
class Summarize(Accumulator):
236+
def __init__(self, bias: float = 0.0):
237+
self._sum = pa.scalar(bias)
238+
239+
def state(self) -> List[pa.Scalar]:
240+
return [self._sum]
241+
242+
def update(self, values: pa.Array) -> None:
243+
self._sum = pa.scalar(self._sum.as_py() + pc.sum(values).as_py())
244+
245+
def merge(self, states: List[pa.Array]) -> None:
246+
self._sum = pa.scalar(self._sum.as_py() + pc.sum(states[0]).as_py())
247+
248+
def evaluate(self) -> pa.Scalar:
249+
return self._sum
250+
251+
def sum_bias_10() -> Summarize:
252+
return Summarize(10.0)
253+
254+
udaf1 = udaf(Summarize, pa.float64(), pa.float64(), [pa.float64()], "immutable")
255+
udaf2 = udaf(sum_bias_10, pa.float64(), pa.float64(), [pa.float64()], "immutable")
256+
udaf3 = udaf(lambda: Summarize(20.0), pa.float64(), pa.float64(), [pa.float64()], "immutable")
220257
221258
Args:
222259
accum: The accumulator python function.
@@ -229,14 +266,16 @@ def udaf(
229266
Returns:
230267
A user-defined aggregate function, which can be used in either data
231268
aggregation or window function calls.
232-
"""
233-
if not issubclass(accum, Accumulator):
269+
""" # noqa W505
270+
if not callable(accum):
271+
raise TypeError("`func` must be callable.")
272+
if not isinstance(accum.__call__(), Accumulator):
234273
raise TypeError(
235-
"`accum` must implement the abstract base class Accumulator"
274+
"Accumulator must implement the abstract base class Accumulator"
236275
)
237276
if name is None:
238-
name = accum.__qualname__.lower()
239-
if isinstance(input_types, pyarrow.lib.DataType):
277+
name = accum.__call__().__class__.__qualname__.lower()
278+
if isinstance(input_types, pyarrow.DataType):
240279
input_types = [input_types]
241280
return AggregateUDF(
242281
name=name,
@@ -421,8 +460,8 @@ class WindowUDF:
421460

422461
def __init__(
423462
self,
424-
name: str | None,
425-
func: WindowEvaluator,
463+
name: Optional[str],
464+
func: Callable[[], WindowEvaluator],
426465
input_types: list[pyarrow.DataType],
427466
return_type: pyarrow.DataType,
428467
volatility: Volatility | str,
@@ -447,30 +486,57 @@ def __call__(self, *args: Expr) -> Expr:
447486

448487
@staticmethod
449488
def udwf(
450-
func: WindowEvaluator,
489+
func: Callable[[], WindowEvaluator],
451490
input_types: pyarrow.DataType | list[pyarrow.DataType],
452491
return_type: pyarrow.DataType,
453492
volatility: Volatility | str,
454-
name: str | None = None,
493+
name: Optional[str] = None,
455494
) -> WindowUDF:
456495
"""Create a new User-Defined Window Function.
457496
497+
If your :py:class:`WindowEvaluator` can be instantiated with no arguments, you
498+
can simply pass it's type as ``func``. If you need to pass additional arguments
499+
to it's constructor, you can define a lambda or a factory method. During runtime
500+
the :py:class:`WindowEvaluator` will be constructed for every instance in
501+
which this UDWF is used. The following examples are all valid.
502+
503+
.. code-block:: python
504+
505+
import pyarrow as pa
506+
507+
class BiasedNumbers(WindowEvaluator):
508+
def __init__(self, start: int = 0) -> None:
509+
self.start = start
510+
511+
def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array:
512+
return pa.array([self.start + i for i in range(num_rows)])
513+
514+
def bias_10() -> BiasedNumbers:
515+
return BiasedNumbers(10)
516+
517+
udwf1 = udwf(BiasedNumbers, pa.int64(), pa.int64(), "immutable")
518+
udwf2 = udwf(bias_10, pa.int64(), pa.int64(), "immutable")
519+
udwf3 = udwf(lambda: BiasedNumbers(20), pa.int64(), pa.int64(), "immutable")
520+
458521
Args:
459-
func: The python function.
522+
func: A callable to create the window function.
460523
input_types: The data types of the arguments to ``func``.
461524
return_type: The data type of the return value.
462525
volatility: See :py:class:`Volatility` for allowed values.
526+
arguments: A list of arguments to pass in to the __init__ method for accum.
463527
name: A descriptive name for the function.
464528
465529
Returns:
466530
A user-defined window function.
467-
"""
468-
if not isinstance(func, WindowEvaluator):
531+
""" # noqa W505
532+
if not callable(func):
533+
raise TypeError("`func` must be callable.")
534+
if not isinstance(func.__call__(), WindowEvaluator):
469535
raise TypeError(
470536
"`func` must implement the abstract base class WindowEvaluator"
471537
)
472538
if name is None:
473-
name = func.__class__.__qualname__.lower()
539+
name = func.__call__().__class__.__qualname__.lower()
474540
if isinstance(input_types, pyarrow.DataType):
475541
input_types = [input_types]
476542
return WindowUDF(

python/tests/test_dataframe.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
WindowFrame,
3030
column,
3131
literal,
32-
udf,
3332
)
3433
from datafusion.expr import Window
3534

@@ -236,21 +235,6 @@ def test_unnest_without_nulls(nested_df):
236235
assert result.column(1) == pa.array([7, 8, 8, 9, 9, 9])
237236

238237

239-
def test_udf(df):
240-
# is_null is a pa function over arrays
241-
is_null = udf(
242-
lambda x: x.is_null(),
243-
[pa.int64()],
244-
pa.bool_(),
245-
volatility="immutable",
246-
)
247-
248-
df = df.select(is_null(column("a")))
249-
result = df.collect()[0].column(0)
250-
251-
assert result == pa.array([False, False, False])
252-
253-
254238
def test_join():
255239
ctx = SessionContext()
256240

File renamed without changes.

0 commit comments

Comments
 (0)
0