8000 MNT Raise erorr when normalize is invalid in confusion_matrix (#15888) · scikit-learn/scikit-learn@4ad4cc6 · GitHub
[go: up one dir, main page]

Skip to content

Commit 4ad4cc6

Browse files
MNT Raise erorr when normalize is invalid in confusion_matrix (#15888)
1 parent 13134a8 commit 4ad4cc6

File tree

4 files changed

+18
-4
lines changed

4 files changed

+18
-4
lines changed

doc/whats_new/v0.22.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,13 @@ This is a bug-fix release to primarily resolve some packaging issues in version
1515
Changelog
1616
---------
1717

18+
:mod:`sklearn.metrics`
19+
......................
20+
21+
- |Fix| :func:`metrics.plot_confusion_matrix` now raises error when `normalize`
22+
is invalid. Previously, it runs fine with no normalization.
23+
:pr:`15888` by `Hanmin Qin`_.
24+
1825
:mod:`sklearn.utils`
1926
....................
2027

sklearn/metrics/_classification.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,10 @@ def confusion_matrix(y_true, y_pred, labels=None, sample_weight=None,
283283

284284
check_consistent_length(y_true, y_pred, sample_weight)
285285

286+
if normalize not in ['true', 'pred', 'all', None]:
287+
raise ValueError("normalize must be one of {'true', 'pred', "
288+
"'all', None}")
8000 289+
286290
n_labels = labels.size
287291
label_to_ind = {y: x for x, y in enumerate(labels)}
288292
# convert yt, yp into index

sklearn/metrics/_plot/confusion_matrix.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -184,10 +184,6 @@ def plot_confusion_matrix(estimator, X, y_true, labels=None,
184184
if not is_classifier(estimator):
185185
raise ValueError("plot_confusion_matrix only supports classifiers")
186186

187-
if normalize not in {'true', 'pred', 'all', None}:
188-
raise ValueError("normalize must be one of {'true', 'pred', "
189-
"'all', None}")
190-
191187
y_pred = estimator.predict(X)
192188
cm = confusion_matrix(y_true, y_pred, sample_weight=sample_weight,
193189
labels=labels, normalize=normalize)

sklearn/metrics/tests/test_classification.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,13 @@ def test_confusion_matrix_normalize(normalize, cm_dtype, expected_results):
526526
assert cm.dtype.kind == cm_dtype
527527

528528

529+
def test_confusion_matrix_normalize_wrong_option():
530+
y_test = [0, 0, 0, 0, 1, 1, 1, 1]
531+
y_pred = [0, 0, 0, 0, 0, 0, 0, 0]
532+
with pytest.raises(ValueError, match='normalize must be one of'):
533+
confusion_matrix(y_test, y_pred, normalize=True)
534+
535+
529536
def test_confusion_matrix_normalize_single_class():
530537
y_test = [0, 0, 0, 0, 1, 1, 1, 1]
531538
y_pred = [0, 0, 0, 0, 0, 0, 0, 0]

0 commit comments

Comments
 (0)
0