8000 FIX DecisionBoundaryPlot should not raise spurious warning (#23318) · scikit-learn/scikit-learn@b0b8a39 · GitHub
[go: up one dir, main page]

Skip to content

Commit b0b8a39

Browse files
authored
FIX DecisionBoundaryPlot should not raise spurious warning (#23318)
1 parent 8fb8607 commit b0b8a39

File tree

2 files changed

+32
-1
lines changed

2 files changed

+32
-1
lines changed

sklearn/inspection/_plot/decision_boundary.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,9 +294,16 @@ def from_estimator(
294294
np.linspace(x0_min, x0_max, grid_resolution),
295295
np.linspace(x1_min, x1_max, grid_resolution),
296296
)
297+
if hasattr(X, "iloc"):
298+
# we need to preserve the feature names and therefore get an empty dataframe
299+
X_grid = X.iloc[[], :].copy()
300+
X_grid.iloc[:, 0] = xx0.ravel()
301+
X_grid.iloc[:, 1] = xx1.ravel()
302+
else:
303+
X_grid = np.c_[xx0.ravel(), xx1.ravel()]
297304

298305
pred_func = _check_boundary_response_method(estimator, response_method)
299-
response = pred_func(np.c_[xx0.ravel(), xx1.ravel()])
306+
response = pred_func(X_grid)
300307

301308
# convert classes predictions into integers
302309
if pred_func.__name__ == "predict" and hasattr(estimator, "classes_"):

sklearn/inspection/_plot/tests/test_boundary_decision_display.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import warnings
2+
13
import pytest
24
import numpy as np
35
from numpy.testing import assert_allclose
@@ -265,6 +267,11 @@ def test_multioutput_regressor_error(pyplot):
265267
DecisionBoundaryDisplay.from_estimator(tree, X)
266268

267269

270+
@pytest.mark.filterwarnings(
271+
# We expect to raise the following warning because the classifier is fit on a
272+
# NumPy array
273+
"ignore:X has feature names, but LogisticRegression was fitted without"
274+
)
268275
def test_dataframe_labels_used(pyplot, fitted_clf):
269276
"""Check that column names are used for pandas."""
270277
pd = pytest.importorskip("pandas")
@@ -319,3 +326,20 @@ def test_string_target(pyplot):
319326
grid_resolution=5,
320327
response_method="predict",
321328
)
329+
330+
331+
def test_dataframe_support():
332+
"""Check that passing a dataframe at fit and to the Display does not
333+
raise warnings.
334+
335+
Non-regression test for:
336+
https://github.com/scikit-learn/scikit-learn/issues/23311
337+
"""
338+
pd = pytest.importorskip("pandas")
339+
df = pd.DataFrame(X, columns=["col_x", "col_y"])
340+
estimator = LogisticRegression().fit(df, y)
341+
342+
with warnings.catch_warnings():
343+
# no warnings linked to feature names validation should be raised
344+
warnings.simplefilter("error", UserWarning)
345+
DecisionBoundaryDisplay.from_estimator(estimator, df, response_method="predict")

0 commit comments

Comments
 (0)
0