@@ -218,7 +218,7 @@ def test_udwf_errors(complex_window_df):
218
218
def test_udwf_errors_with_message ():
219
219
"""Test error cases for UDWF creation."""
220
220
with pytest .raises (
221
- TypeError , match = "`func` must implement the abstract base class WindowEvaluator"
221
+ TypeError , match = "`func` must implement the WindowEvaluator protocol "
222
222
):
223
223
udwf (
224
224
NotSubclassOfWindowEvaluator , pa .int64 (), pa .int64 (), volatility = "immutable"
@@ -466,3 +466,51 @@ def test_udwf_named_function(ctx, count_window_df):
466
466
FOLLOWING) FROM test_table"""
467
467
).collect ()[0 ]
468
468
assert result .column (0 ) == pa .array ([0 , 1 , 2 ])
469
+
470
+
471
+ def test_window_evaluator_protocol (count_window_df ):
472
+ """Test that WindowEvaluator works as a Protocol without explicit inheritance."""
473
+
474
+ # Define a class that implements the Protocol interface without inheriting
475
+ class CounterWithoutInheritance :
476
+ def __init__ (self , base : int = 0 ) -> None :
477
+ self .base = base
478
+
479
+ def evaluate_all (self , values : list [pa .Array ], num_rows : int ) -> pa .Array :
480
+ return pa .array ([self .base + i for i in range (num_rows )])
481
+
482
+ # Protocol methods with default implementations don't need to be defined
483
+
484
+ # Create a UDWF using the class that doesn't inherit from WindowEvaluator
485
+ protocol_counter = udwf (
486
+ CounterWithoutInheritance , pa .int64 (), pa .int64 (), volatility = "immutable"
487
+ )
488
+
489
+ # Use the window function
490
+ df = count_window_df .select (
491
+ protocol_counter (column ("a" ))
492
+ .window_frame (WindowFrame ("rows" , None , None ))
493
+ .build ()
494
+ .alias ("count" )
495
+ )
496
+
497
+ result = df .collect ()[0 ]
498
+ assert result .column (0 ) == pa .array ([0 , 1 , 2 ])
499
+
500
+ # Also test with constructor args
501
+ protocol_counter_with_args = udwf (
502
+ lambda : CounterWithoutInheritance (10 ),
503
+ pa .int64 (),
504
+ pa .int64 (),
505
+ volatility = "immutable" ,
506
+ )
507
+
508
+ df = count_window_df .select (
509
+ protocol_counter_with_args (column ("a" ))
510
+ .window_frame (WindowFrame ("rows" , None , None ))
511
+ .build ()
512
+ .alias ("count" )
513
+ )
514
+
515
+ result = df .collect ()[0 ]
516
+ assert result .column (0 ) == pa .array ([10 , 11 , 12 ])
0 commit comments