8000 Revert "refactor: Enhance typing in udf.py by introducing Protocol fo… · kosiew/datafusion-python@78c0203 · GitHub
[go: up one dir, main page]

Skip to content

Commit 78c0203

Browse files
committed
Revert "refactor: Enhance typing in udf.py by introducing Protocol for WindowEvaluator and improving import organization"
This reverts commit 16dbe5f.
1 parent 16dbe5f commit 78c0203

File tree

2 files changed

+5
-63
lines changed

2 files changed

+5
-63
lines changed

python/datafusion/udf.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,7 @@
2222
import functools
2323
from abc import ABCMeta, abstractmethod
2424
from enum import Enum
25-
from typing import (
26-
TYPE_CHECKING,
27-
Any,
28-
Callable,
29-
Optional,
30-
Protocol,
31-
runtime_checkable,
32-
overload,
33-
TypeVar,
34-
)
25+
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, overload
3526

3627
import pyarrow as pa
3728

@@ -438,9 +429,8 @@ def wrapper(*args: Any, **kwargs: Any) -> Expr:
438429
return _decorator(*args, **kwargs)
439430

440431

441-
@runtime_checkable
442-
class WindowEvaluator(Protocol):
443-
"""Protocol defining interface for user-defined window functions (UDWF).
432+
class WindowEvaluator:
433+
"""Evaluator class for user-defined window functions (UDWF).
444434
445435
It is up to the user to decide which evaluate function is appropriate.
446436
@@ -721,7 +711,7 @@ def _create_window_udf(
721711
msg = "`func` must be callable."
722712
raise TypeError(msg)
723713
if not isinstance(func(), WindowEvaluator):
724-
msg = "`func` must implement the WindowEvaluator protocol"
714+
msg = "`func` must implement the abstract base class WindowEvaluator"
725715
raise TypeError(msg)
726716

727717
name = name or func.__qualname__.lower()

python/tests/test_udwf.py

Lines changed: 1 addition & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def test_udwf_errors(complex_window_df):
218218
def test_udwf_errors_with_message():
219219
"""Test error cases for UDWF creation."""
220220
with pytest.raises(
221-
TypeError, match="`func` must implement the WindowEvaluator protocol"
221+
TypeError, match="`func` must implement the abstract base class WindowEvaluator"
222222
):
223223
udwf(
224224
NotSubclassOfWindowEvaluator, pa.int64(), pa.int64(), volatility="immutable"
@@ -466,51 +466,3 @@ def test_udwf_named_function(ctx, count_window_df):
466466
FOLLOWING) FROM test_table"""
467467
).collect()[0]
468468
assert result.column(0) == pa.array([0, 1, 2])
469-
470-
471-
def test_window_evaluator_protocol(count_window_df):
472-
"""Test that WindowEvaluator works as a Protocol without explicit inheritance."""
473-
474-
# Define a class that implements the Protocol interface without inheriting
475-
class CounterWithoutInheritance:
476-
def __init__(self, base: int = 0) -> None:
477-
self.base = base
478-
479-
def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array:
480-
return pa.array([self.base + i for i in range(num_rows)])
481-
482-
# Protocol methods with default implementations don't need to be defined
483-
484-
# Create a UDWF using the class that doesn't inherit from WindowEvaluator
485-
protocol_counter = udwf(
486-
CounterWithoutInheritance, pa.int64(), pa.int64(), volatility="immutable"
487-
)
488-
489-
# Use the window function
490-
df = count_window_df.select(
491-
protocol_counter(column("a"))
492-
.window_frame(WindowFrame("rows", None, None))
493-
.build()
494-
.alias("count")
495-
)
496-
497-
result = df.collect()[0]
498-
assert result.column(0) == pa.array([0, 1, 2])
499-
500-
# Also test with constructor args
501-
protocol_counter_with_args = udwf(
502-
lambda: CounterWithoutInheritance(10),
503-
pa.int64(),
504-
pa.int64(),
505-
volatility="immutable",
506-
)
507-
508-
df = count_window_df.select(
509-
protocol_counter_with_args(column("a"))
510-
.window_frame(WindowFrame("rows", None, None))
511-
.build()
512-
.alias("count")
513-
)
514-
515-
result = df.collect()[0]
516-
assert result.column(0) == pa.array([10, 11, 12])

0 commit comments

Comments
 (0)
0