@@ -218,7 +218,7 @@ def test_udwf_errors(complex_window_df):
218
218
def test_udwf_errors_with_message ():
219
219
"""Test error cases for UDWF creation."""
220
220
with pytest .raises (
221
- TypeError , match = "`func` must implement the WindowEvaluator protocol "
221
+ TypeError , match = "`func` must implement the abstract base class WindowEvaluator "
222
222
):
223
223
udwf (
224
224
NotSubclassOfWindowEvaluator , pa .int64 (), pa .int64 (), volatility = "immutable"
@@ -466,51 +466,3 @@ def test_udwf_named_function(ctx, count_window_df):
466
466
FOLLOWING) FROM test_table"""
467
467
).collect ()[0 ]
468
468
assert result .column (0 ) == pa .array ([0 , 1 , 2 ])
469
-
470
-
471
- def test_window_evaluator_protocol (count_window_df ):
472
- """Test that WindowEvaluator works as a Protocol without explicit inheritance."""
473
-
474
- # Define a class that implements the Protocol interface without inheriting
475
- class CounterWithoutInheritance :
476
- def __init__ (self , base : int = 0 ) -> None :
477
- self .base = base
478
-
479
- def evaluate_all (self , values : list [pa .Array ], num_rows : int ) -> pa .Array :
480
- return pa .array ([self .base + i for i in range (num_rows )])
481
-
482
- # Protocol methods with default implementations don't need to be defined
483
-
484
- # Create a UDWF using the class that doesn't inherit from WindowEvaluator
485
- protocol_counter = udwf (
486
- CounterWithoutInheritance , pa .int64 (), pa .int64 (), volatility = "immutable"
487
- )
488
-
489
- # Use the window function
490
- df = count_window_df .select (
491
- protocol_counter (column ("a" ))
492
- .window_frame (WindowFrame ("rows" , None , None ))
493
- .build ()
494
- .alias ("count" )
495
- )
496
-
497
- result = df .collect ()[0 ]
498
- assert result .column (0 ) == pa .array ([0 , 1 , 2 ])
499
-
500
- # Also test with constructor args
501
- protocol_counter_with_args = udwf (
502
- lambda : CounterWithoutInheritance (10 ),
503
- pa .int64 (),
504
- pa .int64 (),
505
- volatility = "immutable" ,
506
- )
507
-
508
- df = count_window_df .select (
509
- protocol_counter_with_args (column ("a" ))
510
- .window_frame (WindowFrame ("rows" , None , None ))
511
- .build ()
512
- .alias ("count" )
513
- )
514
-
515
- result = df .collect ()[0 ]
516
- assert result .column (0 ) == pa .array ([10 , 11 , 12 ])
0 commit comments