8000 FIX Array validation for DecisionBoundaryDisplay (#25077) · scikit-learn/scikit-learn@c047241 · GitHub
[go: up one dir, main page]

Skip to content

Commit c047241

Browse files
FIX Array validation for DecisionBoundaryDisplay (#25077)
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
1 parent 7af5297 commit c047241

File tree

3 files changed

+25
-1
lines changed

3 files changed

+25
-1
lines changed

doc/whats_new/v1.2.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,10 @@ Changelog
416416
:pr:`18298` by :user:`Madhura Jayaratne <madhuracj>` and
417417
:user:`Guillaume Lemaitre <glemaitre>`.
418418

419+
- |Fix| :class:`inspection.DecisionBoundaryDisplay` now raises error if input
420+
data is not 2-dimensional.
421+
:pr:`25077` by :user:`Arturo Amor <ArturoAmorQ>`.
422+
419423
:mod:`sklearn.kernel_approximation`
420424
...................................
421425

sklearn/inspection/_plot/decision_boundary.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@
66
from ...utils import check_matplotlib_support
77
from ...utils import _safe_indexing
88
from ...base import is_regressor
9-
from ...utils.validation import check_is_fitted, _is_arraylike_not_scalar
9+
from ...utils.validation import (
10+
check_is_fitted,
11+
_is_arraylike_not_scalar,
12+
_num_features,
13+
)
1014

1115

1216
def _check_boundary_response_method(estimator, response_method):
@@ -316,6 +320,12 @@ def from_estimator(
316320
f"Got {plot_method} instead."
317321
)
318322

323+
num_features = _num_features(X)
324+
if num_features != 2:
325+
raise ValueError(
326+
f"n_features must be equal to 2. Got {num_features} instead."
327+
)
328+
319329
x0, x1 = _safe_indexing(X, 0, axis=1), _safe_indexing(X, 1, axis=1)
320330

321331
x0_min, x0_max = x0.min() - eps, x0.max() + eps

sklearn/inspection/_plot/tests/test_boundary_decision_display.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,16 @@ def fitted_clf():
3838
return LogisticRegression().fit(X, y)
3939

4040

41+
def test_input_data_dimension():
42+
"""Check that we raise an error when `X` does not have exactly 2 features."""
43+
X, y = make_classification(n_samples=10, n_features=4, random_state=0)
44+
45+
clf = LogisticRegression().fit(X, y)
46+
msg = "n_features must be equal to 2. Got 4 instead."
47+
with pytest.raises(ValueError, match=msg):
48+
DecisionBoundaryDisplay.from_estimator(estimator=clf, X=X)
49+
50+
4151
def test_check_boundary_response_method_auto():
4252
"""Check _check_boundary_response_method behavior with 'auto'."""
4353

0 commit comments

Comments
 (0)
0