8000 Add test for passing polars dataframes into from_estimator · scikit-learn/scikit-learn@bcbc83a · GitHub
[go: up one dir, main page]

Skip to content

Commit bcbc83a

Browse files
committed
Add test for passing polars dataframes into from_estimator
1 parent 9a4b37c commit bcbc83a

File tree

1 file changed

+19
-2
lines changed

1 file changed

+19
-2
lines changed

sklearn/inspection/_plot/tests/test_boundary_decision_display.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -468,8 +468,8 @@ def test_string_target(pyplot):
468468
)
469469

470470

471-
def test_dataframe_support(pyplot):
472-
"""Check that passing a dataframe at fit and to the Display does not
471+
def test_pandas_dataframe_support(pyplot):
472+
"""Check that passing a pandas dataframe at fit and to the Display does not
473473
raise warnings.
474474
475475
Non-regression test for:
@@ -485,6 +485,23 @@ def test_dataframe_support(pyplot):
485485
DecisionBoundaryDisplay.from_estimator(estimator, df, response_method="predict")
486486

487487

488+
def test_polars_dataframe_support(pyplot):
489+
"""Check that passing a polars dataframe at fit and to the Display does not
490+
raise warnings.
491+
492+
Non-regression test for:
493+
https://github.com/scikit-learn/scikit-learn/issues/23311
494+
"""
495+
pl = pytest.importorskip("polars")
496+
df = pl.DataFrame({"col_x": X[:, 0], "col_y": X[:, 1]})
497+
estimator = LogisticRegression().fit(df, y)
498+
499+
with warnings.catch_warnings():
500+
# no warnings linked to feature names validation should be raised
501+
warnings.simplefilter("error", UserWarning)
502+
DecisionBoundaryDisplay.from_estimator(estimator, df, response_method="predict")
503+
504+
488505
@pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"])
489506
def test_class_of_interest_binary(pyplot, response_method):
490507
"""Check the behaviour of passing `class_of_interest` for plotting the output of

0 commit comments

Comments
 (0)
0