8000 refactor: Enhance typing in udf.py by introducing Protocol for Window… · kosiew/datafusion-python@16dbe5f · GitHub
[go: up one dir, main page]

Skip to content

Commit 16dbe5f

Browse files
committed
refactor: Enhance typing in udf.py by introducing Protocol for WindowEvaluator and improving import organization
1 parent 20d5dd9 commit 16dbe5f

File tree

2 files changed

+63
-5
lines changed

2 files changed

+63
-5
lines changed

python/datafusion/udf.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,16 @@
2222
import functools
2323
from abc import ABCMeta, abstractmethod
2424
from enum import Enum
25-
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, overload
25+
from typing import (
26+
TYPE_CHECKING,
27+
Any,
28+
Callable,
29+
Optional,
30+
Protocol,
31+
runtime_checkable,
32+
overload,
33+
TypeVar,
34+
)
2635

2736
import pyarrow as pa
2837

@@ -429,8 +438,9 @@ def wrapper(*args: Any, **kwargs: Any) -> Expr:
429438
return _decorator(*args, **kwargs)
430439

431440

432-
class WindowEvaluator:
433-
"""Evaluator class for user-defined window functions (UDWF).
441+
@runtime_checkable
442+
class WindowEvaluator(Protocol):
443+
"""Protocol defining interface for user-defined window functions (UDWF).
434444
435445
It is up to the user to decide which evaluate function is appropriate.
436446
@@ -711,7 +721,7 @@ def _create_window_udf(
711721
msg = "`func` must be callable."
712722
raise TypeError(msg)
713723
if not isinstance(func(), WindowEvaluator):
714-
msg = "`func` must implement the abstract base class WindowEvaluator"
724+
msg = "`func` must implement the WindowEvaluator protocol"
715725
raise TypeError(msg)
716726

717727
name = name or func.__qualname__.lower()

python/tests/test_udwf.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def test_udwf_errors(complex_window_df):
218218
def test_udwf_errors_with_message():
219219
"""Test error cases for UDWF creation."""
220220
with pytest.raises(
221-
TypeError, match="`func` must implement the abstract base class WindowEvaluator"
221+
TypeError, match="`func` must implement the WindowEvaluator protocol"
222222
):
223223
udwf(
224224
NotSubclassOfWindowEvaluator, pa.int64(), pa.int64(), volatility="immutable"
@@ -466,3 +466,51 @@ def test_udwf_named_function(ctx, count_window_df):
466466
FOLLOWING) FROM test_table"""
467467
).collect()[0]
468468
assert result.column(0) == pa.array([0, 1, 2])
469+
470+
471+
def test_window_evaluator_protocol(count_window_df):
472+
"""Test that WindowEvaluator works as a Protocol without explicit inheritance."""
473+
474+
# Define a class that implements the Protocol interface without inheriting
475+
class CounterWithoutInheritance:
476+
def __init__(self, base: int = 0) -> None:
477+
self.base = base
478+
479+
def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array:
480+
return pa.array([self.base + i for i in range(num_rows)])
481+
482+
# Protocol methods with default implementations don't need to be defined
483+
484+
# Create a UDWF using the class that doesn't inherit from WindowEvaluator
485+
protocol_counter = udwf(
486+
CounterWithoutInheritance, pa.int64(), pa.int64(), volatility="immutable"
487+
)
488+
489+
# Use the window function
490+
df = count_window_df.select(
491+
protocol_counter(column("a"))
492+
.window_frame(WindowFrame("rows", None, None))
493+
.build()
494+
.alias("count")
495+
)
496+
497+
result = df.collect()[0]
498+
assert result.column(0) == pa.array([0, 1, 2])
499+
500+
# Also test with constructor args
501+
protocol_counter_with_args = udwf(
502+
lambda: CounterWithoutInheritance(10),
503+
pa.int64(),
504+
pa.int64(),
505+
volatility="immutable",
506+
)
507+
508+
df = count_window_df.select(
509+
protocol_counter_with_args(column("a"))
510+
.window_frame(WindowFrame("rows", None, None))
511+
.build()
512+
.alias("count")
513+
)
514+
515+
result = df.collect()[0]
516+
assert result.column(0) == pa.array([10, 11, 12])

0 commit comments

Comments
 (0)
0