@@ -468,8 +468,8 @@ def test_string_target(pyplot):
468
468
)
469
469
470
470
471
- def test_dataframe_support (pyplot ):
472
- """Check that passing a dataframe at fit and to the Display does not
471
+ def test_pandas_dataframe_support (pyplot ):
472
+ """Check that passing a pandas dataframe at fit and to the Display does not
473
473
raise warnings.
474
474
475
475
Non-regression test for:
@@ -485,6 +485,23 @@ def test_dataframe_support(pyplot):
485
485
DecisionBoundaryDisplay .from_estimator (estimator , df , response_method = "predict" )
486
486
487
487
488
+ def test_polars_dataframe_support (pyplot ):
489
+ """Check that passing a polars dataframe at fit and to the Display does not
490
+ raise warnings.
491
+
492
+ Non-regression test for:
493
+ https://github.com/scikit-learn/scikit-learn/issues/23311
494
+ """
495
+ pl = pytest .importorskip ("polars" )
496
+ df = pl .DataFrame ({"col_x" : X [:, 0 ], "col_y" : X [:, 1 ]})
497
+ estimator = LogisticRegression ().fit (df , y )
498
+
499
+ with warnings .catch_warnings ():
500
+ # no warnings linked to feature names validation should be raised
501
+ warnings .simplefilter ("error" , UserWarning )
502
+ DecisionBoundaryDisplay .from_estimator (estimator , df , response_method = "predict" )
503
+
504
+
488
505
@pytest .mark .parametrize ("response_method" , ["predict_proba" , "decision_function" ])
489
506
def test_class_of_interest_binary (pyplot , response_method ):
490
507
"""Check the behaviour of passing `class_of_interest` for plotting the output of
0 commit comments