23
23
from datafusion .expr import Expr
24
24
from typing import Callable , TYPE_CHECKING , TypeVar
25
25
from abc import ABCMeta , abstractmethod
26
- from typing import List
26
+ from typing import List , Optional
27
27
from enum import Enum
28
28
import pyarrow
29
29
@@ -84,16 +84,18 @@ class ScalarUDF:
84
84
85
85
def __init__ (
86
86
self ,
87
- name : str | None ,
87
+ name : Optional [ str ] ,
88
88
func : Callable [..., _R ],
89
- input_types : list [pyarrow .DataType ],
89
+ input_types : pyarrow . DataType | list [pyarrow .DataType ],
90
90
return_type : _R ,
91
91
volatility : Volatility | str ,
92
92
) -> None :
93
93
"""Instantiate a scalar user-defined function (UDF).
94
94
95
95
See helper method :py:func:`udf` for argument details.
96
96
"""
97
+ if isinstance (input_types , pyarrow .DataType ):
98
+ input_types = [input_types ]
97
99
self ._udf = df_internal .ScalarUDF (
98
100
name , func , input_types , return_type , str (volatility )
99
101
)
@@ -104,16 +106,16 @@ def __call__(self, *args: Expr) -> Expr:
104
106
This function is not typically called by an end user. These calls will
105
107
occur during the evaluation of the dataframe.
106
108
"""
107
- args = [arg .expr for arg in args ]
108
- return Expr (self ._udf .__call__ (* args ))
109
+ args_raw = [arg .expr for arg in args ]
110
+ return Expr (self ._udf .__call__ (* args_raw ))
109
111
110
112
@staticmethod
111
113
def udf (
112
114
func : Callable [..., _R ],
113
115
input_types : list [pyarrow .DataType ],
114
116
return_type : _R ,
115
117
volatility : Volatility | str ,
116
- name : str | None = None ,
118
+ name : Optional [ str ] = None ,
117
119
) -> ScalarUDF :
118
120
"""Create a new User-Defined Function.
119
121
@@ -133,7 +135,10 @@ def udf(
133
135
if not callable (func ):
134
136
raise TypeError ("`func` argument must be callable" )
135
137
if name is None :
136
- name = func .__qualname__ .lower ()
138
+ if hasattr (func , "__qualname__" ):
139
+ name = func .__qualname__ .lower ()
140
+ else :
141
+ name = func .__class__ .__name__ .lower ()
137
142
return ScalarUDF (
138
143
name = name ,
139
144
func = func ,
@@ -167,10 +172,6 @@ def evaluate(self) -> pyarrow.Scalar:
167
172
pass
168
173
169
174
170
- if TYPE_CHECKING :
171
- _A = TypeVar ("_A" , bound = (Callable [..., _R ], Accumulator ))
172
-
173
-
174
175
class AggregateUDF :
175
176
"""Class for performing scalar user-defined functions (UDF).
176
177
@@ -180,10 +181,10 @@ class AggregateUDF:
180
181
181
182
def __init__ (
182
183
self ,
183
- name : str | None ,
184
- accumulator : _A ,
184
+ name : Optional [ str ] ,
185
+ accumulator : Callable [[], Accumulator ] ,
185
186
input_types : list [pyarrow .DataType ],
186
- return_type : _R ,
187
+ return_type : pyarrow . DataType ,
187
188
state_type : list [pyarrow .DataType ],
188
189
volatility : Volatility | str ,
189
190
) -> None :
@@ -193,7 +194,12 @@ def __init__(
193
194
descriptions.
194
195
"""
195
196
self ._udaf = df_internal .AggregateUDF (
196
- name , accumulator , input_types , return_type , state_type , str (volatility )
197
+ name ,
198
+ accumulator ,
199
+ input_types ,
200
+ return_type ,
201
+ state_type ,
202
+ str (volatility ),
197
203
)
198
204
199
205
def __call__ (self , * args : Expr ) -> Expr :
@@ -202,21 +208,52 @@ def __call__(self, *args: Expr) -> Expr:
202
208
This function is not typically called by an end user. These calls will
203
209
occur during the evaluation of the dataframe.
204
210
"""
205
- args = [arg .expr for arg in args ]
206
- return Expr (self ._udaf .__call__ (* args ))
211
+ args_raw = [arg .expr for arg in args ]
212
+ return Expr (self ._udaf .__call__ (* args_raw ))
207
213
208
214
@staticmethod
209
215
def udaf (
210
- accum : _A ,
211
- input_types : list [pyarrow .DataType ],
212
- return_type : _R ,
216
+ accum : Callable [[], Accumulator ] ,
217
+ input_types : pyarrow . DataType | list [pyarrow .DataType ],
218
+ return_type : pyarrow . DataType ,
213
219
state_type : list [pyarrow .DataType ],
214
220
volatility : Volatility | str ,
215
- name : str | None = None ,
221
+ name : Optional [ str ] = None ,
216
222
) -> AggregateUDF :
217
223
"""Create a new User-Defined Aggregate Function.
218
224
219
- The accumulator function must be callable and implement :py:class:`Accumulator`.
225
+ If your :py:class:`Accumulator` can be instantiated with no arguments, you
226
+ can simply pass it's type as ``accum``. If you need to pass additional arguments
227
+ to it's constructor, you can define a lambda or a factory method. During runtime
228
+ the :py:class:`Accumulator` will be constructed for every instance in
229
+ which this UDAF is used. The following examples are all valid.
230
+
231
+ .. code-block:: python
232
+ import pyarrow as pa
233
+ import pyarrow.compute as pc
234
+
235
+ class Summarize(Accumulator):
236
+ def __init__(self, bias: float = 0.0):
237
+ self._sum = pa.scalar(bias)
238
+
239
+ def state(self) -> List[pa.Scalar]:
240
+ return [self._sum]
241
+
242
+ def update(self, values: pa.Array) -> None:
243
+ self._sum = pa.scalar(self._sum.as_py() + pc.sum(values).as_py())
244
+
245
+ def merge(self, states: List[pa.Array]) -> None:
246
+ self._sum = pa.scalar(self._sum.as_py() + pc.sum(states[0]).as_py())
247
+
248
+ def evaluate(self) -> pa.Scalar:
249
+ return self._sum
250
+
251
+ def sum_bias_10() -> Summarize:
252
+ return Summarize(10.0)
253
+
254
+ udaf1 = udaf(Summarize, pa.float64(), pa.float64(), [pa.float64()], "immutable")
255
+ udaf2 = udaf(sum_bias_10, pa.float64(), pa.float64(), [pa.float64()], "immutable")
256
+ udaf3 = udaf(lambda: Summarize(20.0), pa.float64(), pa.float64(), [pa.float64()], "immutable")
220
257
221
258
Args:
222
259
accum: The accumulator python function.
@@ -229,14 +266,16 @@ def udaf(
229
266
Returns:
230
267
A user-defined aggregate function, which can be used in either data
231
268
aggregation or window function calls.
232
- """
233
- if not issubclass (accum , Accumulator ):
269
+ """ # noqa W505
270
+ if not callable (accum ):
271
+ raise TypeError ("`func` must be callable." )
272
+ if not isinstance (accum .__call__ (), Accumulator ):
234
273
raise TypeError (
235
- "`accum` must implement the abstract base class Accumulator"
274
+ "Accumulator must implement the abstract base class Accumulator"
236
275
)
237
276
if name is None :
238
- name = accum .__qualname__ .lower ()
239
- if isinstance (input_types , pyarrow .lib . DataType ):
277
+ name = accum .__call__ (). __class__ . __qualname__ .lower ()
278
+ if isinstance (input_types , pyarrow .DataType ):
240
279
input_types = [input_types ]
241
280
return AggregateUDF (
242
281
name = name ,
@@ -421,8 +460,8 @@ class WindowUDF:
421
460
422
461
def __init__ (
423
462
self ,
424
- name : str | None ,
425
- func : WindowEvaluator ,
463
+ name : Optional [ str ] ,
464
+ func : Callable [[], WindowEvaluator ] ,
426
465
input_types : list [pyarrow .DataType ],
427
466
return_type : pyarrow .DataType ,
428
467
volatility : Volatility | str ,
@@ -447,30 +486,57 @@ def __call__(self, *args: Expr) -> Expr:
447
486
448
487
@staticmethod
449
488
def udwf (
450
- func : WindowEvaluator ,
489
+ func : Callable [[], WindowEvaluator ] ,
451
490
input_types : pyarrow .DataType | list [pyarrow .DataType ],
452
491
return_type : pyarrow .DataType ,
453
492
volatility : Volatility | str ,
454
- name : str | None = None ,
493
+ name : Optional [ str ] = None ,
455
494
) -> WindowUDF :
456
495
"""Create a new User-Defined Window Function.
457
496
497
+ If your :py:class:`WindowEvaluator` can be instantiated with no arguments, you
498
+ can simply pass it's type as ``func``. If you need to pass additional arguments
499
+ to it's constructor, you can define a lambda or a factory method. During runtime
500
+ the :py:class:`WindowEvaluator` will be constructed for every instance in
501
+ which this UDWF is used. The following examples are all valid.
502
+
503
+ .. code-block:: python
504
+
505
+ import pyarrow as pa
506
+
507
+ class BiasedNumbers(WindowEvaluator):
508
+ def __init__(self, start: int = 0) -> None:
509
+ self.start = start
510
+
511
+ def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array:
512
+ return pa.array([self.start + i for i in range(num_rows)])
513
+
514
+ def bias_10() -> BiasedNumbers:
515
+ return BiasedNumbers(10)
516
+
517
+ udwf1 = udwf(BiasedNumbers, pa.int64(), pa.int64(), "immutable")
518
+ udwf2 = udwf(bias_10, pa.int64(), pa.int64(), "immutable")
519
+ udwf3 = udwf(lambda: BiasedNumbers(20), pa.int64(), pa.int64(), "immutable")
520
+
458
521
Args:
459
- func: The python function.
522
+ func: A callable to create the window function.
460
523
input_types: The data types of the arguments to ``func``.
461
524
return_type: The data type of the return value.
462
525
volatility: See :py:class:`Volatility` for allowed values.
526
+ arguments: A list of arguments to pass in to the __init__ method for accum.
463
527
name: A descriptive name for the function.
464
528
465
529
Returns:
466
530
A user-defined window function.
467
- """
468
- if not isinstance (func , WindowEvaluator ):
531
+ """ # noqa W505
532
+ if not callable (func ):
533
+ raise TypeError ("`func` must be callable." )
534
+ if not isinstance (func .__call__ (), WindowEvaluator ):
469
535
raise TypeError (
470
536
"`func` must implement the abstract base class WindowEvaluator"
471
537
)
472
538
if name is None :
473
- name = func .__class__ .__qualname__ .lower ()
539
+ name = func .__call__ (). __class__ .__qualname__ .lower ()
474
540
if isinstance (input_types , pyarrow .DataType ):
475
541
input_types = [input_types ]
476
542
return WindowUDF (
0 commit comments