19
19
20
20
from __future__ import annotations
21
21
22
+ import functools
22
23
from abc import ABCMeta , abstractmethod
23
24
from enum import Enum
24
25
from typing import TYPE_CHECKING , Callable , List , Optional , TypeVar
@@ -110,43 +111,102 @@ def __call__(self, *args: Expr) -> Expr:
110
111
args_raw = [arg .expr for arg in args ]
111
112
return Expr (self ._udf .__call__ (* args_raw ))
112
113
113
- @staticmethod
114
- def udf (
115
- func : Callable [..., _R ],
116
- input_types : list [pyarrow .DataType ],
117
- return_type : _R ,
118
- volatility : Volatility | str ,
119
- name : Optional [str ] = None ,
120
- ) -> ScalarUDF :
121
- """Create a new User-Defined Function.
114
+ class udf :
115
+ """Create a new User-Defined Function (UDF).
116
+
117
+ This class can be used both as a **function** and as a **decorator**.
118
+
119
+ Usage:
120
+ - **As a function**: Call `udf(func, input_types, return_type, volatility,
121
+ name)`.
122
+ - **As a decorator**: Use `@udf(input_types, return_type, volatility,
123
+ name)`. In this case, do **not** pass `func` explicitly.
122
124
123
125
Args:
124
- func: A callable python function.
125
- input_types: The data types of the arguments to ``func``. This list
126
- must be of the same length as the number of arguments.
127
- return_type: The data type of the return value from the python
128
- function.
129
- volatility: See ``Volatility`` for allowed values.
130
- name: A descriptive name for the function.
126
+ func (Callable, optional): **Only needed when calling as a function.**
127
+ Skip this argument when using `udf` as a decorator.
128
+ input_types (list[pyarrow.DataType]): The data types of the arguments
129
+ to `func`. This list must be of the same length as the number of
130
+ arguments.
131
+ return_type (_R): The data type of the return value from the function.
132
+ volatility (Volatility | str): See `Volatility` for allowed values.
133
+ name (Optional[str]): A descriptive name for the function.
131
134
132
135
Returns:
133
- A user-defined aggregate function, which can be used in either data
134
- aggregation or window function calls.
136
+ A user-defined function that can be used in SQL expressions,
137
+ data aggregation, or window function calls.
138
+
139
+ Example:
140
+ **Using `udf` as a function:**
141
+ ```
142
+ def double_func(x):
143
+ return x * 2
144
+ double_udf = udf(double_func, [pyarrow.int32()], pyarrow.int32(),
145
+ "volatile", "double_it")
146
+ ```
147
+
148
+ **Using `udf` as a decorator:**
149
+ ```
150
+ @udf([pyarrow.int32()], pyarrow.int32(), "volatile", "double_it")
151
+ def double_udf(x):
152
+ return x * 2
153
+ ```
135
154
"""
136
- if not callable (func ):
137
- raise TypeError ("`func` argument must be callable" )
138
- if name is None :
139
- if hasattr (func , "__qualname__" ):
140
- name = func .__qualname__ .lower ()
155
+
156
+ def __new__ (cls , * args , ** kwargs ):
157
+ """Create a new UDF.
158
+
159
+ Trigger UDF function or decorator depending on if the first args is callable
160
+ """
161
+ if args and callable (args [0 ]):
162
+ # Case 1: Used as a function, require the first parameter to be callable
163
+ return cls ._function (* args , ** kwargs )
141
164
else :
142
- name = func .__class__ .__name__ .lower ()
143
- return ScalarUDF (
144
- name = name ,
145
- func = func ,
146
- input_types = input_types ,
147
- return_type = return_type ,
148
- volatility = volatility ,
149
- )
165
+ # Case 2: Used as a decorator with parameters
166
+ return cls ._decorator (* args , ** kwargs )
167
+
168
+ @staticmethod
169
+ def _function (
170
+ func : Callable [..., _R ],
171
+ input_types : list [pyarrow .DataType ],
172
+ return_type : _R ,
173
+ volatility : Volatility | str ,
174
+ name : Optional [str ] = None ,
175
+ ) -> ScalarUDF :
176
+ if not callable (func ):
177
+ raise TypeError ("`func` argument must be callable" )
178
+ if name is None :
179
+ if hasattr (func , "__qualname__" ):
180
+ name = func .__qualname__ .lower ()
181
+ else :
182
+ name = func .__class__ .__name__ .lower ()
183
+ return ScalarUDF (
184
+ name = name ,
185
+ func = func ,
186
+ input_types = input_types ,
187
+ return_type = return_type ,
188
+ volatility = volatility ,
189
+ )
190
+
191
+ @staticmethod
192
+ def _decorator (
193
+ input_types : list [pyarrow .DataType ],
194
+ return_type : _R ,
195
+ volatility : Volatility | str ,
196
+ name : Optional [str ] = None ,
197
+ ):
198
+ def decorator (func ):
199
+ udf_caller = ScalarUDF .udf (
200
+ func , input_types , return_type , volatility , name
201
+ )
202
+
203
+ @functools .wraps (func )
204
+ def wrapper (* args , ** kwargs ):
205
+ return udf_caller (* args , ** kwargs )
206
+
207
+ return wrapper
208
+
209
+ return decorator
150
210
151
211
152
212
class Accumulator (metaclass = ABCMeta ):
@@ -212,25 +272,27 @@ def __call__(self, *args: Expr) -> Expr:
212
272
args_raw = [arg .expr for arg in args ]
213
273
return Expr (self ._udaf .__call__ (* args_raw ))
214
274
215
- @staticmethod
216
- def udaf (
217
- accum : Callable [[], Accumulator ],
218
- input_types : pyarrow .DataType | list [pyarrow .DataType ],
219
- return_type : pyarrow .DataType ,
220
- state_type : list [pyarrow .DataType ],
221
- volatility : Volatility | str ,
222
- name : Optional [str ] = None ,
223
- ) -> AggregateUDF :
224
- """Create a new User-Defined Aggregate Function.
275
+ class udaf :
276
+ """Create a new User-Defined Aggregate Function (UDAF).
225
277
226
- If your :py:class:`Accumulator` can be instantiated with no arguments, you
227
- can simply pass it's type as ``accum``. If you need to pass additional arguments
228
- to it's constructor, you can define a lambda or a factory method. During runtime
229
- the :py:class:`Accumulator` will be constructed for every instance in
230
- which this UDAF is used. The following examples are all valid.
278
+ This class allows you to define an **aggregate function** that can be used in
279
+ data aggregation or window function calls.
231
280
232
- .. code-block:: python
281
+ Usage:
282
+ - **As a function**: Call `udaf(accum, input_types, return_type, state_type,
283
+ volatility, name)`.
284
+ - **As a decorator**: Use `@udaf(input_types, return_type, state_type,
285
+ volatility, name)`.
286
+ When using `udaf` as a decorator, **do not pass `accum` explicitly**.
233
287
288
+ **Function example:**
289
+
290
+ If your `:py:class:Accumulator` can be instantiated with no arguments, you
291
+ can simply pass it's type as `accum`. If you need to pass additional
292
+ arguments to it's constructor, you can define a lambda or a factory method.
293
+ During runtime the `:py:class:Accumulator` will be constructed for every
294
+ instance in which this UDAF is used. The following examples are all valid.
295
+ ```<
F987
/span>
234
296
import pyarrow as pa
235
297
import pyarrow.compute as pc
236
298
@@ -253,12 +315,24 @@ def evaluate(self) -> pa.Scalar:
253
315
def sum_bias_10() -> Summarize:
254
316
return Summarize(10.0)
255
317
256
- udaf1 = udaf(Summarize, pa.float64(), pa.float64(), [pa.float64()], "immutable")
257
- udaf2 = udaf(sum_bias_10, pa.float64(), pa.float64(), [pa.float64()], "immutable")
258
- udaf3 = udaf(lambda: Summarize(20.0), pa.float64(), pa.float64(), [pa.float64()], "immutable")
318
+ udaf1 = udaf(Summarize, pa.float64(), pa.float64(), [pa.float64()],
319
+ "immutable")
320
+ udaf2 = udaf(sum_bias_10, pa.float64(), pa.float64(), [pa.float64()],
321
+ "immutable")
322
+ udaf3 = udaf(lambda: Summarize(20.0), pa.float64(), pa.float64(),
323
+ [pa.float64()], "immutable")
324
+ ```
325
+
326
+ **Decorator example:**
327
+ ```
328
+ @udaf(pa.float64(), pa.float64(), [pa.float64()], "immutable")
329
+ def udf4() -> Summarize:
330
+ return Summarize(10.0)
331
+ ```
259
332
260
333
Args:
261
- accum: The accumulator python function.
334
+ accum: The accumulator python function. **Only needed when calling as a
335
+ function. Skip this argument when using `udaf` as a decorator.**
262
336
input_types: The data types of the arguments to ``accum``.
263
337
return_type: The data type of the return value.
264
338
state_type: The data types of the intermediate accumulation.
@@ -268,26 +342,69 @@ def sum_bias_10() -> Summarize:
268
342
Returns:
269
343
A user-defined aggregate function, which can be used in either data
270
344
aggregation or window function calls.
271
- """ # noqa W505
272
- if not callable (accum ):
273
- raise TypeError ("`func` must be callable." )
274
- if not isinstance (accum .__call__ (), Accumulator ):
275
- raise TypeError (
276
- "Accumulator must implement the abstract base class Accumulator"
345
+ """
346
+
347
+ def __new__ (cls , * args , ** kwargs ):
348
+ """Create a new UDAF.
349
+
350
+ Trigger UDAF function or decorator depending on if the first args is
351
+ callable
352
+ """
353
+ if args and callable (args [0 ]):
354
+ # Case 1: Used as a function, require the first parameter to be callable
355
+ return cls ._function (* args , ** kwargs )
356
+ else :
357
+ # Case 2: Used as a decorator with parameters
358
+ return cls ._decorator (* args , ** kwargs )
359
+
360
+ @staticmethod
361
+ def _function (
362
+ accum : Callable [[], Accumulator ],
363
+ input_types : pyarrow .DataType | list [pyarrow .DataType ],
364
+ return_type : pyarrow .DataType ,
365
+ state_type : list [pyarrow .DataType ],
366
+ volatility : Volatility | str ,
367
+ name : Optional [str ] = None ,
368
+ ) -> AggregateUDF :
369
+ if not callable (accum ):
370
+ raise TypeError ("`func` must be callable." )
371
+ if not isinstance (accum .__call__ (), Accumulator ):
372
+ raise TypeError (
373
+ "Accumulator must implement the abstract base class Accumulator"
374
+ )
375
+ if name is None :
376
+ name = accum .__call__ ().__class__ .__qualname__ .lower ()
377
+ if isinstance (input_types , pyarrow .DataType ):
378
+ input_types = [input_types ]
379
+ return AggregateUDF (
380
+ name = name ,
381
+ accumulator = accum ,
382
+ input_types = input_types ,
383
+ return_type = return_type ,
384
+ state_type = state_type ,
385
+ volatility = volatility ,
277
386
)
278
- if name is None :
279
- name = accum .__call__ ().__class__ .__qualname__ .lower ()
280
- assert name is not None
281
- if isinstance (input_types , pyarrow .DataType ):
282
- input_types = [input_types ]
283
- return AggregateUDF (
284
- name = name ,
285
- accumulator = accum ,
286
- input_types = input_types ,
287
- return_type = return_type ,
288
- state_type = state_type ,
289
- volatility = volatility ,
290
- )
387
+
388
+ @staticmethod
389
+ def _decorator (
390
+ input_types : pyarrow .DataType | list [pyarrow .DataType ],
391
+ return_type : pyarrow .DataType ,
392
+ state_type : list [pyarrow .DataType ],
393
+ volatility : Volatility | str ,
394
+ name : Optional [str ] = None ,
395
+ ):
396
+ def decorator (accum : Callable [[], Accumulator ]):
397
+ udaf_caller = AggregateUDF .udaf (
398
+ accum , input_types , return_type , state_type , volatility , name
399
+ )
400
+
401
+ @functools .wraps (accum )
402
+ def wrapper (* args , ** kwargs ):
403
+ return udaf_caller (* args , ** kwargs )
404
+
405
+ return wrapper
406
+
407
+ return decorator
291
408
292
409
293
410
class WindowEvaluator (metaclass = ABCMeta ):
0 commit comments