8000 feat: Implementation of udf and udaf decorator (#1040) · satwikmishra11/datafusion-python@973d7ec · GitHub
[go: up one dir, main page]

Skip to content

Commit 973d7ec

Browse files
feat: Implementation of udf and udaf decorator (apache#1040)
* Implementation of udf and udaf decorator * Rename decorators back to udf and udaf, update documentations * Minor typo fixes * Fixing linting errors * ruff formatting --------- Co-authored-by: Tim Saucer <timsaucer@gmail.com>
1 parent acd7040 commit 973d7ec

File tree

3 files changed

+265
-76
lines changed

3 files changed

+265
-76
lines changed

python/datafusion/udf.py

Lines changed: 187 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from __future__ import annotations
2121

22+
import functools
2223
from abc import ABCMeta, abstractmethod
2324
from enum import Enum
2425
from typing import TYPE_CHECKING, Callable, List, Optional, TypeVar
@@ -110,43 +111,102 @@ def __call__(self, *args: Expr) -> Expr:
110111
args_raw = [arg.expr for arg in args]
111112
return Expr(self._udf.__call__(*args_raw))
112113

113-
@staticmethod
114-
def udf(
115-
func: Callable[..., _R],
116-
input_types: list[pyarrow.DataType],
117-
return_type: _R,
118-
volatility: Volatility | str,
119-
name: Optional[str] = None,
120-
) -> ScalarUDF:
121-
"""Create a new User-Defined Function.
114+
class udf:
115+
"""Create a new User-Defined Function (UDF).
116+
117+
This class can be used both as a **function** and as a **decorator**.
118+
119+
Usage:
120+
- **As a function**: Call `udf(func, input_types, return_type, volatility,
121+
name)`.
122+
- **As a decorator**: Use `@udf(input_types, return_type, volatility,
123+
name)`. In this case, do **not** pass `func` explicitly.
122124
123125
Args:
124-
func: A callable python function.
125-
input_types: The data types of the arguments to ``func``. This list
126-
must be of the same length as the number of arguments.
127-
return_type: The data type of the return value from the python
128-
function.
129-
volatility: See ``Volatility`` for allowed values.
130-
name: A descriptive name for the function.
126+
func (Callable, optional): **Only needed when calling as a function.**
127+
Skip this argument when using `udf` as a decorator.
128+
input_types (list[pyarrow.DataType]): The data types of the arguments
129+
to `func`. This list must be of the same length as the number of
130+
arguments.
131+
return_type (_R): The data type of the return value from the function.
132+
volatility (Volatility | str): See `Volatility` for allowed values.
133+
name (Optional[str]): A descriptive name for the function.
131134
132135
Returns:
133-
A user-defined aggregate function, which can be used in either data
134-
aggregation or window function calls.
136+
A user-defined function that can be used in SQL expressions,
137+
data aggregation, or window function calls.
138+
139+
Example:
140+
**Using `udf` as a function:**
141+
```
142+
def double_func(x):
143+
return x * 2
144+
double_udf = udf(double_func, [pyarrow.int32()], pyarrow.int32(),
145+
"volatile", "double_it")
146+
```
147+
148+
**Using `udf` as a decorator:**
149+
```
150+
@udf([pyarrow.int32()], pyarrow.int32(), "volatile", "double_it")
151+
def double_udf(x):
152+
return x * 2
153+
```
135154
"""
136-
if not callable(func):
137-
raise TypeError("`func` argument must be callable")
138-
if name is None:
139-
if hasattr(func, "__qualname__"):
140-
name = func.__qualname__.lower()
155+
156+
def __new__(cls, *args, **kwargs):
157+
"""Create a new UDF.
158+
159+
Trigger UDF function or decorator depending on if the first args is callable
160+
"""
161+
if args and callable(args[0]):
162+
# Case 1: Used as a function, require the first parameter to be callable
163+
return cls._function(*args, **kwargs)
141164
else:
142-
name = func.__class__.__name__.lower()
143-
return ScalarUDF(
144-
name=name,
145-
func=func,
146-
input_types=input_types,
147-
return_type=return_type,
148-
volatility=volatility,
149-
)
165+
# Case 2: Used as a decorator with parameters
166+
return cls._decorator(*args, **kwargs)
167+
168+
@staticmethod
169+
def _function(
170+
func: Callable[..., _R],
171+
input_types: list[pyarrow.DataType],
172+
return_type: _R,
173+
volatility: Volatility | str,
174+
name: Optional[str] = None,
175+
) -> ScalarUDF:
176+
if not callable(func):
177+
raise TypeError("`func` argument must be callable")
178+
if name is None:
179+
if hasattr(func, "__qualname__"):
180+
name = func.__qualname__.lower()
181+
else:
182+
name = func.__class__.__name__.lower()
183+
return ScalarUDF(
184+
name=name,
185+
func=func,
186+
input_types=input_types,
187+
return_type=return_type,
188+
volatility=volatility,
189+
)
190+
191+
@staticmethod
192+
def _decorator(
193+
input_types: list[pyarrow.DataType],
194+
return_type: _R,
195+
volatility: Volatility | str,
196+
name: Optional[str] = None,
197+
):
198+
def decorator(func):
199+
udf_caller = ScalarUDF.udf(
200+
func, input_types, return_type, volatility, name
201+
)
202+
203+
@functools.wraps(func)
204+
def wrapper(*args, **kwargs):
205+
return udf_caller(*args, **kwargs)
206+
207+
return wrapper
208+
209+
return decorator
150210

151211

152212
class Accumulator(metaclass=ABCMeta):
@@ -212,25 +272,27 @@ def __call__(self, *args: Expr) -> Expr:
212272
args_raw = [arg.expr for arg in args]
213273
return Expr(self._udaf.__call__(*args_raw))
214274

215-
@staticmethod
216-
def udaf(
217-
accum: Callable[[], Accumulator],
218-
input_types: pyarrow.DataType | list[pyarrow.DataType],
219-
return_type: pyarrow.DataType,
220-
state_type: list[pyarrow.DataType],
221-
volatility: Volatility | str,
222-
name: Optional[str] = None,
223-
) -> AggregateUDF:
224-
"""Create a new User-Defined Aggregate Function.
275+
class udaf:
276+
"""Create a new User-Defined Aggregate Function (UDAF).
225277
226-
If your :py:class:`Accumulator` can be instantiated with no arguments, you
227-
can simply pass it's type as ``accum``. If you need to pass additional arguments
228-
to it's constructor, you can define a lambda or a factory method. During runtime
229-
the :py:class:`Accumulator` will be constructed for every instance in
230-
which this UDAF is used. The following examples are all valid.
278+
This class allows you to define an **aggregate function** that can be used in
279+
data aggregation or window function calls.
231280
232-
.. code-block:: python
281+
Usage:
282+
- **As a function**: Call `udaf(accum, input_types, return_type, state_type,
283+
volatility, name)`.
284+
- **As a decorator**: Use `@udaf(input_types, return_type, state_type,
285+
volatility, name)`.
286+
When using `udaf` as a decorator, **do not pass `accum` explicitly**.
233287
288+
**Function example:**
289+
290+
If your `:py:class:Accumulator` can be instantiated with no arguments, you
291+
can simply pass it's type as `accum`. If you need to pass additional
292+
arguments to it's constructor, you can define a lambda or a factory method.
293+
During runtime the `:py:class:Accumulator` will be constructed for every
294+
instance in which this UDAF is used. The following examples are all valid.
295+
```< F987 /span>
234296
import pyarrow as pa
235297
import pyarrow.compute as pc
236298
@@ -253,12 +315,24 @@ def evaluate(self) -> pa.Scalar:
253315
def sum_bias_10() -> Summarize:
254316
return Summarize(10.0)
255317
256-
udaf1 = udaf(Summarize, pa.float64(), pa.float64(), [pa.float64()], "immutable")
257-
udaf2 = udaf(sum_bias_10, pa.float64(), pa.float64(), [pa.float64()], "immutable")
258-
udaf3 = udaf(lambda: Summarize(20.0), pa.float64(), pa.float64(), [pa.float64()], "immutable")
318+
udaf1 = udaf(Summarize, pa.float64(), pa.float64(), [pa.float64()],
319+
"immutable")
320+
udaf2 = udaf(sum_bias_10, pa.float64(), pa.float64(), [pa.float64()],
321+
"immutable")
322+
udaf3 = udaf(lambda: Summarize(20.0), pa.float64(), pa.float64(),
323+
[pa.float64()], "immutable")
324+
```
325+
326+
**Decorator example:**
327+
```
328+
@udaf(pa.float64(), pa.float64(), [pa.float64()], "immutable")
329+
def udf4() -> Summarize:
330+
return Summarize(10.0)
331+
```
259332
260333
Args:
261-
accum: The accumulator python function.
334+
accum: The accumulator python function. **Only needed when calling as a
335+
function. Skip this argument when using `udaf` as a decorator.**
262336
input_types: The data types of the arguments to ``accum``.
263337
return_type: The data type of the return value.
264338
state_type: The data types of the intermediate accumulation.
@@ -268,26 +342,69 @@ def sum_bias_10() -> Summarize:
268342
Returns:
269343
A user-defined aggregate function, which can be used in either data
270344
aggregation or window function calls.
271-
""" # noqa W505
272-
if not callable(accum):
273-
raise TypeError("`func` must be callable.")
274-
if not isinstance(accum.__call__(), Accumulator):
275-
raise TypeError(
276-
"Accumulator must implement the abstract base class Accumulator"
345+
"""
346+
347+
def __new__(cls, *args, **kwargs):
348+
"""Create a new UDAF.
349+
350+
Trigger UDAF function or decorator depending on if the first args is
351+
callable
352+
"""
353+
if args and callable(args[0]):
354+
# Case 1: Used as a function, require the first parameter to be callable
355+
return cls._function(*args, **kwargs)
356+
else:
357+
# Case 2: Used as a decorator with parameters
358+
return cls._decorator(*args, **kwargs)
359+
360+
@staticmethod
361+
def _function(
362+
accum: Callable[[], Accumulator],
363+
input_types: pyarrow.DataType | list[pyarrow.DataType],
364+
return_type: pyarrow.DataType,
365+
state_type: list[pyarrow.DataType],
366+
volatility: Volatility | str,
367+
name: Optional[str] = None,
368+
) -> AggregateUDF:
369+
if not callable(accum):
370+
raise TypeError("`func` must be callable.")
371+
if not isinstance(accum.__call__(), Accumulator):
372+
raise TypeError(
373+
"Accumulator must implement the abstract base class Accumulator"
374+
)
375+
if name is None:
376+
name = accum.__call__().__class__.__qualname__.lower()
377+
if isinstance(input_types, pyarrow.DataType):
378+
input_types = [input_types]
379+
return AggregateUDF(
380+
name=name,
381+
accumulator=accum,
382+
input_types=input_types,
383+
return_type=return_type,
384+
state_type=state_type,
385+
volatility=volatility,
277386
)
278-
if name is None:
279-
name = accum.__call__().__class__.__qualname__.lower()
280-
assert name is not None
281-
if isinstance(input_types, pyarrow.DataType):
282-
input_types = [input_types]
283-
return AggregateUDF(
284-
name=name,
285-
accumulator=accum,
286-
input_types=input_types,
287-
return_type=return_type,
288-
state_type=state_type,
289-
volatility=volatility,
290-
)
387+
388+
@staticmethod
389+
def _decorator(
390+
input_types: pyarrow.DataType | list[pyarrow.DataType],
391+
return_type: pyarrow.DataType,
392+
state_type: list[pyarrow.DataType],
393+
volatility: Volatility | str,
394+
name: Optional[str] = None,
395+
):
396+
def decorator(accum: Callable[[], Accumulator]):
397+
udaf_caller = AggregateUDF.udaf(
398+
accum, input_types, return_type, state_type, volatility, name
399+
)
400+
401+
@functools.wraps(accum)
402+
def wrapper(*args, **kwargs):
403+
return udaf_caller(*args, **kwargs)
404+
405+
return wrapper
406+
407+
return decorator
291408

292409

293410
class WindowEvaluator(metaclass=ABCMeta):

python/tests/test_udaf.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,26 @@ def test_udaf_aggregate(df):
117117
assert result.column(0) == pa.array([1.0 + 2.0 + 3.0])
118118

119119

120+
def test_udaf_decorator_aggregate(df):
121+
@udaf(pa.float64(), pa.float64(), [pa.float64()], "immutable")
122+
def summarize():
123+
return Summarize()
124+
125+
df1 = df.aggregate([], [summarize(column("a"))])
126+
127+
# execute and collect the first (and only) batch
128+
result = df1.collect()[0]
129+
130+
assert result.column(0) == pa.array([1.0 + 2.0 + 3.0])
131+
132+
df2 = df.aggregate([], [summarize(column("a"))])
133+
134+
# Run a second time to ensure the state is properly reset
135+
result = df2.collect()[0]
136+
137+
assert result.column(0) == pa.array([1.0 + 2.0 + 3.0])
138+
139+
120140
def test_udaf_aggregate_with_arguments(df):
121141
bias = 10.0
122142

@@ -143,6 +163,28 @@ def test_udaf_aggregate_with_arguments(df):
143163
assert result.column(0) == pa.array([bias + 1.0 + 2.0 + 3.0])
144164

145165

166+
def test_udaf_decorator_aggregate_with_arguments(df):
167+
bias = 10.0
168+
169+
@udaf(pa.float64(), pa.float64(), [pa.float64()], "immutable")
170+
def summarize():
171+
return Summarize(bias)
172+
173+
df1 = df.aggregate([], [summarize(column("a"))])
174+
175+
# execute and collect the first (and only) batch
176+
result = df1.collect()[0]
177+
178+
assert result.column(0) == pa.array([bias + 1.0 + 2.0 + 3.0])
179+
180+
df2 = df.aggregate([], [summarize(column("a"))])
181+
182+
# Run a second time to ensure the state is properly reset
183+
result = df2.collect()[0]
184+
185+
assert result.column(0) == pa.array([bias + 1.0 + 2.0 + 3.0])
186+
187+
146188
def test_group_by(df):
147189
summarize = udaf(
148190
Summarize,

0 commit comments

Comments
 (0)
0