|
22 | 22 | import functools
|
23 | 23 | from abc import ABCMeta, abstractmethod
|
24 | 24 | from enum import Enum
|
25 |
| -from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, overload |
| 25 | +from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union, overload |
26 | 26 |
|
27 | 27 | import pyarrow as pa
|
28 | 28 |
|
29 | 29 | import datafusion._internal as df_internal
|
30 | 30 | from datafusion.expr import Expr
|
31 | 31 |
|
32 | 32 | if TYPE_CHECKING:
|
33 |
| - _R = TypeVar("_R", bound=pa.DataType) |
| 33 | + from typing import TypeAlias |
| 34 | + |
| 35 | + _R: TypeAlias = pa.DataType |
34 | 36 |
|
35 | 37 |
|
36 | 38 | class Volatility(Enum):
|
@@ -684,9 +686,103 @@ def bias_10() -> BiasedNumbers:
|
684 | 686 | volatility=volatility,
|
685 | 687 | )
|
686 | 688 |
|
| 689 | + @staticmethod |
| 690 | + def create_udwf( |
| 691 | + *args: Any, **kwargs: Any |
| 692 | + ) -> Union[WindowUDF, Callable[[Callable[[], WindowEvaluator]], WindowUDF]]: |
| 693 | + """Create a new User-Defined Window Function (UDWF). |
| 694 | +
|
| 695 | + This class can be used both as a **function** and as a **decorator**. |
| 696 | +
|
| 697 | + Usage: |
| 698 | + - **As a function**: Call `udwf(func, input_types, return_type, volatility, name)`. |
| 699 | + - **As a decorator**: Use `@udwf(input_types, return_type, volatility, name)`. |
| 700 | + When using `udwf` as a decorator, **do not pass `func` explicitly**. |
| 701 | +
|
| 702 | + **Function example:** |
| 703 | + ``` |
| 704 | + import pyarrow as pa |
| 705 | +
|
| 706 | + class BiasedNumbers(WindowEvaluator): |
| 707 | + def __init__(self, start: int = 0) -> None: |
| 708 | + self.start = start |
| 709 | +
|
| 710 | + def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array: |
| 711 | + return pa.array([self.start + i for i in range(num_rows)]) |
| 712 | +
|
| 713 | + def bias_10() -> BiasedNumbers: |
| 714 | + return BiasedNumbers(10) |
| 715 | +
|
| 716 | + udwf1 = udwf(bias_10, pa.int64(), pa.int64(), "immutable") |
| 717 | + ``` |
| 718 | +
|
| 719 | + **Decorator example:** |
| 720 | + ``` |
| 721 | + @udwf(pa.int64(), pa.int64(), "immutable") |
| 722 | + def biased_numbers() -> BiasedNumbers: |
| 723 | + return BiasedNumbers(10) |
| 724 | + ``` |
| 725 | +
|
| 726 | + Args: |
| 727 | + func: **Only needed when calling as a function. Skip this argument when using |
| 728 | + `udwf` as a decorator.** |
| 729 | + input_types: The data types of the arguments. |
| 730 | + return_type: The data type of the return value. |
| 731 | + volatility: See :py:class:`Volatility` for allowed values. |
| 732 | + name: A descriptive name for the function. |
| 733 | +
|
| 734 | + Returns: |
| 735 | + A user-defined window function that can be used in window function calls. |
| 736 | + """ |
| 737 | + |
| 738 | + def _function( |
| 739 | + func: Callable[[], WindowEvaluator], |
| 740 | + input_types: pa.DataType | list[pa.DataType], |
| 741 | + return_type: pa.DataType, |
| 742 | + volatility: Volatility | str, |
| 743 | + name: Optional[str] = None, |
| 744 | + ) -> WindowUDF: |
| 745 | + if not callable(func): |
| 746 | + msg = "`func` argument must be callable" |
| 747 | + raise TypeError(msg) |
| 748 | + if not isinstance(func(), WindowEvaluator): |
| 749 | + msg = "`func` must implement the abstract base class WindowEvaluator" |
| 750 | + raise TypeError(msg) |
| 751 | + if name is None: |
| 752 | + if hasattr(func, "__qualname__"): |
| 753 | + name = func.__qualname__.lower() |
| 754 | + else: |
| 755 | + name = func.__class__.__name__.lower() |
| 756 | + if isinstance(input_types, pa.DataType): |
| 757 | + input_types = [input_types] |
| 758 | + return WindowUDF( |
| 759 | + name=name, |
| 760 | + func=func, |
| 761 | + input_types=input_types, |
| 762 | + return_type=return_type, |
| 763 | + volatility=volatility, |
| 764 | + ) |
| 765 | + |
| 766 | + def _decorator( |
| 767 | + input_types: pa.DataType | list[pa.DataType], |
| 768 | + return_type: pa.DataType, |
| 769 | + volatility: Volatility | str, |
| 770 | + name: Optional[str] = None, |
| 771 | + ) -> Callable[[Callable[[], WindowEvaluator]], WindowUDF]: |
| 772 | + def decorator(func: Callable[[], WindowEvaluator]) -> WindowUDF: |
| 773 | + return _function(func, input_types, return_type, volatility, name) |
| 774 | + |
| 775 | + return decorator |
| 776 | + |
| 777 | + if args and callable(args[0]): |
| 778 | + # Case 1: Used as a function, require the first parameter to be callable |
| 779 | + return _function(*args, **kwargs) |
| 780 | + # Case 2: Used as a decorator with parameters |
| 781 | + return _decorator(*args, **kwargs) |
| 782 | + |
687 | 783 |
|
688 | 784 | # Convenience exports so we can import instead of treating as
|
689 | 785 | # variables at the package root
|
690 | 786 | udf = ScalarUDF.udf
|
691 | 787 | udaf = AggregateUDF.udaf
|
692 |
| -udwf = WindowUDF.udwf |
| 788 | +udwf = WindowUDF.create_udwf |
0 commit comments