8000 feat: Introduce create_udwf method for User-Defined Window Functions · kosiew/datafusion-python@b708133 · GitHub
[go: up one dir, main page]

Skip to content

Commit b708133

Browse files
committed
feat: Introduce create_udwf method for User-Defined Window Functions
- Added `create_udwf` static method to `WindowUDF` class, allowing users to create User-Defined Window Functions (UDWF) as both a function and a decorator. - Updated type hinting for `_R` using `TypeAlias` for better clarity. - Enhanced documentation with usage examples for both function and decorator styles, improving usability and understanding.
1 parent b194a87 commit b708133

File tree

1 file changed

+99
-3
lines changed

1 file changed

+99
-3
lines changed

python/datafusion/udf.py

Lines changed: 99 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,17 @@
2222
import functools
2323
from abc import ABCMeta, abstractmethod
2424
from enum import Enum
25-
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, overload
25+
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union, overload
2626

2727
import pyarrow as pa
2828

2929
import datafusion._internal as df_internal
3030
from datafusion.expr import Expr
3131

3232
if TYPE_CHECKING:
33-
_R = TypeVar("_R", bound=pa.DataType)
33+
from typing import TypeAlias
34+
35+
_R: TypeAlias = pa.DataType
3436

3537

3638
class Volatility(Enum):
@@ -684,9 +686,103 @@ def bias_10() -> BiasedNumbers:
684686
volatility=volatility,
685687
)
686688

689+
@staticmethod
690+
def create_udwf(
691+
*args: Any, **kwargs: Any
692+
) -> Union[WindowUDF, Callable[[Callable[[], WindowEvaluator]], WindowUDF]]:
693+
"""Create a new User-Defined Window Function (UDWF).
694+
695+
This class can be used both as a **function** and as a **decorator**.
696+
697+
Usage:
698+
- **As a function**: Call `udwf(func, input_types, return_type, volatility, name)`.
699+
- **As a decorator**: Use `@udwf(input_types, return_type, volatility, name)`.
700+
When using `udwf` as a decorator, **do not pass `func` explicitly**.
701+
702+
**Function example:**
703+
```
704+
import pyarrow as pa
705+
706+
class BiasedNumbers(WindowEvaluator):
707+
def __init__(self, start: int = 0) -> None:
708+
self.start = start
709+
710+
def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array:
711+
return pa.array([self.start + i for i in range(num_rows)])
712+
713+
def bias_10() -> BiasedNumbers:
714+
return BiasedNumbers(10)
715+
716+
udwf1 = udwf(bias_10, pa.int64(), pa.int64(), "immutable")
717+
```
718+
719+
**Decorator example:**
720+
```
721+
@udwf(pa.int64(), pa.int64(), "immutable")
722+
def biased_numbers() -> BiasedNumbers:
723+
return BiasedNumbers(10)
724+
```
725+
726+
Args:
727+
func: **Only needed when calling as a function. Skip this argument when using
728+
`udwf` as a decorator.**
729+
input_types: The data types of the arguments.
730+
return_type: The data type of the return value.
731+
volatility: See :py:class:`Volatility` for allowed values.
732+
name: A descriptive name for the function.
733+
734+
Returns:
735+
A user-defined window function that can be used in window function calls.
736+
"""
737+
738+
def _function(
739+
func: Callable[[], WindowEvaluator],
740+
input_types: pa.DataType | list[pa.DataType],
741+
return_type: pa.DataType,
742+
volatility: Volatility | str,
743+
name: Optional[str] = None,
744+
) -> WindowUDF:
745+
if not callable(func):
746+
msg = "`func` argument must be callable"
747+
raise TypeError(msg)
748+
if not isinstance(func(), WindowEvaluator):
749+
msg = "`func` must implement the abstract base class WindowEvaluator"
750+
raise TypeError(msg)
751+
if name is None:
752+
if hasattr(func, "__qualname__"):
753+
name = func.__qualname__.lower()
754+
else:
755+
name = func.__class__.__name__.lower()
756+
if isinstance(input_types, pa.DataType):
757+
input_types = [input_types]
758+
return WindowUDF(
759+
name=name,
760+
func=func,
761+
input_types=input_types,
762+
return_type=return_type,
763+
volatility=volatility,
764+
)
765+
766+
def _decorator(
767+
input_types: pa.DataType | list[pa.DataType],
768+
return_type: pa.DataType,
769+
volatility: Volatility | str,
770+
name: Optional[str] = None,
771+
) -> Callable[[Callable[[], WindowEvaluator]], WindowUDF]:
772+
def decorator(func: Callable[[], WindowEvaluator]) -> WindowUDF:
773+
return _function(func, input_types, return_type, volatility, name)
774+
775+
return decorator
776+
777+
if args and callable(args[0]):
778+
# Case 1: Used as a function, require the first parameter to be callable
779+
return _function(*args, **kwargs)
780+
# Case 2: Used as a decorator with parameters
781+
return _decorator(*args, **kwargs)
782+
687783

688784
# Convenience exports so we can import instead of treating as
689785
# variables at the package root
690786
udf = ScalarUDF.udf
691787
udaf = AggregateUDF.udaf
692-
udwf = WindowUDF.udwf
788+
udwf = WindowUDF.create_udwf

0 commit comments

Comments
 (0)
0