|
| 1 | +import warnings |
| 2 | + |
1 | 3 | import pytest
|
2 | 4 | import numpy as np
|
3 | 5 | from numpy.testing import assert_allclose
|
@@ -265,6 +267,11 @@ def test_multioutput_regressor_error(pyplot):
|
265 | 267 | DecisionBoundaryDisplay.from_estimator(tree, X)
|
266 | 268 |
|
267 | 269 |
|
| 270 | +@pytest.mark.filterwarnings( |
| 271 | + # We expect to raise the following warning because the classifier is fit on a |
| 272 | + # NumPy array |
| 273 | + "ignore:X has feature names, but LogisticRegression was fitted without" |
| 274 | +) |
268 | 275 | def test_dataframe_labels_used(pyplot, fitted_clf):
|
269 | 276 | """Check that column names are used for pandas."""
|
270 | 277 | pd = pytest.importorskip("pandas")
|
@@ -319,3 +326,20 @@ def test_string_target(pyplot):
|
319 | 326 | grid_resolution=5,
|
320 | 327 | response_method="predict",
|
321 | 328 | )
|
| 329 | + |
| 330 | + |
| 331 | +def test_dataframe_support(): |
| 332 | + """Check that passing a dataframe at fit and to the Display does not |
| 333 | + raise warnings. |
| 334 | +
|
| 335 | + Non-regression test for: |
| 336 | + https://github.com/scikit-learn/scikit-learn/issues/23311 |
| 337 | + """ |
| 338 | + pd = pytest.importorskip("pandas") |
| 339 | + df = pd.DataFrame(X, columns=["col_x", "col_y"]) |
| 340 | + estimator = LogisticRegression().fit(df, y) |
| 341 | + |
| 342 | + with warnings.catch_warnings(): |
| 343 | + # no warnings linked to feature names validation should be raised |
| 344 | + warnings.simplefilter("error", UserWarning) |
| 345 | + DecisionBoundaryDisplay.from_estimator(estimator, df, response_method="predict") |
0 commit comments